import * as ast from "../ast.ts"; import { FileReporter, Loc } from "../diagnostics.ts"; import { Ty } from "../ty.ts"; import { Syms } from "./resolve.ts"; export class Tys { private nodeTys = new Map(); private checker: Checker; constructor( private syms: Syms, private reporter: FileReporter, ) { this.checker = new Checker(this, this.syms, this.reporter); } fnStmt(node: ast.NodeWithKind<"FnStmt">): Ty { if (this.nodeTys.has(node.id)) { return this.nodeTys.get(node.id)!; } const ty = this.checker.checkFnStmt(node); this.nodeTys.set(node.id, ty); return ty; } place(node: ast.Node): Ty { if (this.nodeTys.has(node.id)) { return this.nodeTys.get(node.id)!; } const ty = this.checker.checkPlace(node); this.nodeTys.set(node.id, ty); return ty; } expr(node: ast.Node): Ty { if (this.nodeTys.has(node.id)) { return this.nodeTys.get(node.id)!; } const ty = this.checker.checkExpr(node); this.nodeTys.set(node.id, ty); return ty; } } class Checker { constructor( private tys: Tys, private syms: Syms, private reporter: FileReporter, ) {} checkFnStmt(stmt: ast.NodeWithKind<"FnStmt">): Ty { const k = stmt.kind; const params = k.params.map((param) => this.tys.expr(param)); const retTy = k.retTy ? this.tys.expr(k.retTy) : Ty.Void; k.body.visit({ visit: (node) => { if (node.is("ReturnStmt")) { const ty = node.kind.expr ? this.tys.expr(node.kind.expr) : Ty.Void; if (!ty.compatibleWith(retTy)) { this.error( node.loc, `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, ); this.info( stmt.kind.retTy?.loc ?? stmt.loc, `return type '${retTy}' defined here`, ); this.fail(); } } }, }); const ty = Ty.create("Fn", { params, retTy }); return Ty.create("FnStmt", { stmt, ty }); } checkPlace(node: ast.Node): Ty { if (node.is("UnaryExpr")) { if (node.kind.op === "Deref") { const exprTy = this.checkPlace(node.kind.expr); if (exprTy.is("Ptr") || exprTy.is("PtrMut")) { return exprTy.kind.ty; } } } return this.checkExpr(node); } checkExpr(node: ast.Node): Ty { const k = node.kind; if (node.is("Param")) { const sym = this.syms.get(node); if (sym.tag === "Let") { const exprTy = this.tys.expr(sym.stmt.kind.expr); if (node.kind.ty) { const explicitTy = this.tys.expr(node.kind.ty); this.assertCompatible( exprTy, explicitTy, sym.stmt.kind.expr.loc, ); } return exprTy; } if (sym.tag === "FnParam") { if (!node.kind.ty) { this.error(node.loc, `parameter must have a type`); this.fail(); } return this.tys.expr(node.kind.ty); } throw new Error(`'${sym.tag}' not handled`); } if (node.is("IdentExpr")) { const sym = this.syms.get(node); if (sym.tag === "Fn") { return this.tys.fnStmt(sym.stmt); } if (sym.tag === "Bool") { return Ty.Bool; } if (sym.tag === "Builtin") { this.error(node.loc, `invalid use of builtin '${sym.id}'`); this.fail(); } if (sym.tag === "FnParam") { return this.tys.expr(sym.param); } if (sym.tag === "Let") { return this.tys.expr(sym.param); } throw new Error(`'${sym.tag}' not handled`); } if (node.is("IntExpr")) { return Ty.Int; } if (node.is("ArrayExpr")) { let ty: Ty | null = null; for (const value of node.kind.values) { const valueTy = this.tys.expr(value); if (ty) { this.assertCompatible(ty, valueTy, value.loc); } else { ty = valueTy; } } if (!ty) { this.error(node.loc, `could not infer type of empty array`); this.fail(); } const length = node.kind.values.length; return Ty.create("Array", { ty, length }); } if (node.is("IndexExpr")) { const exprTy = this.tys.place(node.kind.value); const argTy = this.tys.expr(node.kind.arg); if ( (exprTy.is("Array") || exprTy.is("Slice")) && argTy.compatibleWith(Ty.Int) ) { return exprTy.kind.ty; } if ( (exprTy.is("Array") || exprTy.is("Slice")) && argTy.compatibleWith(Ty.create("Range", {})) ) { return Ty.create("Slice", { ty: exprTy.kind.ty }); } this.error( node.loc, `cannot use index operator on '${exprTy.pretty()}' with '${argTy.pretty()}'`, ); this.fail(); } if (node.is("CallExpr")) { return this.checkCall(node); } if (node.is("UnaryExpr")) { const exprTy = this.tys.expr(node.kind.expr); if (node.kind.op === "Negate" && exprTy.compatibleWith(Ty.Int)) { return Ty.Int; } if (node.kind.op === "Not" && exprTy.compatibleWith(Ty.Bool)) { return Ty.Bool; } if (node.kind.op === "Ref") { return Ty.create("Ptr", { ty: exprTy }); } if (node.kind.op === "RefMut") { return Ty.create("PtrMut", { ty: exprTy }); } if (node.kind.op === "Deref") { if (exprTy.is("Ptr") || exprTy.is("PtrMut")) { if (!exprTy.kind.ty.isSized()) { this.error( node.loc, `cannot dereference unsized type '${exprTy.kind.ty.pretty()}' in an expression`, ); this.fail(); } return exprTy.kind.ty; } } this.error( node.loc, `operator '${node.kind.tok}' cannot be applied to type '${exprTy.pretty()}'`, ); this.fail(); } if (node.is("BinaryExpr")) { const left = this.tys.expr(node.kind.left); const right = this.tys.expr(node.kind.right); const binaryOp = binaryOpPatterns .find((pat) => pat.op === node.kind.op && left.compatibleWith(pat.left) && right.compatibleWith(pat.right) ); if (!binaryOp) { this.error( node.loc, `operator '${node.kind.tok}' cannot be applied to types '${left.pretty()}' and '${right.pretty()}'`, ); this.fail(); } return binaryOp.result; } if (node.is("RangeExpr")) { for (const operandExpr of [node.kind.begin, node.kind.end]) { const operandTy = operandExpr && this.tys.expr(operandExpr); if (operandTy && !operandTy.compatibleWith(Ty.Int)) { this.error( operandExpr.loc, `range operand must be '${Ty.Int.pretty()}', not '${operandTy.pretty()}'`, ); this.fail(); } } return Ty.create("Range", {}); } if (node.is("IdentTy")) { switch (node.kind.ident) { case "void": return Ty.Void; case "int": return Ty.Int; case "bool": return Ty.Bool; default: this.error(node.loc, `unknown type '${node.kind.ident}'`); } } if (node.is("PtrTy")) { const ty = this.tys.expr(node.kind.ty); return Ty.create("Ptr", { ty }); } if (node.is("PtrMutTy")) { const ty = this.tys.expr(node.kind.ty); return Ty.create("PtrMut", { ty }); } if (node.is("ArrayTy")) { const ty = this.tys.expr(node.kind.ty); const lengthTy = this.tys.expr(node.kind.length); if (!lengthTy.compatibleWith(Ty.Int)) { this.error( node.kind.length.loc, `for array length, expected 'int', got '${lengthTy.pretty()}'`, ); this.fail(); } if (!node.kind.length.is("IntExpr")) { this.error( node.kind.length.loc, `array length must be an 'int' expression`, ); this.fail(); } const length = node.kind.length.kind.value; return Ty.create("Array", { ty, length }); } if (node.is("SliceTy")) { const ty = this.tys.expr(node.kind.ty); return Ty.create("Slice", { ty }); } throw new Error(`'${k.tag}' not unhandled`); } private checkCall(node: ast.NodeWithKind<"CallExpr">): Ty { if (node.kind.value.is("IdentExpr")) { const sym = this.syms.get(node.kind.value); if (sym && sym.tag === "Builtin") { if (sym.id === "debug_print") { const _argTys = node.kind.args .map((arg) => this.tys.expr(arg)); return Ty.Void; } } } const calleeTy = this.tys.expr(node.kind.value); const callableTy = calleeTy.is("Fn") ? calleeTy : calleeTy.is("FnStmt") ? calleeTy.kind.ty as Ty & { kind: { tag: "Fn" } } : null; if (!callableTy) { this.error( node.loc, `type '${calleeTy.pretty()}' not callable`, ); this.fail(); } const args = node.kind.args .map((arg) => this.tys.expr(arg)); const params = callableTy.kind.params; if (args.length !== params.length) { this.error( node.loc, `incorrect amount of arguments. got ${args.length} expected ${params.length}`, ); if (calleeTy.is("FnStmt")) { this.info( calleeTy.kind.stmt.loc, "function defined here", ); } this.fail(); } for (const i of args.keys()) { if (!args[i].compatibleWith(params[i])) { this.error( node.kind.args[i].loc, `type '${args[i].pretty()}' not compatible with type '${ params[i].pretty() }', for argument ${i}`, ); if (calleeTy.is("FnStmt")) { this.info( calleeTy.kind.stmt.kind.params[i].loc, `parameter '${ calleeTy.kind.stmt.kind.params[i] .as("Param").kind.ident }' defined here`, ); } this.fail(); } } return callableTy.kind.retTy; } private assertCompatible(left: Ty, right: Ty, loc: Loc): void { if (!left.compatibleWith(right)) { this.error( loc, `type '${left.pretty()}' not compatible with type '${right.pretty()}'`, ); this.fail(); } } private error(loc: Loc, message: string) { this.reporter.error(loc, message); } private info(loc: Loc, message: string) { this.reporter.info(loc, message); } private fail(): never { this.reporter.abort(); } } type BinaryOpPattern = { op: ast.BinaryOp; left: Ty; right: Ty; result: Ty; }; const binaryOpPatterns: BinaryOpPattern[] = [ { op: "Add", left: Ty.Int, right: Ty.Int, result: Ty.Int }, { op: "Subtract", left: Ty.Int, right: Ty.Int, result: Ty.Int }, { op: "Multiply", left: Ty.Int, right: Ty.Int, result: Ty.Int }, { op: "Divide", left: Ty.Int, right: Ty.Int, result: Ty.Int }, { op: "Remainder", left: Ty.Int, right: Ty.Int, result: Ty.Int }, { op: "Eq", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, { op: "Ne", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, { op: "Lt", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, { op: "Gt", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, { op: "Lte", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, { op: "Gte", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, ];