import * as ast from "../ast.ts"; import { FileReporter, Loc } from "../diagnostics.ts"; import { Ty } from "../ty.ts"; import { Sym, Syms } from "./resolve.ts"; export class Tys { constructor( private syms: Syms, private reporter: FileReporter, ) { this.cx = new CheckerCx(this.syms, this.reporter); } private cx: CheckerCx; fnStmt(node: ast.NodeWithKind<"FnStmt">): Ty { return this.cx.fnStmt(node); } param(node: ast.NodeWithKind<"Param">): Ty { return this.cx.param(node); } place(node: ast.Node): Ty { return this.cx.place(node); } expr(node: ast.Node): Ty { return this.cx.expr(node); } ty(node: ast.Node): Ty { return this.cx.ty(node); } } class CheckerCx { constructor( private syms: Syms, private reporter: FileReporter, ) {} private nodeTys = new Map(); private stmtChecker = new StmtChecker(this); private paramChecker = new ParamChecker(this); private placeChecker = new PlaceChecker(this); private exprChecker = new ExprChecker(this); private tyChecker = new TyChecker(this); fnStmt(node: ast.NodeWithKind<"FnStmt">): Ty { return this.cache(node, () => this.stmtChecker.checkFnStmt(node)); } param(node: ast.NodeWithKind<"Param">): Ty { return this.cache(node, () => this.paramChecker.checkParam(node)); } place(node: ast.Node): Ty { return this.cache(node, () => this.placeChecker.checkPlace(node)); } expr(node: ast.Node): Ty { return this.cache(node, () => this.exprChecker.checkExpr(node)); } ty(node: ast.Node): Ty { return this.cache(node, () => this.tyChecker.checkTy(node)); } private cache(node: ast.Node, action: () => Ty): Ty { if (this.nodeTys.has(node.id)) { return this.nodeTys.get(node.id)!; } const ty = action(); this.nodeTys.set(node.id, ty); return ty; } error(loc: Loc, message: string) { this.reporter.error(loc, message); } info(loc: Loc, message: string) { this.reporter.info(loc, message); } fail(): never { this.reporter.abort(); } sym(node: ast.Node): Sym { return this.syms.get(node); } assertCompatible(left: Ty, right: Ty, loc: Loc): void { if (!left.resolvableWith(right)) { this.error( loc, `type '${left.pretty()}' not compatible with type '${right.pretty()}'`, ); this.fail(); } } } class StmtChecker { constructor( private cx: CheckerCx, ) {} checkFnStmt(stmt: ast.NodeWithKind<"FnStmt">): Ty { const k = stmt.kind; const params = k.params .map((param) => this.cx.param(param.as("Param"))); const retTy = k.retTy ? this.cx.ty(k.retTy) : Ty.Void; k.body.visit({ visit: (node) => { if (node.is("ReturnStmt")) { const ty = node.kind.expr ? this.cx.expr(node.kind.expr) : Ty.Void; if (!ty.resolvableWith(retTy)) { this.cx.error( node.loc, `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, ); this.cx.info( stmt.kind.retTy?.loc ?? stmt.loc, `return type '${retTy}' defined here`, ); this.cx.fail(); } } }, }); const ty = Ty.create("Fn", { params, retTy }); return Ty.create("FnStmt", { stmt, ty }); } } class ParamChecker { constructor( private cx: CheckerCx, ) {} checkParam(node: ast.NodeWithKind<"Param">): Ty { const sym = this.cx.sym(node); if (sym.tag === "Let") { const exprTy = this.cx.expr(sym.stmt.kind.expr); if (node.kind.ty) { const explicitTy = this.cx.ty(node.kind.ty); this.cx.assertCompatible( exprTy, explicitTy, sym.stmt.kind.expr.loc, ); } return exprTy; } if (sym.tag === "FnParam") { if (!node.kind.ty) { this.cx.error(node.loc, `parameter must have a type`); this.cx.fail(); } return this.cx.ty(node.kind.ty); } throw new Error(`'${sym.tag}' not handled`); } } class PlaceChecker { constructor( private cx: CheckerCx, ) {} checkPlace(node: ast.Node): Ty { if (node.is("UnaryExpr")) { if (node.kind.op === "Deref") { const exprTy = this.checkPlace(node.kind.expr); if (exprTy.is("Ptr") || exprTy.is("PtrMut")) { return exprTy.kind.ty; } } } return this.cx.expr(node); } } class ExprChecker { constructor( private cx: CheckerCx, ) {} checkExpr(node: ast.Node): Ty { const tag = node.kind.tag; switch (tag) { case "IdentExpr": return this.checkIdentExpr(node.as(tag)); case "IntExpr": return this.checkIntExpr(node.as(tag)); case "StrExpr": return Ty.create("Ptr", { ty: Ty.create("Slice", { ty: Ty.U8 }), }); case "ArrayExpr": return this.checkArrayExpr(node.as(tag)); case "IndexExpr": return this.checkIndexExpr(node.as(tag)); case "CallExpr": return this.checkCallExpr(node.as(tag)); case "UnaryExpr": return this.checkUnaryExpr(node.as(tag)); case "BinaryExpr": return this.checkBinaryExpr(node.as(tag)); case "RangeExpr": return this.checkRangeExpr(node.as(tag)); default: throw new Error(`'${node.kind.tag}' not unhandled`); } } private checkIdentExpr(node: ast.NodeWithKind<"IdentExpr">): Ty { const sym = this.cx.sym(node); if (sym.tag === "Fn") { return this.cx.fnStmt(sym.stmt); } if (sym.tag === "Bool") { return Ty.Bool; } if (sym.tag === "Builtin") { this.cx.error(node.loc, `invalid use of builtin '${sym.id}'`); this.cx.fail(); } if (sym.tag === "FnParam") { return this.cx.expr(sym.param); } if (sym.tag === "Let") { return this.cx.expr(sym.param); } throw new Error(`'${sym.tag}' not handled`); } private checkIntExpr(node: ast.NodeWithKind<"IntExpr">): Ty { switch (node.kind.intTy) { case "u8": return Ty.U8; case "u16": return Ty.U16; case "u32": return Ty.U32; case "u64": return Ty.U64; case "usize": return Ty.USize; case "i8": return Ty.U8; case "i16": return Ty.U16; case "i32": return Ty.I32; case "i64": return Ty.U64; case "isize": return Ty.ISize; case "any": return Ty.I32; default: throw new Error(`intType '${node.kind.intTy}' not handled`); } } private checkArrayExpr(node: ast.NodeWithKind<"ArrayExpr">): Ty { let ty: Ty | null = null; for (const value of node.kind.values) { const valueTy = this.cx.expr(value); if (ty) { this.cx.assertCompatible(ty, valueTy, value.loc); } else { ty = valueTy; } } if (!ty) { this.cx.error(node.loc, `could not infer type of empty array`); this.cx.fail(); } const length = node.kind.values.length; return Ty.create("Array", { ty, length }); } private checkIndexExpr(node: ast.NodeWithKind<"IndexExpr">): Ty { const exprTy = this.cx.place(node.kind.value); const argTy = this.cx.expr(node.kind.arg); if ( (exprTy.is("Array") || exprTy.is("Slice")) && argTy.resolvableWith(Ty.I32) ) { return exprTy.kind.ty; } if ( (exprTy.is("Array") || exprTy.is("Slice")) && argTy.resolvableWith(Ty.create("Range", {})) ) { return Ty.create("Slice", { ty: exprTy.kind.ty }); } this.cx.error( node.loc, `cannot use index operator on '${exprTy.pretty()}' with '${argTy.pretty()}'`, ); this.cx.fail(); } private checkCallExpr(node: ast.NodeWithKind<"CallExpr">): Ty { if (node.kind.value.is("IdentExpr")) { const sym = this.cx.sym(node.kind.value); if (sym.tag === "Builtin") { return this.checkCallExprBuiltin(node, sym); } } const calleeTy = this.cx.expr(node.kind.value); const callableTy = calleeTy.is("Fn") ? calleeTy : calleeTy.is("FnStmt") ? calleeTy.kind.ty as Ty & { kind: { tag: "Fn" } } : null; if (!callableTy) { this.cx.error( node.loc, `type '${calleeTy.pretty()}' not callable`, ); this.cx.fail(); } const args = node.kind.args .map((arg) => this.cx.expr(arg)); const params = callableTy.kind.params; if (args.length !== params.length) { this.reportArgsIncorrectAmount( node, args.length, params.length, calleeTy, ); } for (const i of args.keys()) { if (!args[i].resolvableWith(params[i])) { this.reportArgTypeNotCompatible( node, args, params, calleeTy, i, ); } } return callableTy.kind.retTy; } private checkCallExprBuiltin( node: ast.NodeWithKind<"CallExpr">, sym: Sym, ): Ty { if (!node.kind.value.is("IdentExpr")) { throw new Error(); } if (sym.tag !== "Builtin") { throw new Error(); } if (sym.id === "len") { if (node.kind.args.length !== 1) { this.reportArgsIncorrectAmount( node, node.kind.args.length, 0, null, ); } const argTy = this.cx.expr(node.kind.args[0]); if ( !(argTy.is("Array") || argTy.is("Ptr") && (argTy.kind.ty.is("Array") || argTy.kind.ty.is("Slice"))) ) { this.reportArgTypeNotCompatible( node, [argTy], [Ty.Error], null, 0, ); } return Ty.I32; } if (sym.id === "print") { void node.kind.args .map((arg) => this.cx.expr(arg)); return Ty.Void; } throw new Error(`builtin '${sym.id}' not handled`); } private checkUnaryExpr(node: ast.NodeWithKind<"UnaryExpr">): Ty { const exprTy = this.cx.expr(node.kind.expr); if (node.kind.op === "Neg" && exprTy.resolvableWith(Ty.I32)) { return Ty.I32; } if (node.kind.op === "Not" && exprTy.resolvableWith(Ty.Bool)) { return Ty.Bool; } if (node.kind.op === "Ref") { return Ty.create("Ptr", { ty: exprTy }); } if (node.kind.op === "RefMut") { return Ty.create("PtrMut", { ty: exprTy }); } if (node.kind.op === "Deref") { if (exprTy.is("Ptr") || exprTy.is("PtrMut")) { if (!exprTy.kind.ty.isSized()) { this.cx.error( node.loc, `cannot dereference unsized type '${exprTy.kind.ty.pretty()}' in an expression`, ); this.cx.fail(); } return exprTy.kind.ty; } } this.cx.error( node.loc, `operator '${node.kind.tok}' cannot be applied to type '${exprTy.pretty()}'`, ); this.cx.fail(); } private checkBinaryExpr(node: ast.NodeWithKind<"BinaryExpr">): Ty { const left = this.cx.expr(node.kind.left); const right = this.cx.expr(node.kind.right); const result = binaryOpTests .map((test) => test(node.kind.op, left, right)) .filter((result) => result) .at(0); if (!result) { this.cx.error( node.loc, `operator '${node.kind.tok}' cannot be applied to types '${left.pretty()}' and '${right.pretty()}'`, ); this.cx.fail(); } return result; } private checkRangeExpr(node: ast.NodeWithKind<"RangeExpr">): Ty { for (const operandExpr of [node.kind.begin, node.kind.end]) { const operandTy = operandExpr && this.cx.expr(operandExpr); if (operandTy && !operandTy.resolvableWith(Ty.I32)) { this.cx.error( operandExpr.loc, `range operand must be '${Ty.I32.pretty()}', not '${operandTy.pretty()}'`, ); this.cx.fail(); } } return Ty.create("Range", {}); } private reportArgsIncorrectAmount( node: ast.NodeWithKind<"CallExpr">, argsLength: number, paramsLength: number, calleeTy: Ty | null, ): never { this.cx.error( node.loc, `incorrect amount of arguments. got ${argsLength} expected ${paramsLength}`, ); if (calleeTy?.is("FnStmt")) { this.cx.info( calleeTy.kind.stmt.loc, "function defined here", ); } this.cx.fail(); } private reportArgTypeNotCompatible( node: ast.NodeWithKind<"CallExpr">, args: Ty[], params: Ty[], calleeTy: Ty | null, i: number, ): never { this.cx.error( node.kind.args[i].loc, `type '${args[i].pretty()}' not compatible with type '${ params[i].pretty() }', for argument ${i}`, ); if (calleeTy?.is("FnStmt")) { this.cx.info( calleeTy.kind.stmt.kind.params[i].loc, `parameter '${ calleeTy.kind.stmt.kind.params[i] .as("Param").kind.ident }' defined here`, ); } this.cx.fail(); } } class TyChecker { constructor( private cx: CheckerCx, ) {} checkTy(node: ast.Node): Ty { if (node.is("IdentTy")) { const sym = this.cx.sym(node); if (sym.tag === "BuiltinTy") { switch (sym.ident) { case "void": return Ty.Void; case "bool": return Ty.Bool; case "i8": return Ty.I8; case "i16": return Ty.I16; case "i32": return Ty.I32; case "i64": return Ty.I64; case "isize": return Ty.ISize; case "u8": return Ty.U8; case "u16": return Ty.U16; case "u32": return Ty.U32; case "u64": return Ty.U64; case "usize": return Ty.USize; default: throw new Error( `unknown type '${node.kind.ident}'`, ); } } this.cx.error(node.loc, `symbol is not a type`); this.cx.fail(); } if (node.is("PtrTy")) { const ty = this.cx.ty(node.kind.ty); return Ty.create("Ptr", { ty }); } if (node.is("PtrMutTy")) { const ty = this.cx.ty(node.kind.ty); return Ty.create("PtrMut", { ty }); } if (node.is("ArrayTy")) { const ty = this.cx.ty(node.kind.ty); const lengthTy = this.cx.expr(node.kind.length); if (!lengthTy.resolvableWith(Ty.I32)) { this.cx.error( node.kind.length.loc, `for array length, expected 'int', got '${lengthTy.pretty()}'`, ); this.cx.fail(); } if (!node.kind.length.is("IntExpr")) { this.cx.error( node.kind.length.loc, `array length must be an 'int' expression`, ); this.cx.fail(); } const length = node.kind.length.kind.value; return Ty.create("Array", { ty, length }); } if (node.is("SliceTy")) { const ty = this.cx.ty(node.kind.ty); return Ty.create("Slice", { ty }); } throw new Error(`'${node.kind.tag}' not unhandled`); } } type BinaryOpTest = (op: ast.BinaryOp, left: Ty, right: Ty) => Ty | null; const binaryOpTests: BinaryOpTest[] = [ (op, left, right) => { const ops: ast.BinaryOp[] = [ "Add", "Sub", "Mul", "Div", "Rem", ]; if ( ops.includes(op) && left.is("Int") && left.resolvableWith(right) ) { return left; } return null; }, (op, left, right) => { const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"]; if ( ops.includes(op) && left.is("Int") && left.resolvableWith(right) ) { return Ty.Bool; } return null; }, ];