fix new type checker
All checks were successful
Check / Explore-Gitea-Actions (push) Successful in 8s

This commit is contained in:
sfja 2026-04-16 02:02:41 +02:00
parent 9e24ff3002
commit aab3fa77ad
4 changed files with 182 additions and 55 deletions

View File

@ -2,14 +2,30 @@ import { Syms } from "./resolve.ts";
import * as ast from "../ast.ts"; import * as ast from "../ast.ts";
import { Ty } from "../ty.ts"; import { Ty } from "../ty.ts";
import { FileReporter, Loc } from "../diagnostics.ts"; import { FileReporter, Loc } from "../diagnostics.ts";
import { exit } from "node:process"; import * as stringify from "../stringify.ts";
export function checkFn( export class Checker {
fn: ast.FnStmt, private checkedFns = new Map<ast.FnStmt, CheckedFn>();
syms: Syms,
reporter: FileReporter, constructor(
): CheckedFn { private syms: Syms,
return new TypeChecker(fn, syms, reporter).check(); private reporter: FileReporter,
) {}
checkFn(fn: ast.FnStmt): CheckedFn {
const existing = this.checkedFns.get(fn);
if (existing) {
return existing;
}
const checkedFn = new TypeChecker(this, fn, this.syms, this.reporter)
.check();
checkedFn.checkForCheckerInternalTys(this.reporter);
this.checkedFns.set(fn, checkedFn);
return checkedFn;
}
} }
export class CheckedFn { export class CheckedFn {
@ -29,12 +45,26 @@ export class CheckedFn {
} }
return ty; return ty;
} }
checkForCheckerInternalTys(reporter: FileReporter) {
for (const [node, ty] of this.nodeTys) {
if (ty.isCheckerInternal()) {
reporter.error(
node.loc,
`concrete type must be known at this point, got temporary type '${ty.pretty()}'`,
);
reporter.abort();
}
}
}
} }
class TypeChecker { class TypeChecker {
private nodeTys = new Map<ast.Node, Ty>(); private nodeTys = new Map<ast.Node, Ty>();
private params!: Ty[];
constructor( constructor(
private cx: Checker,
private fn: ast.FnStmt, private fn: ast.FnStmt,
private syms: Syms, private syms: Syms,
private reporter: FileReporter, private reporter: FileReporter,
@ -51,6 +81,7 @@ class TypeChecker {
} }
return this.ty(param.kind.ty); return this.ty(param.kind.ty);
}); });
this.params = params;
const retTy = this.fn.kind.retTy const retTy = this.fn.kind.retTy
? this.ty(this.fn.kind.retTy) ? this.ty(this.fn.kind.retTy)
@ -92,7 +123,12 @@ class TypeChecker {
ty = res.ty; ty = res.ty;
} else { } else {
ty = this.expr(k.expr, Ty.Any); ty = this.expr(k.expr, Ty.Any);
if (ty.is("AnyInt")) {
ty = Ty.I32;
this.rewriteTree(k.expr, ty);
} }
}
this.nodeTys.set(node, ty); this.nodeTys.set(node, ty);
break; break;
} }
@ -102,7 +138,7 @@ class TypeChecker {
} }
case "AssignStmt": { case "AssignStmt": {
const placeTy = this.expr(k.place, Ty.Any); const placeTy = this.expr(k.place, Ty.Any);
const exprTy = this.expr(k.expr, Ty.Any); const exprTy = this.expr(k.expr, placeTy);
if (!placeTy.resolvableWith(exprTy)) { if (!placeTy.resolvableWith(exprTy)) {
this.reporter.error( this.reporter.error(
k.expr.loc, k.expr.loc,
@ -137,7 +173,9 @@ class TypeChecker {
case "BreakStmt": { case "BreakStmt": {
break; break;
} }
case "FnStmt": case "FnStmt": {
break;
}
case "Error": case "Error":
case "File": case "File":
case "Block": case "Block":
@ -163,6 +201,8 @@ class TypeChecker {
}, },
}); });
this.convertAnyIntToI32();
const ty = Ty.create("FnStmt", { const ty = Ty.create("FnStmt", {
stmt: this.fn, stmt: this.fn,
ty: Ty.create("Fn", { params, retTy }), ty: Ty.create("Fn", { params, retTy }),
@ -171,6 +211,38 @@ class TypeChecker {
return new CheckedFn(ty, this.nodeTys); return new CheckedFn(ty, this.nodeTys);
} }
private convertAnyIntToI32() {
for (const [node, ty] of this.nodeTys) {
if (ty.is("AnyInt")) {
this.rewriteTree(node, Ty.I64);
}
}
}
private place(expr: ast.Node, expected: Ty): Ty {
return this.cachedCheck(expr, () => this.checkPlace(expr, expected));
}
private checkPlace(expr: ast.Node, expected: Ty): Ty {
const k = expr.kind;
switch (k.tag) {
case "UnaryExpr": {
switch (k.op) {
case "Deref": {
const innerTy = this.checkPlace(
k.expr,
Ty.AnyDerefable(expected),
);
if (innerTy.is("Ptr") || innerTy.is("PtrMut")) {
return innerTy.kind.ty;
}
}
}
}
}
return this.expr(expr, expected);
}
private expr(expr: ast.Node, expected: Ty): Ty { private expr(expr: ast.Node, expected: Ty): Ty {
return this.cachedCheck(expr, () => this.checkExpr(expr, expected)); return this.cachedCheck(expr, () => this.checkExpr(expr, expected));
} }
@ -199,7 +271,8 @@ class TypeChecker {
case "IdentExpr": { case "IdentExpr": {
const sym = this.syms.get(expr); const sym = this.syms.get(expr);
if (sym.tag === "Fn") { if (sym.tag === "Fn") {
throw new Error("todo"); const fn = this.cx.checkFn(sym.stmt);
return fn.ty();
} }
if (sym.tag === "Bool") { if (sym.tag === "Bool") {
return Ty.Bool; return Ty.Bool;
@ -212,7 +285,7 @@ class TypeChecker {
this.reporter.abort(); this.reporter.abort();
} }
if (sym.tag === "FnParam") { if (sym.tag === "FnParam") {
return this.expr(sym.param, Ty.Any); return this.params[sym.idx];
} }
if (sym.tag === "Let") { if (sym.tag === "Let") {
const ty = this.nodeTys.get(sym.stmt); const ty = this.nodeTys.get(sym.stmt);
@ -276,20 +349,20 @@ class TypeChecker {
} }
} }
const expectedInner = expected.is("Any") const expectedInner = expected.isIndexable()
? expected ? expected.indexableTy()!
: expected.as("Array").kind.ty; : Ty.Any;
let res = this.resolve( let res = this.resolve(
this.expr(k.values[0], expectedInner), this.expr(k.values[0], expectedInner),
expected, expectedInner,
k.values[0].loc, k.values[0].loc,
); );
while (true) { while (true) {
for (const val of k.values.slice(1)) { for (const val of k.values.slice(1)) {
res = this.resolve( res = this.resolve(
this.expr(val, expectedInner), this.expr(val, expectedInner),
expected, expectedInner,
k.values[0].loc, k.values[0].loc,
); );
if (res.rewriteSubtree) { if (res.rewriteSubtree) {
@ -305,12 +378,12 @@ class TypeChecker {
} }
} }
return res.ty; return Ty.Array(res.ty, k.values.length);
} }
case "IndexExpr": { case "IndexExpr": {
const innerTy = this.expr( const innerTy = this.place(
k.value, k.value,
Ty.Indexable(expected), Ty.AnyIndexable(expected),
); );
if (!innerTy.isIndexable()) { if (!innerTy.isIndexable()) {
this.reporter.error( this.reporter.error(
@ -389,14 +462,18 @@ class TypeChecker {
return Ty.USize; return Ty.USize;
} }
if (sym.id === "print") { if (sym.id === "print") {
void k.args for (const arg of k.args) {
.map((arg) => this.expr(arg, Ty.Any)); const ty = this.expr(arg, Ty.Any);
if (ty.is("AnyInt")) {
this.rewriteTree(arg, Ty.I32);
}
}
return Ty.Void; return Ty.Void;
} }
throw new Error(`builtin '${sym.id}' not handled`); throw new Error(`builtin '${sym.id}' not handled`);
} }
const calleeTy = this.expr(k.value, Ty.Callable(expected)); const calleeTy = this.expr(k.value, Ty.AnyCallable(expected));
if (!calleeTy.isCallable()) { if (!calleeTy.isCallable()) {
this.reporter.error( this.reporter.error(
expr.loc, expr.loc,
@ -439,7 +516,7 @@ class TypeChecker {
} }
case "RefMut": { case "RefMut": {
const ty = this.expr(k.expr, expected); const ty = this.expr(k.expr, expected);
return Ty.Ptr(ty); return Ty.PtrMut(ty);
} }
case "Deref": { case "Deref": {
const ty = this.expr(k.expr, expected); const ty = this.expr(k.expr, expected);
@ -465,8 +542,13 @@ class TypeChecker {
} }
} }
case "BinaryExpr": { case "BinaryExpr": {
const left = this.expr(k.left, expected); const op = new BinaryOp(k.op);
const right = this.expr(k.right, expected); const expectedInner = op.isPropagating() ? expected : Ty.Any;
const left = this.expr(k.left, expectedInner);
const right = this.expr(k.right, expectedInner);
const res = this.resolve(left, right, expr.loc);
const result = binaryOpTests const result = binaryOpTests
.map((test) => test(k.op, left, right)) .map((test) => test(k.op, left, right))
.filter((result) => result) .filter((result) => result)
@ -478,16 +560,18 @@ class TypeChecker {
); );
this.reporter.abort(); this.reporter.abort();
} }
this.rewriteTree(k.left, res.ty);
this.rewriteTree(k.right, res.ty);
return result; return result;
} }
case "RangeExpr": { case "RangeExpr": {
for (const operandExpr of [k.begin, k.end]) { for (const operandExpr of [k.begin, k.end]) {
const operandTy = operandExpr && const operandTy = operandExpr &&
this.expr(operandExpr, Ty.USize); this.expr(operandExpr, Ty.USize);
if (operandTy && !operandTy.resolvableWith(Ty.I32)) { if (operandTy && !operandTy.resolvableWith(Ty.USize)) {
this.reporter.error( this.reporter.error(
operandExpr.loc, operandExpr.loc,
`range operand must be '${Ty.I32.pretty()}', not '${operandTy.pretty()}'`, `range operand must be '${Ty.USize.pretty()}', not '${operandTy.pretty()}'`,
); );
this.reporter.abort(); this.reporter.abort();
} }
@ -577,7 +661,7 @@ class TypeChecker {
case "ArrayTy": { case "ArrayTy": {
const ty = this.ty(k.ty); const ty = this.ty(k.ty);
const lengthTy = this.expr(k.length, Ty.USize); const lengthTy = this.expr(k.length, Ty.USize);
if (!lengthTy.resolvableWith(Ty.I32)) { if (!lengthTy.resolvableWith(Ty.USize)) {
this.reporter.error( this.reporter.error(
k.length.loc, k.length.loc,
`for array length, expected 'int', got '${lengthTy.pretty()}'`, `for array length, expected 'int', got '${lengthTy.pretty()}'`,
@ -623,6 +707,12 @@ class TypeChecker {
if (expected.is("Any")) { if (expected.is("Any")) {
return { ty, rewriteSubtree: false }; return { ty, rewriteSubtree: false };
} }
if (expected.is("Int") && ty.is("AnyInt")) {
return { ty: expected, rewriteSubtree: true };
}
if (expected.is("AnyInt") && ty.is("Int")) {
return { ty, rewriteSubtree: true };
}
if (!ty.resolvableWith(expected)) { if (!ty.resolvableWith(expected)) {
if (inCaseOfError) { if (inCaseOfError) {
inCaseOfError(); inCaseOfError();
@ -640,10 +730,23 @@ class TypeChecker {
} }
private rewriteTree(node: ast.Node, ty: Ty) { private rewriteTree(node: ast.Node, ty: Ty) {
this.nodeTys.set(node, ty);
const k = node.kind; const k = node.kind;
switch (k.tag) { switch (k.tag) {
case "IdentExpr":
break;
case "IntExpr":
break;
case "BinaryExpr": {
const op = new BinaryOp(k.op);
if (op.isPropagating()) {
this.rewriteTree(k.left, ty);
this.rewriteTree(k.right, ty);
}
break;
}
default: default:
throw new Error("not implemented"); throw new Error(`not implemented for '${k.tag}'`);
} }
} }
} }
@ -657,15 +760,10 @@ type BinaryOpTest = (op: ast.BinaryOp, left: Ty, right: Ty) => Ty | null;
const binaryOpTests: BinaryOpTest[] = [ const binaryOpTests: BinaryOpTest[] = [
(op, left, right) => { (op, left, right) => {
const ops: ast.BinaryOp[] = [ const ops: ast.BinaryOp[] = ["Add", "Sub", "Mul", "Div", "Rem"];
"Add",
"Sub",
"Mul",
"Div",
"Rem",
];
if ( if (
ops.includes(op) && left.is("Int") && left.resolvableWith(right) ops.includes(op) && (left.is("Int") || left.is("AnyInt")) &&
left.resolvableWith(right)
) { ) {
return left; return left;
} }
@ -674,10 +772,22 @@ const binaryOpTests: BinaryOpTest[] = [
(op, left, right) => { (op, left, right) => {
const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"]; const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"];
if ( if (
ops.includes(op) && left.is("Int") && left.resolvableWith(right) ops.includes(op) && (left.is("Int") || left.is("AnyInt")) &&
left.resolvableWith(right)
) { ) {
return Ty.Bool; return Ty.Bool;
} }
return null; return null;
}, },
]; ];
class BinaryOp {
constructor(
public op: ast.BinaryOp,
) {}
isPropagating(): boolean {
return (["Add", "Sub", "Mul", "Div", "Rem"] as ast.BinaryOp[])
.includes(this.op);
}
}

View File

@ -38,8 +38,10 @@ if (!mainFn) {
} }
const tys = new front.Tys(syms, fileRep); const tys = new front.Tys(syms, fileRep);
// const fnTys = front.checkFn(mainFn, syms, fileRep);
// const checker = new front.Checker(syms, fileRep);
const fnTys = checker.checkFn(mainFn);
// fileAst.visit({ // fileAst.visit({
// visit(node) { // visit(node) {
// switch (node.kind.tag) { // switch (node.kind.tag) {

View File

@ -80,7 +80,7 @@ export function tyPretty(ty: ty.Ty, colors = noColors): string {
case "Array": case "Array":
return `${c.punctuation}[${ return `${c.punctuation}[${
ty.kind.ty.pretty(c) ty.kind.ty.pretty(c)
}${c.punctuation}; ${ty.kind.length}${c.punctuation}]`; }${c.punctuation}; ${c.literal}${ty.kind.length}${c.punctuation}]`;
case "Slice": case "Slice":
return `${c.punctuation}[${ty.kind.ty.pretty(c)}${c.punctuation}]`; return `${c.punctuation}[${ty.kind.ty.pretty(c)}${c.punctuation}]`;
case "Range": case "Range":
@ -101,8 +101,6 @@ export function tyPretty(ty: ty.Ty, colors = noColors): string {
`${c.punctuation}, `, `${c.punctuation}, `,
) )
}${c.punctuation}) -> ${ty.kind.ty.kind.retTy.pretty(c)}`; }${c.punctuation}) -> ${ty.kind.ty.kind.retTy.pretty(c)}`;
case "AnyInt":
return `${c.typeIdent}{integer}`;
} }
return `{${ty.kind.tag}}`; return `{${ty.kind.tag}}`;
} }
@ -141,7 +139,7 @@ class MirFnPrettyStringifier {
const retTy = fnTy.kind.retTy.pretty(c); const retTy = fnTy.kind.retTy.pretty(c);
this.result += this.result +=
`${c.keyword}fn ${c.fnIdent}${ident}${c.punctuation}(${params}${c.punctuation}) -> ${retTy} ${c.punctuation}{\n`; `${c.keyword}fn ${c.fnIdent}${ident}${c.punctuation}(${params}${c.punctuation}) -> ${retTy}\n${c.punctuation}{\n`;
for (const bb of this.fn.bbs) { for (const bb of this.fn.bbs) {
this.basicBlock(bb); this.basicBlock(bb);

View File

@ -54,12 +54,16 @@ export class Ty {
/** Only used in type checker. */ /** Only used in type checker. */
static AnyInt = Ty.create("AnyInt", {}); static AnyInt = Ty.create("AnyInt", {});
/** Only used in type checker. */ /** Only used in type checker. */
static Indexable(ty: Ty): Ty { static AnyIndexable(ty: Ty): Ty {
return this.create("Indexable", { ty }); return this.create("AnyIndexable", { ty });
} }
/** Only used in type checker. */ /** Only used in type checker. */
static Callable(ty: Ty): Ty { static AnyCallable(ty: Ty): Ty {
return this.create("Callable", { ty }); return this.create("AnyCallable", { ty });
}
/** Only used in type checker. */
static AnyDerefable(ty: Ty): Ty {
return this.create("AnyDerefable", { ty });
} }
private internHash(): string { private internHash(): string {
@ -111,7 +115,8 @@ export class Ty {
return other.is("Void"); return other.is("Void");
} }
if (this.is("Int")) { if (this.is("Int")) {
return other.is("Int") && this.kind.intTy == other.kind.intTy; return other.is("Int") && this.kind.intTy == other.kind.intTy ||
other.is("AnyInt");
} }
if (this.is("Bool")) { if (this.is("Bool")) {
@ -186,11 +191,11 @@ export class Ty {
} }
isIndexable(): boolean { isIndexable(): boolean {
return this.is("Array") || this.is("Slice") || this.is("Indexable"); return this.is("Array") || this.is("Slice") || this.is("AnyIndexable");
} }
indexableTy(): Ty | null { indexableTy(): Ty | null {
if (this.is("Array") || this.is("Slice") || this.is("Indexable")) { if (this.is("Array") || this.is("Slice") || this.is("AnyIndexable")) {
return this.kind.ty; return this.kind.ty;
} }
return null; return null;
@ -208,9 +213,20 @@ export class Ty {
return this.kind.ty.as("Fn"); return this.kind.ty.as("Fn");
} }
isDerefable(): boolean {
return this.is("Ptr") || this.is("PtrMut") || this.is("AnyDerefable");
}
derefableTy(): Ty | null {
if (this.is("Ptr") || this.is("PtrMut") || this.is("AnyDerefable")) {
return this.kind.ty;
}
return null;
}
isCheckerInternal(): boolean { isCheckerInternal(): boolean {
return this.is("Any") || this.is("AnyInt") || this.is("Indexable") || return this.is("Any") || this.is("AnyInt") || this.is("AnyIndexable") ||
this.is("Callable"); this.is("AnyCallable") || this.is("AnyDerefable");
} }
pretty(colors?: stringify.PrettyColors): string { pretty(colors?: stringify.PrettyColors): string {
@ -231,5 +247,6 @@ export type TyKind =
| { tag: "Fn"; params: Ty[]; retTy: Ty } | { tag: "Fn"; params: Ty[]; retTy: Ty }
| { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> } | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> }
| { tag: "Any" | "AnyInt" } | { tag: "Any" | "AnyInt" }
| { tag: "Indexable"; ty: Ty } | { tag: "AnyIndexable"; ty: Ty }
| { tag: "Callable"; ty: Ty }; | { tag: "AnyCallable"; ty: Ty }
| { tag: "AnyDerefable"; ty: Ty };