diff --git a/src/front/check.ts b/src/front/check.ts index 0246e6d..5c1c151 100644 --- a/src/front/check.ts +++ b/src/front/check.ts @@ -94,7 +94,7 @@ class CheckerCx { } assertCompatible(left: Ty, right: Ty, loc: Loc): void { - if (!left.compatibleWith(right)) { + if (!left.resolvableWith(right)) { this.error( loc, `type '${left.pretty()}' not compatible with type '${right.pretty()}'`, @@ -122,7 +122,7 @@ class StmtChecker { const ty = node.kind.expr ? this.cx.expr(node.kind.expr) : Ty.Void; - if (!ty.compatibleWith(retTy)) { + if (!ty.resolvableWith(retTy)) { this.cx.error( node.loc, `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, @@ -292,13 +292,13 @@ class ExprChecker { const argTy = this.cx.expr(node.kind.arg); if ( (exprTy.is("Array") || exprTy.is("Slice")) && - argTy.compatibleWith(Ty.I32) + argTy.resolvableWith(Ty.I32) ) { return exprTy.kind.ty; } if ( (exprTy.is("Array") || exprTy.is("Slice")) && - argTy.compatibleWith(Ty.create("Range", {})) + argTy.resolvableWith(Ty.create("Range", {})) ) { return Ty.create("Slice", { ty: exprTy.kind.ty }); } @@ -345,7 +345,7 @@ class ExprChecker { ); } for (const i of args.keys()) { - if (!args[i].compatibleWith(params[i])) { + if (!args[i].resolvableWith(params[i])) { this.reportArgTypeNotCompatible( node, args, @@ -404,10 +404,10 @@ class ExprChecker { private checkUnaryExpr(node: ast.NodeWithKind<"UnaryExpr">): Ty { const exprTy = this.cx.expr(node.kind.expr); - if (node.kind.op === "Neg" && exprTy.compatibleWith(Ty.I32)) { + if (node.kind.op === "Neg" && exprTy.resolvableWith(Ty.I32)) { return Ty.I32; } - if (node.kind.op === "Not" && exprTy.compatibleWith(Ty.Bool)) { + if (node.kind.op === "Not" && exprTy.resolvableWith(Ty.Bool)) { return Ty.Bool; } if (node.kind.op === "Ref") { @@ -455,7 +455,7 @@ class ExprChecker { private checkRangeExpr(node: ast.NodeWithKind<"RangeExpr">): Ty { for (const operandExpr of [node.kind.begin, node.kind.end]) { const operandTy = operandExpr && this.cx.expr(operandExpr); - if (operandTy && !operandTy.compatibleWith(Ty.I32)) { + if (operandTy && !operandTy.resolvableWith(Ty.I32)) { this.cx.error( operandExpr.loc, `range operand must be '${Ty.I32.pretty()}', not '${operandTy.pretty()}'`, @@ -567,7 +567,7 @@ class TyChecker { if (node.is("ArrayTy")) { const ty = this.cx.ty(node.kind.ty); const lengthTy = this.cx.expr(node.kind.length); - if (!lengthTy.compatibleWith(Ty.I32)) { + if (!lengthTy.resolvableWith(Ty.I32)) { this.cx.error( node.kind.length.loc, `for array length, expected 'int', got '${lengthTy.pretty()}'`, @@ -606,7 +606,7 @@ const binaryOpTests: BinaryOpTest[] = [ "Rem", ]; if ( - ops.includes(op) && left.is("Int") && left.compatibleWith(right) + ops.includes(op) && left.is("Int") && left.resolvableWith(right) ) { return left; } @@ -615,7 +615,7 @@ const binaryOpTests: BinaryOpTest[] = [ (op, left, right) => { const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"]; if ( - ops.includes(op) && left.is("Int") && left.compatibleWith(right) + ops.includes(op) && left.is("Int") && left.resolvableWith(right) ) { return Ty.Bool; } diff --git a/src/front/check2.ts b/src/front/check2.ts index 77c384a..2cbe3d7 100644 --- a/src/front/check2.ts +++ b/src/front/check2.ts @@ -1,32 +1,195 @@ import { Syms } from "./resolve.ts"; import * as ast from "../ast.ts"; import { Ty } from "../ty.ts"; +import { FileReporter, Loc } from "../diagnostics.ts"; -export function checkFn(fn: ast.FnStmt, syms: Syms) { +export function checkFn( + fn: ast.FnStmt, + syms: Syms, + reporter: FileReporter, +): CheckedFn { + return new TypeChecker(fn, syms, reporter).check(); } -export class FnTys { +export class CheckedFn { constructor( + private fnTy: Ty, private nodeTys: Map, ) {} - get(node: ast.Node): Ty { - const ty = this.nodeTys.get(node); + ty(): Ty { + return this.fnTy; + } + + exprTy(expr: ast.Node): Ty { + const ty = this.nodeTys.get(expr); if (ty === undefined) { - throw new Error(`no type for '${node.kind.tag}'`); + throw new Error(`no type for '${expr.kind.tag}'`); } return ty; } } -class FnChecker { +class TypeChecker { private nodeTys = new Map(); constructor( private fn: ast.FnStmt, private syms: Syms, + private reporter: FileReporter, ) {} - check() { + check(): CheckedFn { + const params = this.fn.kind.params + .map((node) => { + const param = node.as("Param"); + if (!param.kind.ty) { + this.reporter + .error(param.loc, `parameter must have a type`); + this.reporter.abort(); + } + return this.ty(param.kind.ty); + }); + + const retTy = this.fn.kind.retTy + ? this.ty(this.fn.kind.retTy) + : Ty.Void; + + this.fn.kind.body.visit({ + visit: (node) => { + if (!node.is("ReturnStmt")) { + return; + } + const ty = node.kind.expr + ? this.expr(node.kind.expr, retTy) + : Ty.Void; + if (!ty.resolvableWith(retTy)) { + this.reporter.error( + node.loc, + `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, + ); + this.reporter.info( + this.fn.kind.retTy?.loc ?? this.fn.loc, + `return type '${retTy}' defined here`, + ); + this.reporter.abort(); + } + }, + }); + + const ty = Ty.create("FnStmt", { + stmt: this.fn, + ty: Ty.create("Fn", { params, retTy }), + }); + + return new CheckedFn(ty, this.nodeTys); + } + + private expr(expr: ast.Node, expected: Ty): Ty { + return this.cachedCheck(expr, () => this.checkExpr(expr, expected)); + } + + private checkExpr(expr: ast.Node, expected: Ty): Ty { + const k = expr.kind; + switch (k.tag) { + case "Error": + case "File": + case "Block": + case "ExprStmt": + case "AssignStmt": + case "FnStmt": + case "ReturnStmt": + case "LetStmt": + case "IfStmt": + case "WhileStmt": + case "BreakStmt": + case "Param": + case "IdentTy": + case "PtrTy": + case "PtrMutTy": + case "ArrayTy": + case "SliceTy": + throw new Error(`node '${k.tag}' not an expression`); + case "IdentExpr": { + throw new Error("not implemented"); + } + case "IntExpr": { + return this.resolve(Ty.AnyInt, expected, expr.loc).ty; + } + case "StrExpr": { + return this.resolve( + Ty.Ptr(Ty.Slice(Ty.U8)), + expected, + expr.loc, + ).ty; + } + case "ArrayExpr": { + if (k.values.length === 0) { + if (expected.is("Any")) { + return Ty.Array(Ty.Any, 0); + } else { + return this.resolve( + Ty.Array(Ty.Any, 0), + expected, + expr.loc, + ).ty; + } + } + + throw new Error("not implemented"); + } + case "IndexExpr": { + throw new Error("not implemented"); + } + case "CallExpr": { + throw new Error("not implemented"); + } + case "UnaryExpr": { + throw new Error("not implemented"); + } + case "BinaryExpr": { + throw new Error("not implemented"); + } + case "RangeExpr": { + throw new Error("not implemented"); + } + default: + k satisfies never; + throw new Error(); + } + } + + private ty(ty: ast.Node): Ty { + return this.cachedCheck(ty, () => this.checkTy(ty)); + } + + private checkTy(ty: ast.Node): Ty { + throw new Error("not implemented"); + } + + private cachedCheck(node: ast.Node, func: () => Ty) { + const existing = this.nodeTys.get(node); + if (existing !== undefined) { + return existing; + } + const ty = func(); + this.nodeTys.set(node, ty); + return ty; + } + + private resolve(ty: Ty, expected: Ty, loc: Loc): TyRes { + if (!ty.resolvableWith(expected)) { + this.reporter.error( + loc, + `expected type '${expected.pretty()}', got '${ty.pretty()}'`, + ); + this.reporter.abort(); + } + throw new Error("not implemented"); } } + +type TyRes = { + ty: Ty; + rewriteSubtree: boolean; +}; diff --git a/src/middle.ts b/src/middle.ts index 410a472..7339bf0 100644 --- a/src/middle.ts +++ b/src/middle.ts @@ -346,16 +346,16 @@ class FnLowerer { const operandTy = this.tys.expr(expr.kind.expr); if ( expr.kind.op === "Neg" && - operandTy.compatibleWith(Ty.I32) && - resultTy.compatibleWith(Ty.I32) + operandTy.resolvableWith(Ty.I32) && + resultTy.resolvableWith(Ty.I32) ) { const operand = this.lowerExpr(expr.kind.expr); return this.pushInst(Ty.I32, "Negate", { source: operand }); } if ( expr.kind.op === "Not" && - operandTy.compatibleWith(Ty.Bool) && - resultTy.compatibleWith(Ty.Bool) + operandTy.resolvableWith(Ty.Bool) && + resultTy.resolvableWith(Ty.Bool) ) { const operand = this.lowerExpr(expr.kind.expr); return this.pushInst(Ty.Bool, "Not", { source: operand }); @@ -428,8 +428,8 @@ const binaryOpTests: BinaryOpTest[] = [ if ( ops.includes(op) && left.is("Int") && - left.compatibleWith(right) && - result.compatibleWith(left) + left.resolvableWith(right) && + result.resolvableWith(left) ) { return op as BinaryOp; } @@ -440,7 +440,7 @@ const binaryOpTests: BinaryOpTest[] = [ if ( ops.includes(op) && left.is("Int") && - left.compatibleWith(right) && + left.resolvableWith(right) && result.is("Bool") ) { return op as BinaryOp; diff --git a/src/ty.ts b/src/ty.ts index bec9a77..f406c28 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -24,7 +24,6 @@ export class Ty { static Error = Ty.create("Error", {}); static Void = Ty.create("Void", {}); - static IntLiteral = Ty.create("IntLiteral", {}); static I8 = Ty.create("Int", { intTy: "i8" }); static I16 = Ty.create("Int", { intTy: "i16" }); static I32 = Ty.create("Int", { intTy: "i32" }); @@ -37,6 +36,24 @@ export class Ty { static USize = Ty.create("Int", { intTy: "usize" }); static Bool = Ty.create("Bool", {}); + static Ptr(ty: Ty): Ty { + return this.create("Ptr", { ty }); + } + static PtrMut(ty: Ty): Ty { + return this.create("PtrMut", { ty }); + } + static Slice(ty: Ty): Ty { + return this.create("Slice", { ty }); + } + static Array(ty: Ty, length: number): Ty { + return this.create("Array", { ty, length }); + } + + /** Only used in type checker. */ + static Any = Ty.create("Any", {}); + /** Only used in type checker. */ + static AnyInt = Ty.create("AnyInt", {}); + private internHash(): string { return JSON.stringify(this.kind); } @@ -52,7 +69,28 @@ export class Ty { return this.kind.tag === tag; } - compatibleWith(other: Ty): boolean { + assertIs< + Tag extends TyKind["tag"], + >(tag: Tag): asserts this is Ty & { kind: { tag: Tag } } { + if (this.kind.tag !== tag) { + throw new Error(`ty is not '${tag}'`); + } + } + + as< + Tag extends TyKind["tag"], + >(tag: Tag): Ty & { kind: { tag: Tag } } { + this.assertIs(tag); + return this; + } + + /** + * Used for checking and inference in the type checker + * to say if two types are able to implicitly be converted + * into each other or to a third type. Past type inference + * this function is meaningless. + */ + resolvableWith(other: Ty): boolean { // types are interned--we can just do this if (this === other) { return true; @@ -64,10 +102,6 @@ export class Ty { if (this.is("Void")) { return other.is("Void"); } - if (this.is("IntLiteral")) { - // return other.is("Int") || other.is("UInt"); - return false; - } if (this.is("Int")) { return other.is("Int") && this.kind.intTy == other.kind.intTy; } @@ -79,7 +113,7 @@ export class Ty { if (!other.is("Ptr")) { return false; } - if (!this.kind.ty.compatibleWith(other.kind.ty)) { + if (!this.kind.ty.resolvableWith(other.kind.ty)) { return false; } return true; @@ -88,7 +122,7 @@ export class Ty { if (!other.is("Array")) { return false; } - if (!this.kind.ty.compatibleWith(other.kind.ty)) { + if (!this.kind.ty.resolvableWith(other.kind.ty)) { return false; } if (this.kind.length !== other.kind.length) { @@ -100,7 +134,7 @@ export class Ty { if (!other.is("Slice")) { return false; } - if (!this.kind.ty.compatibleWith(other.kind.ty)) { + if (!this.kind.ty.resolvableWith(other.kind.ty)) { return false; } return true; @@ -113,28 +147,26 @@ export class Ty { return false; } for (const i of this.kind.params.keys()) { - if (!this.kind.params[i].compatibleWith(other.kind.params[i])) { + if (!this.kind.params[i].resolvableWith(other.kind.params[i])) { return false; } } - if (!this.kind.retTy.compatibleWith(other.kind.retTy)) { + if (!this.kind.retTy.resolvableWith(other.kind.retTy)) { return false; } return true; } if (this.is("FnStmt")) { - if (!other.is("FnStmt")) { - return false; - } - if (!this.kind.ty.compatibleWith(other.kind.ty)) { - return false; - } - // redundant; sanity check - if (this.kind.stmt.id !== other.kind.stmt.id) { - throw new Error(); - } + // Since FnStmt tys are only compatible with itself, + // we can count on the ty cache for this check. + return false; + } + if (this.is("Any")) { return true; } + if (this.is("AnyInt")) { + return other.is("Int"); + } throw new Error(`'${this.kind.tag}' not handled`); } @@ -145,6 +177,15 @@ export class Ty { return true; } + innerFn(): Ty & { kind: { tag: "Fn" } } { + this.assertIs("FnStmt"); + return this.kind.ty.as("Fn"); + } + + isCheckerInternal(): boolean { + return this.is("Any") || this.is("AnyInt"); + } + pretty(colors?: stringify.PrettyColors): string { return stringify.tyPretty(this, colors); } @@ -153,7 +194,6 @@ export class Ty { export type TyKind = | { tag: "Error" } | { tag: "Void" } - | { tag: "IntLiteral" } | { tag: "Int"; intTy: ast.IntTy } | { tag: "Bool" } | { tag: "Ptr"; ty: Ty } @@ -162,4 +202,5 @@ export type TyKind = | { tag: "Slice"; ty: Ty } | { tag: "Range" } | { tag: "Fn"; params: Ty[]; retTy: Ty } - | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> }; + | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> } + | { tag: "Any" | "AnyInt" };