import { Syms } from "./resolve.ts"; import * as ast from "../ast.ts"; import { Ty } from "../ty.ts"; import { FileReporter, Loc } from "../diagnostics.ts"; import * as stringify from "../stringify.ts"; export class Checker { private checkedFns = new Map(); constructor( private syms: Syms, private reporter: FileReporter, ) {} checkFn(fn: ast.FnStmt): CheckedFn { const existing = this.checkedFns.get(fn); if (existing) { return existing; } const checkedFn = new TypeChecker(this, fn, this.syms, this.reporter) .check(); checkedFn.checkForCheckerInternalTys(this.reporter); this.checkedFns.set(fn, checkedFn); return checkedFn; } } export class CheckedFn { constructor( private fnTy: Ty, private nodeTys: Map, ) {} ty(): Ty { return this.fnTy; } exprTy(expr: ast.Node): Ty { const ty = this.nodeTys.get(expr); if (ty === undefined) { throw new Error(`no type for '${expr.kind.tag}'`); } return ty; } checkForCheckerInternalTys(reporter: FileReporter) { for (const [node, ty] of this.nodeTys) { if (ty.isCheckerInternal()) { reporter.error( node.loc, `concrete type must be known at this point, got temporary type '${ty.pretty()}'`, ); reporter.abort(); } } } } class TypeChecker { private nodeTys = new Map(); private generics?: Ty[]; private params!: Ty[]; constructor( private cx: Checker, private fn: ast.FnStmt, private syms: Syms, private reporter: FileReporter, ) {} check(): CheckedFn { const generics = this.fn.kind.genericParams ?.map((_node, idx) => { return Ty.Generic(idx); }); this.generics = generics; const params = this.fn.kind.params .map((node) => { const param = node.as("Param"); if (!param.kind.ty) { this.reporter .error(param.loc, `parameter must have a type`); this.reporter.abort(); } return this.ty(param.kind.ty); }); this.params = params; const retTy = this.fn.kind.retTy ? this.ty(this.fn.kind.retTy) : Ty.Void; this.fn.kind.body.visit({ visit: (node) => { const k = node.kind; switch (k.tag) { case "ReturnStmt": { const ty = k.expr ? this.expr(k.expr, retTy) : Ty.Void; if (!ty.resolvableWith(retTy)) { this.reporter.error( node.loc, `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, ); this.reporter.info( this.fn.kind.retTy?.loc ?? this.fn.loc, `return type '${retTy.pretty()}' defined here`, ); this.reporter.abort(); } break; } case "LetStmt": { let ty: Ty; const paramTy = k.param.as("Param").kind.ty; if (paramTy) { const explicitTy = this.ty(paramTy); const exprTy = this.expr(k.expr, explicitTy); const res = this.resolve( exprTy, explicitTy, k.expr.loc, ); if (res.rewriteSubtree) { this.rewriteTree(k.expr, res.ty); } ty = res.ty; } else { ty = this.expr(k.expr, Ty.Any); if (ty.is("AnyInt")) { ty = Ty.I32; this.rewriteTree(k.expr, ty); } } this.nodeTys.set(node, ty); break; } case "ExprStmt": { this.expr(k.expr, Ty.Any); break; } case "AssignStmt": { const placeTy = this.place(k.place, Ty.Any); const exprTy = this.expr(k.expr, placeTy); if (!placeTy.resolvableWith(exprTy)) { this.reporter.error( k.expr.loc, `type '${exprTy.pretty()}' not assignable to type '${placeTy.pretty()}'`, ); this.reporter.abort(); } break; } case "IfStmt": { const condTy = this.expr(k.cond, Ty.Bool); if (!condTy.is("Bool")) { this.reporter.error( k.cond.loc, `expected condition to be 'bool', got '${condTy.pretty()}'`, ); this.reporter.abort(); } break; } case "WhileStmt": { const condTy = this.expr(k.cond, Ty.Bool); if (!condTy.is("Bool")) { this.reporter.error( k.cond.loc, `expected condition to be 'bool', got '${condTy.pretty()}'`, ); this.reporter.abort(); } break; } case "BreakStmt": { break; } case "FnStmt": { break; } case "Error": case "File": case "Block": case "Param": case "IdentExpr": case "IntExpr": case "StrExpr": case "ArrayExpr": case "IndexExpr": case "CallExpr": case "UnaryExpr": case "BinaryExpr": case "RangeExpr": case "IdentTy": case "PtrTy": case "PtrMutTy": case "ArrayTy": case "SliceTy": case "Generic": break; default: k satisfies never; } }, }); this.convertAnyIntToI32(); const ty = Ty.create("FnStmt", { stmt: this.fn, ty: Ty.Fn(params, retTy, generics ?? null), }); return new CheckedFn(ty, this.nodeTys); } private convertAnyIntToI32() { for (const [node, ty] of this.nodeTys) { if (ty.is("AnyInt")) { this.rewriteTree(node, Ty.I64); } } } private place(expr: ast.Node, expected: Ty): Ty { return this.cachedCheck(expr, () => this.checkPlace(expr, expected)); } private checkPlace(expr: ast.Node, expected: Ty): Ty { const k = expr.kind; switch (k.tag) { case "UnaryExpr": { switch (k.op) { case "Deref": { const innerTy = this.checkPlace( k.expr, Ty.AnyDerefable(expected), ); if (innerTy.is("Ptr") || innerTy.is("PtrMut")) { return innerTy.kind.ty; } } } } } return this.expr(expr, expected); } private expr(expr: ast.Node, expected: Ty): Ty { return this.cachedCheck(expr, () => this.checkExpr(expr, expected)); } private checkExpr(expr: ast.Node, expected: Ty): Ty { const k = expr.kind; switch (k.tag) { case "IdentExpr": { const sym = this.syms.get(expr); if (sym.tag === "Fn") { const fn = this.cx.checkFn(sym.stmt); return fn.ty(); } if (sym.tag === "Bool") { return Ty.Bool; } if (sym.tag === "Builtin") { this.reporter.error( expr.loc, `invalid use of builtin '${sym.id}'`, ); this.reporter.abort(); } if (sym.tag === "FnParam") { return this.params[sym.idx]; } if (sym.tag === "Let") { const ty = this.nodeTys.get(sym.stmt); if (!ty) { throw new Error(); } return ty; } throw new Error(`'${sym.tag}' not handled`); } case "IntExpr": { const intTy = (() => { switch (k.intTy) { case "u8": return Ty.U8; case "u16": return Ty.U16; case "u32": return Ty.U32; case "u64": return Ty.U64; case "usize": return Ty.ISize; case "i8": return Ty.U8; case "i16": return Ty.U16; case "i32": return Ty.I32; case "i64": return Ty.U64; case "isize": return Ty.ISize; case "any": return Ty.AnyInt; default: throw new Error( `intType '${k.intTy}' not handled`, ); } })(); return this.resolve(intTy, expected, expr.loc).ty; } case "StrExpr": { return this.resolve( Ty.Ptr(Ty.Slice(Ty.U8)), expected, expr.loc, ).ty; } case "ArrayExpr": { if (k.values.length === 0) { if (expected.is("Any")) { return Ty.Array(Ty.Any, 0); } else { return this.resolve( Ty.Array(Ty.Any, 0), expected, expr.loc, ).ty; } } const expectedInner = expected.isIndexable() ? expected.indexableTy()! : Ty.Any; let res = this.resolve( this.expr(k.values[0], expectedInner), expectedInner, k.values[0].loc, ); while (true) { for (const val of k.values.slice(1)) { res = this.resolve( this.expr(val, expectedInner), expectedInner, k.values[0].loc, ); if (res.rewriteSubtree) { break; } } if (!res.rewriteSubtree) { break; } for (const val of k.values) { this.rewriteTree(val, res.ty); } } return Ty.Array(res.ty, k.values.length); } case "IndexExpr": { const innerTy = this.place( k.value, Ty.AnyIndexable(expected), ); if (!innerTy.isIndexable()) { this.reporter.error( expr.loc, `expected indexable type, got '${expected.pretty()}'`, ); this.reporter.abort(); } let argTy = this.expr(k.arg, Ty.Any); if (argTy.is("AnyInt")) { argTy = Ty.USize; this.rewriteTree(k.arg, argTy); return innerTy.indexableTy()!; } else if (argTy.is("Range")) { return Ty.Slice(innerTy.indexableTy()!); } else { throw new Error(); } } case "CallExpr": { const checkArgs = ( callableTy: Ty & { kind: { tag: "Fn" } }, ) => { const params = callableTy.kind.params; if (k.args.length !== params.length) { this.reporter.error( expr.loc, `incorrect amount of arguments. got ${k.args.length} expected ${params.length}`, ); if (calleeTy?.is("FnStmt")) { this.reporter.info( calleeTy.kind.stmt.loc, "function defined here", ); } this.reporter.abort(); } const args = k.args .map((arg, i) => this.expr(arg, params[i])); for (const i of args.keys()) { this.resolve(args[i], params[i], k.args[i].loc, () => { if (calleeTy?.is("FnStmt")) { this.reporter.info( calleeTy.kind.stmt.kind.params[i].loc, `parameter '${ calleeTy.kind.stmt.kind.params[i] .as("Param").kind.ident }' defined here`, ); } this.reporter.error( k.args[i].loc, `type '${ args[i].pretty() }' not compatible with type '${ params[i].pretty() }', for argument ${i}`, ); }); } }; const sym = k.value.is("IdentExpr") ? this.syms.get(k.value) : null; if (sym?.tag === "Builtin") { if (sym.id === "len") { checkArgs(Ty.Fn([Ty.Any], Ty.USize, null).as("Fn")); return Ty.USize; } if (sym.id === "print") { for (const arg of k.args) { const ty = this.expr(arg, Ty.Any); if (ty.is("AnyInt")) { this.rewriteTree(arg, Ty.I32); } } return Ty.Void; } throw new Error(`builtin '${sym.id}' not handled`); } const calleeTy = this.expr(k.value, Ty.AnyCallable(expected)); if (!calleeTy.isCallable()) { this.reporter.error( expr.loc, `expected callable type, got '${expected.pretty()}'`, ); this.reporter.abort(); } const callableTy = calleeTy.callableTy(); if (callableTy.kind.generics !== null) { ((() => { throw new Error("generics not implemented"); }) as () => void)(); let generics: Ty[]; if (k.generics) { generics = k.generics.map((ty) => this.ty(ty)); if ( generics.length !== callableTy.kind.generics.length ) { this.reporter.error( expr.loc, `expected ${callableTy.kind.generics.length} generic type arguments, got ${generics.length}`, ); this.reporter.abort(); } } else { generics = callableTy.kind.generics .map((_ty) => Ty.Any); } const newCalleeTy = Ty.Fn( callableTy.kind.params .map((ty) => ty.is("Generic") ? generics[ty.kind.idx] : ty ), callableTy.kind.retTy.is("Generic") ? generics[callableTy.kind.retTy.kind.idx] : callableTy.kind.retTy, null, ).as("Fn"); this.rewriteTree(k.value, newCalleeTy); checkArgs( Ty.Fn( callableTy.kind.params .map((ty) => ty.is("Generic") ? generics[ty.kind.idx] : ty ), callableTy.kind.retTy.is("Generic") ? generics[callableTy.kind.retTy.kind.idx] : callableTy.kind.retTy, null, ).as("Fn"), ); return callableTy.kind.retTy; } else { if (k.generics) { this.reporter.error(expr.loc, "no generics expected"); this.reporter.abort(); } checkArgs(callableTy); return callableTy.kind.retTy; } } case "UnaryExpr": { switch (k.op) { case "Not": { const ty = this.expr(k.expr, Ty.Bool); if (!ty.is("Bool")) { this.reporter.error( expr.loc, `expected 'bool', got '${expected.pretty()}'`, ); this.reporter.abort(); } return ty; } case "Neg": { const ty = this.expr(k.expr, expected); if (!ty.is("Int")) { this.reporter.error( expr.loc, `cannot apply ! to '${expected.pretty()}'`, ); this.reporter.abort(); } return ty; } case "Ref": { const ty = this.expr(k.expr, expected); return Ty.Ptr(ty); } case "RefMut": { const ty = this.expr(k.expr, expected); return Ty.PtrMut(ty); } case "Deref": { const ty = this.expr(k.expr, expected); if (!ty.is("Ptr") && !ty.is("PtrMut")) { this.reporter.error( expr.loc, `cannot dereference type '${expected.pretty()}'`, ); this.reporter.abort(); } if (!ty.kind.ty.isSized()) { this.reporter.error( expr.loc, `cannot dereference unsized type '${ty.kind.ty.pretty()}' in an expression`, ); this.reporter.abort(); } return ty.kind.ty; } default: k satisfies never; throw new Error(); } } case "BinaryExpr": { const op = new BinaryOp(k.op); const expectedInner = op.isPropagating() ? expected : Ty.Any; const left = this.expr(k.left, expectedInner); const right = this.expr(k.right, expectedInner); const res = this.resolve(left, right, expr.loc); const result = binaryOpTests .map((test) => test(k.op, left, right)) .filter((result) => result) .at(0); if (!result) { this.reporter.error( expr.loc, `operator '${k.tok}' cannot be applied to types '${left.pretty()}' and '${right.pretty()}'`, ); this.reporter.abort(); } this.rewriteTree(k.left, res.ty); this.rewriteTree(k.right, res.ty); return result; } case "RangeExpr": { for (const operandExpr of [k.begin, k.end]) { const operandTy = operandExpr && this.expr(operandExpr, Ty.USize); if (operandTy && !operandTy.resolvableWith(Ty.USize)) { this.reporter.error( operandExpr.loc, `range operand must be '${Ty.USize.pretty()}', not '${operandTy.pretty()}'`, ); this.reporter.abort(); } } return Ty.create("Range", {}); } case "Error": case "File": case "Block": case "ExprStmt": case "AssignStmt": case "FnStmt": case "ReturnStmt": case "LetStmt": case "IfStmt": case "WhileStmt": case "BreakStmt": case "Param": case "IdentTy": case "PtrTy": case "PtrMutTy": case "ArrayTy": case "SliceTy": case "Generic": throw new Error(`node '${k.tag}' not an expression`); default: k satisfies never; throw new Error(); } } private ty(ty: ast.Node): Ty { return this.cachedCheck(ty, () => this.checkTy(ty)); } private checkTy(ty: ast.Node): Ty { const k = ty.kind; switch (k.tag) { case "IdentTy": { const sym = this.syms.get(ty); if (sym.tag === "BuiltinTy") { switch (sym.ident) { case "void": return Ty.Void; case "bool": return Ty.Bool; case "i8": return Ty.I8; case "i16": return Ty.I16; case "i32": return Ty.I32; case "i64": return Ty.I64; case "isize": return Ty.ISize; case "u8": return Ty.U8; case "u16": return Ty.U16; case "u32": return Ty.U32; case "u64": return Ty.U64; case "usize": return Ty.USize; default: throw new Error( `unknown type '${k.ident}'`, ); } } if (sym.tag === "Generic") { if (!this.generics) { throw new Error(); } return this.generics[sym.idx]; } this.reporter.error(ty.loc, `symbol is not a type`); return this.reporter.abort(); } case "PtrTy": { const ty = this.ty(k.ty); return Ty.create("Ptr", { ty }); } case "PtrMutTy": { const ty = this.ty(k.ty); return Ty.create("PtrMut", { ty }); } case "ArrayTy": { const ty = this.ty(k.ty); const lengthTy = this.expr(k.length, Ty.USize); if (!lengthTy.resolvableWith(Ty.USize)) { this.reporter.error( k.length.loc, `for array length, expected 'int', got '${lengthTy.pretty()}'`, ); this.reporter.abort(); } if (!k.length.is("IntExpr")) { this.reporter.error( k.length.loc, `array length must be an 'int' expression`, ); this.reporter.abort(); } const length = k.length.kind.value; return Ty.create("Array", { ty, length }); } case "SliceTy": { const ty = this.ty(k.ty); return Ty.create("Slice", { ty }); } case "Error": case "File": case "Block": case "ExprStmt": case "AssignStmt": case "FnStmt": case "ReturnStmt": case "LetStmt": case "IfStmt": case "WhileStmt": case "BreakStmt": case "Param": case "IdentExpr": case "IntExpr": case "StrExpr": case "ArrayExpr": case "IndexExpr": case "CallExpr": case "UnaryExpr": case "BinaryExpr": case "RangeExpr": case "Generic": throw new Error(`node '${k.tag}' not a type`); } } private cachedCheck(node: ast.Node, func: () => Ty): Ty { const existing = this.nodeTys.get(node); if (existing !== undefined) { return existing; } const ty = func(); this.nodeTys.set(node, ty); return ty; } private resolve( ty: Ty, expected: Ty, loc: Loc, inCaseOfError?: () => void, ): TyRes { if (ty == expected) { return { ty, rewriteSubtree: false }; } if (expected.is("Any")) { return { ty, rewriteSubtree: false }; } if (expected.is("Int") && ty.is("AnyInt")) { return { ty: expected, rewriteSubtree: true }; } if (expected.is("AnyInt") && ty.is("Int")) { return { ty, rewriteSubtree: true }; } if (!ty.resolvableWith(expected)) { if (inCaseOfError) { inCaseOfError(); } else { this.reporter.error( loc, `expected type '${expected.pretty()}', got '${ty.pretty()}'`, ); } this.reporter.abort(); } throw new Error( `resolving '${ty.pretty()}' with '${expected.pretty()}' not implemented`, ); } private rewriteTree(node: ast.Node, ty: Ty) { this.nodeTys.set(node, ty); const k = node.kind; switch (k.tag) { case "IdentExpr": break; case "IntExpr": break; case "BinaryExpr": { const op = new BinaryOp(k.op); if (op.isPropagating()) { this.rewriteTree(k.left, ty); this.rewriteTree(k.right, ty); } break; } default: throw new Error(`not implemented for '${k.tag}'`); } } } type TyRes = { ty: Ty; rewriteSubtree: boolean; }; type BinaryOpTest = (op: ast.BinaryOp, left: Ty, right: Ty) => Ty | null; const binaryOpTests: BinaryOpTest[] = [ (op, left, right) => { const ops: ast.BinaryOp[] = ["Add", "Sub", "Mul", "Div", "Rem"]; if ( ops.includes(op) && (left.is("Int") || left.is("AnyInt")) && left.resolvableWith(right) ) { return left; } return null; }, (op, left, right) => { const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"]; if ( ops.includes(op) && (left.is("Int") || left.is("AnyInt")) && left.resolvableWith(right) ) { return Ty.Bool; } return null; }, ]; class BinaryOp { constructor( public op: ast.BinaryOp, ) {} isPropagating(): boolean { return (["Add", "Sub", "Mul", "Div", "Rem"] as ast.BinaryOp[]) .includes(this.op); } }