import { Loc } from "./diagnostics.ts"; export function create( loc: Loc, tag: Tag, kind: Omit, ): Node { return Node.create(loc, tag, kind); } export class Node { private static idCounter = 0; static create( loc: Loc, tag: Tag, kind: Omit, ): Node { return new Node( Node.idCounter++, loc, { tag, ...kind } as NodeKind & { tag: Tag }, ); } private constructor( public id: number, public loc: Loc, public kind: NodeKind, ) {} as< Tag extends NodeKind["tag"], >(tag: Tag): NodeWithKind { this.assertIs(tag); return this; } assertIs< Tag extends NodeKind["tag"], >(tag: Tag): asserts this is NodeWithKind { if (this.kind.tag !== tag) { throw new Error(); } } is< Tag extends NodeKind["tag"], >(tag: Tag): this is NodeWithKind { return this.kind.tag === tag; } visit(v: Visitor) { if (v.visit(this) === "break") { return; } this.visitBelow(v); } visitBelow(v: Visitor) { const visit = (...nodes: (Node | null)[]) => { for (const node of nodes) { node?.visit(v); } }; const k = this.kind; switch (k.tag) { case "Error": return visit(); case "File": return visit(...k.stmts); case "Block": return visit(...k.stmts); case "ExprStmt": return visit(k.expr); case "AssignStmt": return visit(k.place, k.expr); case "FnStmt": return visit(...k.params, k.retTy, k.body); case "ReturnStmt": return visit(k.expr); case "LetStmt": return visit(k.param, k.expr); case "IfStmt": return visit(k.cond, k.truthy, k.falsy); case "WhileStmt": return visit(k.cond, k.body); case "BreakStmt": return visit(); case "Param": return visit(k.ty); case "IdentExpr": return visit(); case "IntExpr": return visit(); case "ArrayExpr": return visit(...k.values); case "IndexExpr": return visit(k.value, k.arg); case "CallExpr": return visit(k.value, ...k.args); case "UnaryExpr": return visit(k.expr); case "BinaryExpr": return visit(k.left, k.right); case "RangeExpr": return visit(k.begin, k.end); case "IdentTy": return visit(); case "PtrTy": case "PtrMutTy": return visit(k.ty); case "ArrayTy": return visit(k.ty, k.length); case "SliceTy": return visit(k.ty); } k satisfies never; } } export type NodeKind = | { tag: "Error" } | { tag: "File"; stmts: Node[] } | { tag: "Block"; stmts: Node[] } | { tag: "ExprStmt"; expr: Node } | { tag: "AssignStmt"; place: Node; expr: Node } | { tag: "FnStmt"; ident: string; params: Node[]; retTy: Node | null; body: Node; } | { tag: "ReturnStmt"; expr: Node | null } | { tag: "LetStmt"; param: Node; expr: Node } | { tag: "IfStmt"; cond: Node; truthy: Node; falsy: Node | null } | { tag: "WhileStmt"; cond: Node; body: Node } | { tag: "BreakStmt" } | { tag: "Param"; ident: string; ty: Node | null } | { tag: "IdentExpr"; ident: string } | { tag: "IntExpr"; value: number } | { tag: "ArrayExpr"; values: Node[] } | { tag: "IndexExpr"; value: Node; arg: Node } | { tag: "CallExpr"; value: Node; args: Node[] } | { tag: "UnaryExpr"; op: UnaryOp; expr: Node; tok: string } | { tag: "BinaryExpr"; op: BinaryOp; left: Node; right: Node; tok: string } | { tag: "RangeExpr"; begin: Node | null; end: Node | null; limit: RangeLimit; } | { tag: "IdentTy"; ident: string } | { tag: "PtrTy" | "PtrMutTy"; ty: Node } | { tag: "ArrayTy"; ty: Node; length: Node } | { tag: "SliceTy"; ty: Node }; export type UnaryOp = | "Not" | "Negate" | "Ref" | "RefMut" | "Deref"; export type BinaryOp = | "Or" | "And" | "Eq" | "Ne" | "Lt" | "Gt" | "Lte" | "Gte" | "BitOr" | "BitXor" | "BitAnd" | "Shl" | "Shr" | "Add" | "Subtract" | "Multiply" | "Divide" | "Remainder"; export type RangeLimit = "Inclusive" | "Exclusive"; export interface Visitor { visit(node: Node): void | "break"; } export type NodeWithKind< Tag extends NodeKind["tag"], > = Node & { kind: { tag: Tag } }; export type Block = NodeWithKind<"Block">; export type FnStmt = NodeWithKind<"FnStmt">; export type Param = NodeWithKind<"Param">; export function assertNodeWithKind< Tag extends NodeKind["tag"], >(node: Node, tag: Tag): asserts node is NodeWithKind { if (node.kind.tag !== tag) { throw new Error(); } }