From aab3fa77ad428cc26370e7e382507cf61c10b11a Mon Sep 17 00:00:00 2001 From: sfja Date: Thu, 16 Apr 2026 02:02:41 +0200 Subject: [PATCH] fix new type checker --- src/front/check2.ts | 186 +++++++++++++++++++++++++++++++++++--------- src/main.ts | 6 +- src/stringify.ts | 6 +- src/ty.ts | 39 +++++++--- 4 files changed, 182 insertions(+), 55 deletions(-) diff --git a/src/front/check2.ts b/src/front/check2.ts index 1042184..333b742 100644 --- a/src/front/check2.ts +++ b/src/front/check2.ts @@ -2,14 +2,30 @@ import { Syms } from "./resolve.ts"; import * as ast from "../ast.ts"; import { Ty } from "../ty.ts"; import { FileReporter, Loc } from "../diagnostics.ts"; -import { exit } from "node:process"; +import * as stringify from "../stringify.ts"; -export function checkFn( - fn: ast.FnStmt, - syms: Syms, - reporter: FileReporter, -): CheckedFn { - return new TypeChecker(fn, syms, reporter).check(); +export class Checker { + private checkedFns = new Map(); + + constructor( + private syms: Syms, + private reporter: FileReporter, + ) {} + + checkFn(fn: ast.FnStmt): CheckedFn { + const existing = this.checkedFns.get(fn); + if (existing) { + return existing; + } + + const checkedFn = new TypeChecker(this, fn, this.syms, this.reporter) + .check(); + + checkedFn.checkForCheckerInternalTys(this.reporter); + + this.checkedFns.set(fn, checkedFn); + return checkedFn; + } } export class CheckedFn { @@ -29,12 +45,26 @@ export class CheckedFn { } return ty; } + + checkForCheckerInternalTys(reporter: FileReporter) { + for (const [node, ty] of this.nodeTys) { + if (ty.isCheckerInternal()) { + reporter.error( + node.loc, + `concrete type must be known at this point, got temporary type '${ty.pretty()}'`, + ); + reporter.abort(); + } + } + } } class TypeChecker { private nodeTys = new Map(); + private params!: Ty[]; constructor( + private cx: Checker, private fn: ast.FnStmt, private syms: Syms, private reporter: FileReporter, @@ -51,6 +81,7 @@ class TypeChecker { } return this.ty(param.kind.ty); }); + this.params = params; const retTy = this.fn.kind.retTy ? this.ty(this.fn.kind.retTy) @@ -92,7 +123,12 @@ class TypeChecker { ty = res.ty; } else { ty = this.expr(k.expr, Ty.Any); + if (ty.is("AnyInt")) { + ty = Ty.I32; + this.rewriteTree(k.expr, ty); + } } + this.nodeTys.set(node, ty); break; } @@ -102,7 +138,7 @@ class TypeChecker { } case "AssignStmt": { const placeTy = this.expr(k.place, Ty.Any); - const exprTy = this.expr(k.expr, Ty.Any); + const exprTy = this.expr(k.expr, placeTy); if (!placeTy.resolvableWith(exprTy)) { this.reporter.error( k.expr.loc, @@ -137,7 +173,9 @@ class TypeChecker { case "BreakStmt": { break; } - case "FnStmt": + case "FnStmt": { + break; + } case "Error": case "File": case "Block": @@ -163,6 +201,8 @@ class TypeChecker { }, }); + this.convertAnyIntToI32(); + const ty = Ty.create("FnStmt", { stmt: this.fn, ty: Ty.create("Fn", { params, retTy }), @@ -171,6 +211,38 @@ class TypeChecker { return new CheckedFn(ty, this.nodeTys); } + private convertAnyIntToI32() { + for (const [node, ty] of this.nodeTys) { + if (ty.is("AnyInt")) { + this.rewriteTree(node, Ty.I64); + } + } + } + + private place(expr: ast.Node, expected: Ty): Ty { + return this.cachedCheck(expr, () => this.checkPlace(expr, expected)); + } + + private checkPlace(expr: ast.Node, expected: Ty): Ty { + const k = expr.kind; + switch (k.tag) { + case "UnaryExpr": { + switch (k.op) { + case "Deref": { + const innerTy = this.checkPlace( + k.expr, + Ty.AnyDerefable(expected), + ); + if (innerTy.is("Ptr") || innerTy.is("PtrMut")) { + return innerTy.kind.ty; + } + } + } + } + } + return this.expr(expr, expected); + } + private expr(expr: ast.Node, expected: Ty): Ty { return this.cachedCheck(expr, () => this.checkExpr(expr, expected)); } @@ -199,7 +271,8 @@ class TypeChecker { case "IdentExpr": { const sym = this.syms.get(expr); if (sym.tag === "Fn") { - throw new Error("todo"); + const fn = this.cx.checkFn(sym.stmt); + return fn.ty(); } if (sym.tag === "Bool") { return Ty.Bool; @@ -212,7 +285,7 @@ class TypeChecker { this.reporter.abort(); } if (sym.tag === "FnParam") { - return this.expr(sym.param, Ty.Any); + return this.params[sym.idx]; } if (sym.tag === "Let") { const ty = this.nodeTys.get(sym.stmt); @@ -276,20 +349,20 @@ class TypeChecker { } } - const expectedInner = expected.is("Any") - ? expected - : expected.as("Array").kind.ty; + const expectedInner = expected.isIndexable() + ? expected.indexableTy()! + : Ty.Any; let res = this.resolve( this.expr(k.values[0], expectedInner), - expected, + expectedInner, k.values[0].loc, ); while (true) { for (const val of k.values.slice(1)) { res = this.resolve( this.expr(val, expectedInner), - expected, + expectedInner, k.values[0].loc, ); if (res.rewriteSubtree) { @@ -305,12 +378,12 @@ class TypeChecker { } } - return res.ty; + return Ty.Array(res.ty, k.values.length); } case "IndexExpr": { - const innerTy = this.expr( + const innerTy = this.place( k.value, - Ty.Indexable(expected), + Ty.AnyIndexable(expected), ); if (!innerTy.isIndexable()) { this.reporter.error( @@ -389,14 +462,18 @@ class TypeChecker { return Ty.USize; } if (sym.id === "print") { - void k.args - .map((arg) => this.expr(arg, Ty.Any)); + for (const arg of k.args) { + const ty = this.expr(arg, Ty.Any); + if (ty.is("AnyInt")) { + this.rewriteTree(arg, Ty.I32); + } + } return Ty.Void; } throw new Error(`builtin '${sym.id}' not handled`); } - const calleeTy = this.expr(k.value, Ty.Callable(expected)); + const calleeTy = this.expr(k.value, Ty.AnyCallable(expected)); if (!calleeTy.isCallable()) { this.reporter.error( expr.loc, @@ -439,7 +516,7 @@ class TypeChecker { } case "RefMut": { const ty = this.expr(k.expr, expected); - return Ty.Ptr(ty); + return Ty.PtrMut(ty); } case "Deref": { const ty = this.expr(k.expr, expected); @@ -465,8 +542,13 @@ class TypeChecker { } } case "BinaryExpr": { - const left = this.expr(k.left, expected); - const right = this.expr(k.right, expected); + const op = new BinaryOp(k.op); + const expectedInner = op.isPropagating() ? expected : Ty.Any; + const left = this.expr(k.left, expectedInner); + const right = this.expr(k.right, expectedInner); + + const res = this.resolve(left, right, expr.loc); + const result = binaryOpTests .map((test) => test(k.op, left, right)) .filter((result) => result) @@ -478,16 +560,18 @@ class TypeChecker { ); this.reporter.abort(); } + this.rewriteTree(k.left, res.ty); + this.rewriteTree(k.right, res.ty); return result; } case "RangeExpr": { for (const operandExpr of [k.begin, k.end]) { const operandTy = operandExpr && this.expr(operandExpr, Ty.USize); - if (operandTy && !operandTy.resolvableWith(Ty.I32)) { + if (operandTy && !operandTy.resolvableWith(Ty.USize)) { this.reporter.error( operandExpr.loc, - `range operand must be '${Ty.I32.pretty()}', not '${operandTy.pretty()}'`, + `range operand must be '${Ty.USize.pretty()}', not '${operandTy.pretty()}'`, ); this.reporter.abort(); } @@ -577,7 +661,7 @@ class TypeChecker { case "ArrayTy": { const ty = this.ty(k.ty); const lengthTy = this.expr(k.length, Ty.USize); - if (!lengthTy.resolvableWith(Ty.I32)) { + if (!lengthTy.resolvableWith(Ty.USize)) { this.reporter.error( k.length.loc, `for array length, expected 'int', got '${lengthTy.pretty()}'`, @@ -623,6 +707,12 @@ class TypeChecker { if (expected.is("Any")) { return { ty, rewriteSubtree: false }; } + if (expected.is("Int") && ty.is("AnyInt")) { + return { ty: expected, rewriteSubtree: true }; + } + if (expected.is("AnyInt") && ty.is("Int")) { + return { ty, rewriteSubtree: true }; + } if (!ty.resolvableWith(expected)) { if (inCaseOfError) { inCaseOfError(); @@ -640,10 +730,23 @@ class TypeChecker { } private rewriteTree(node: ast.Node, ty: Ty) { + this.nodeTys.set(node, ty); const k = node.kind; switch (k.tag) { + case "IdentExpr": + break; + case "IntExpr": + break; + case "BinaryExpr": { + const op = new BinaryOp(k.op); + if (op.isPropagating()) { + this.rewriteTree(k.left, ty); + this.rewriteTree(k.right, ty); + } + break; + } default: - throw new Error("not implemented"); + throw new Error(`not implemented for '${k.tag}'`); } } } @@ -657,15 +760,10 @@ type BinaryOpTest = (op: ast.BinaryOp, left: Ty, right: Ty) => Ty | null; const binaryOpTests: BinaryOpTest[] = [ (op, left, right) => { - const ops: ast.BinaryOp[] = [ - "Add", - "Sub", - "Mul", - "Div", - "Rem", - ]; + const ops: ast.BinaryOp[] = ["Add", "Sub", "Mul", "Div", "Rem"]; if ( - ops.includes(op) && left.is("Int") && left.resolvableWith(right) + ops.includes(op) && (left.is("Int") || left.is("AnyInt")) && + left.resolvableWith(right) ) { return left; } @@ -674,10 +772,22 @@ const binaryOpTests: BinaryOpTest[] = [ (op, left, right) => { const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"]; if ( - ops.includes(op) && left.is("Int") && left.resolvableWith(right) + ops.includes(op) && (left.is("Int") || left.is("AnyInt")) && + left.resolvableWith(right) ) { return Ty.Bool; } return null; }, ]; + +class BinaryOp { + constructor( + public op: ast.BinaryOp, + ) {} + + isPropagating(): boolean { + return (["Add", "Sub", "Mul", "Div", "Rem"] as ast.BinaryOp[]) + .includes(this.op); + } +} diff --git a/src/main.ts b/src/main.ts index 91acd9d..4da6e30 100644 --- a/src/main.ts +++ b/src/main.ts @@ -38,8 +38,10 @@ if (!mainFn) { } const tys = new front.Tys(syms, fileRep); -// const fnTys = front.checkFn(mainFn, syms, fileRep); -// + +const checker = new front.Checker(syms, fileRep); +const fnTys = checker.checkFn(mainFn); + // fileAst.visit({ // visit(node) { // switch (node.kind.tag) { diff --git a/src/stringify.ts b/src/stringify.ts index d98920e..6faad36 100644 --- a/src/stringify.ts +++ b/src/stringify.ts @@ -80,7 +80,7 @@ export function tyPretty(ty: ty.Ty, colors = noColors): string { case "Array": return `${c.punctuation}[${ ty.kind.ty.pretty(c) - }${c.punctuation}; ${ty.kind.length}${c.punctuation}]`; + }${c.punctuation}; ${c.literal}${ty.kind.length}${c.punctuation}]`; case "Slice": return `${c.punctuation}[${ty.kind.ty.pretty(c)}${c.punctuation}]`; case "Range": @@ -101,8 +101,6 @@ export function tyPretty(ty: ty.Ty, colors = noColors): string { `${c.punctuation}, `, ) }${c.punctuation}) -> ${ty.kind.ty.kind.retTy.pretty(c)}`; - case "AnyInt": - return `${c.typeIdent}{integer}`; } return `{${ty.kind.tag}}`; } @@ -141,7 +139,7 @@ class MirFnPrettyStringifier { const retTy = fnTy.kind.retTy.pretty(c); this.result += - `${c.keyword}fn ${c.fnIdent}${ident}${c.punctuation}(${params}${c.punctuation}) -> ${retTy} ${c.punctuation}{\n`; + `${c.keyword}fn ${c.fnIdent}${ident}${c.punctuation}(${params}${c.punctuation}) -> ${retTy}\n${c.punctuation}{\n`; for (const bb of this.fn.bbs) { this.basicBlock(bb); diff --git a/src/ty.ts b/src/ty.ts index 562a0bb..cb2e143 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -54,12 +54,16 @@ export class Ty { /** Only used in type checker. */ static AnyInt = Ty.create("AnyInt", {}); /** Only used in type checker. */ - static Indexable(ty: Ty): Ty { - return this.create("Indexable", { ty }); + static AnyIndexable(ty: Ty): Ty { + return this.create("AnyIndexable", { ty }); } /** Only used in type checker. */ - static Callable(ty: Ty): Ty { - return this.create("Callable", { ty }); + static AnyCallable(ty: Ty): Ty { + return this.create("AnyCallable", { ty }); + } + /** Only used in type checker. */ + static AnyDerefable(ty: Ty): Ty { + return this.create("AnyDerefable", { ty }); } private internHash(): string { @@ -111,7 +115,8 @@ export class Ty { return other.is("Void"); } if (this.is("Int")) { - return other.is("Int") && this.kind.intTy == other.kind.intTy; + return other.is("Int") && this.kind.intTy == other.kind.intTy || + other.is("AnyInt"); } if (this.is("Bool")) { @@ -186,11 +191,11 @@ export class Ty { } isIndexable(): boolean { - return this.is("Array") || this.is("Slice") || this.is("Indexable"); + return this.is("Array") || this.is("Slice") || this.is("AnyIndexable"); } indexableTy(): Ty | null { - if (this.is("Array") || this.is("Slice") || this.is("Indexable")) { + if (this.is("Array") || this.is("Slice") || this.is("AnyIndexable")) { return this.kind.ty; } return null; @@ -208,9 +213,20 @@ export class Ty { return this.kind.ty.as("Fn"); } + isDerefable(): boolean { + return this.is("Ptr") || this.is("PtrMut") || this.is("AnyDerefable"); + } + + derefableTy(): Ty | null { + if (this.is("Ptr") || this.is("PtrMut") || this.is("AnyDerefable")) { + return this.kind.ty; + } + return null; + } + isCheckerInternal(): boolean { - return this.is("Any") || this.is("AnyInt") || this.is("Indexable") || - this.is("Callable"); + return this.is("Any") || this.is("AnyInt") || this.is("AnyIndexable") || + this.is("AnyCallable") || this.is("AnyDerefable"); } pretty(colors?: stringify.PrettyColors): string { @@ -231,5 +247,6 @@ export type TyKind = | { tag: "Fn"; params: Ty[]; retTy: Ty } | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> } | { tag: "Any" | "AnyInt" } - | { tag: "Indexable"; ty: Ty } - | { tag: "Callable"; ty: Ty }; + | { tag: "AnyIndexable"; ty: Ty } + | { tag: "AnyCallable"; ty: Ty } + | { tag: "AnyDerefable"; ty: Ty };