fix new type checker
All checks were successful
Check / Explore-Gitea-Actions (push) Successful in 8s

This commit is contained in:
sfja 2026-04-16 02:02:41 +02:00
parent 9e24ff3002
commit aab3fa77ad
4 changed files with 182 additions and 55 deletions

View File

@ -2,14 +2,30 @@ import { Syms } from "./resolve.ts";
import * as ast from "../ast.ts";
import { Ty } from "../ty.ts";
import { FileReporter, Loc } from "../diagnostics.ts";
import { exit } from "node:process";
import * as stringify from "../stringify.ts";
export function checkFn(
fn: ast.FnStmt,
syms: Syms,
reporter: FileReporter,
): CheckedFn {
return new TypeChecker(fn, syms, reporter).check();
export class Checker {
private checkedFns = new Map<ast.FnStmt, CheckedFn>();
constructor(
private syms: Syms,
private reporter: FileReporter,
) {}
checkFn(fn: ast.FnStmt): CheckedFn {
const existing = this.checkedFns.get(fn);
if (existing) {
return existing;
}
const checkedFn = new TypeChecker(this, fn, this.syms, this.reporter)
.check();
checkedFn.checkForCheckerInternalTys(this.reporter);
this.checkedFns.set(fn, checkedFn);
return checkedFn;
}
}
export class CheckedFn {
@ -29,12 +45,26 @@ export class CheckedFn {
}
return ty;
}
checkForCheckerInternalTys(reporter: FileReporter) {
for (const [node, ty] of this.nodeTys) {
if (ty.isCheckerInternal()) {
reporter.error(
node.loc,
`concrete type must be known at this point, got temporary type '${ty.pretty()}'`,
);
reporter.abort();
}
}
}
}
class TypeChecker {
private nodeTys = new Map<ast.Node, Ty>();
private params!: Ty[];
constructor(
private cx: Checker,
private fn: ast.FnStmt,
private syms: Syms,
private reporter: FileReporter,
@ -51,6 +81,7 @@ class TypeChecker {
}
return this.ty(param.kind.ty);
});
this.params = params;
const retTy = this.fn.kind.retTy
? this.ty(this.fn.kind.retTy)
@ -92,7 +123,12 @@ class TypeChecker {
ty = res.ty;
} else {
ty = this.expr(k.expr, Ty.Any);
if (ty.is("AnyInt")) {
ty = Ty.I32;
this.rewriteTree(k.expr, ty);
}
}
this.nodeTys.set(node, ty);
break;
}
@ -102,7 +138,7 @@ class TypeChecker {
}
case "AssignStmt": {
const placeTy = this.expr(k.place, Ty.Any);
const exprTy = this.expr(k.expr, Ty.Any);
const exprTy = this.expr(k.expr, placeTy);
if (!placeTy.resolvableWith(exprTy)) {
this.reporter.error(
k.expr.loc,
@ -137,7 +173,9 @@ class TypeChecker {
case "BreakStmt": {
break;
}
case "FnStmt":
case "FnStmt": {
break;
}
case "Error":
case "File":
case "Block":
@ -163,6 +201,8 @@ class TypeChecker {
},
});
this.convertAnyIntToI32();
const ty = Ty.create("FnStmt", {
stmt: this.fn,
ty: Ty.create("Fn", { params, retTy }),
@ -171,6 +211,38 @@ class TypeChecker {
return new CheckedFn(ty, this.nodeTys);
}
private convertAnyIntToI32() {
for (const [node, ty] of this.nodeTys) {
if (ty.is("AnyInt")) {
this.rewriteTree(node, Ty.I64);
}
}
}
private place(expr: ast.Node, expected: Ty): Ty {
return this.cachedCheck(expr, () => this.checkPlace(expr, expected));
}
private checkPlace(expr: ast.Node, expected: Ty): Ty {
const k = expr.kind;
switch (k.tag) {
case "UnaryExpr": {
switch (k.op) {
case "Deref": {
const innerTy = this.checkPlace(
k.expr,
Ty.AnyDerefable(expected),
);
if (innerTy.is("Ptr") || innerTy.is("PtrMut")) {
return innerTy.kind.ty;
}
}
}
}
}
return this.expr(expr, expected);
}
private expr(expr: ast.Node, expected: Ty): Ty {
return this.cachedCheck(expr, () => this.checkExpr(expr, expected));
}
@ -199,7 +271,8 @@ class TypeChecker {
case "IdentExpr": {
const sym = this.syms.get(expr);
if (sym.tag === "Fn") {
throw new Error("todo");
const fn = this.cx.checkFn(sym.stmt);
return fn.ty();
}
if (sym.tag === "Bool") {
return Ty.Bool;
@ -212,7 +285,7 @@ class TypeChecker {
this.reporter.abort();
}
if (sym.tag === "FnParam") {
return this.expr(sym.param, Ty.Any);
return this.params[sym.idx];
}
if (sym.tag === "Let") {
const ty = this.nodeTys.get(sym.stmt);
@ -276,20 +349,20 @@ class TypeChecker {
}
}
const expectedInner = expected.is("Any")
? expected
: expected.as("Array").kind.ty;
const expectedInner = expected.isIndexable()
? expected.indexableTy()!
: Ty.Any;
let res = this.resolve(
this.expr(k.values[0], expectedInner),
expected,
expectedInner,
k.values[0].loc,
);
while (true) {
for (const val of k.values.slice(1)) {
res = this.resolve(
this.expr(val, expectedInner),
expected,
expectedInner,
k.values[0].loc,
);
if (res.rewriteSubtree) {
@ -305,12 +378,12 @@ class TypeChecker {
}
}
return res.ty;
return Ty.Array(res.ty, k.values.length);
}
case "IndexExpr": {
const innerTy = this.expr(
const innerTy = this.place(
k.value,
Ty.Indexable(expected),
Ty.AnyIndexable(expected),
);
if (!innerTy.isIndexable()) {
this.reporter.error(
@ -389,14 +462,18 @@ class TypeChecker {
return Ty.USize;
}
if (sym.id === "print") {
void k.args
.map((arg) => this.expr(arg, Ty.Any));
for (const arg of k.args) {
const ty = this.expr(arg, Ty.Any);
if (ty.is("AnyInt")) {
this.rewriteTree(arg, Ty.I32);
}
}
return Ty.Void;
}
throw new Error(`builtin '${sym.id}' not handled`);
}
const calleeTy = this.expr(k.value, Ty.Callable(expected));
const calleeTy = this.expr(k.value, Ty.AnyCallable(expected));
if (!calleeTy.isCallable()) {
this.reporter.error(
expr.loc,
@ -439,7 +516,7 @@ class TypeChecker {
}
case "RefMut": {
const ty = this.expr(k.expr, expected);
return Ty.Ptr(ty);
return Ty.PtrMut(ty);
}
case "Deref": {
const ty = this.expr(k.expr, expected);
@ -465,8 +542,13 @@ class TypeChecker {
}
}
case "BinaryExpr": {
const left = this.expr(k.left, expected);
const right = this.expr(k.right, expected);
const op = new BinaryOp(k.op);
const expectedInner = op.isPropagating() ? expected : Ty.Any;
const left = this.expr(k.left, expectedInner);
const right = this.expr(k.right, expectedInner);
const res = this.resolve(left, right, expr.loc);
const result = binaryOpTests
.map((test) => test(k.op, left, right))
.filter((result) => result)
@ -478,16 +560,18 @@ class TypeChecker {
);
this.reporter.abort();
}
this.rewriteTree(k.left, res.ty);
this.rewriteTree(k.right, res.ty);
return result;
}
case "RangeExpr": {
for (const operandExpr of [k.begin, k.end]) {
const operandTy = operandExpr &&
this.expr(operandExpr, Ty.USize);
if (operandTy && !operandTy.resolvableWith(Ty.I32)) {
if (operandTy && !operandTy.resolvableWith(Ty.USize)) {
this.reporter.error(
operandExpr.loc,
`range operand must be '${Ty.I32.pretty()}', not '${operandTy.pretty()}'`,
`range operand must be '${Ty.USize.pretty()}', not '${operandTy.pretty()}'`,
);
this.reporter.abort();
}
@ -577,7 +661,7 @@ class TypeChecker {
case "ArrayTy": {
const ty = this.ty(k.ty);
const lengthTy = this.expr(k.length, Ty.USize);
if (!lengthTy.resolvableWith(Ty.I32)) {
if (!lengthTy.resolvableWith(Ty.USize)) {
this.reporter.error(
k.length.loc,
`for array length, expected 'int', got '${lengthTy.pretty()}'`,
@ -623,6 +707,12 @@ class TypeChecker {
if (expected.is("Any")) {
return { ty, rewriteSubtree: false };
}
if (expected.is("Int") && ty.is("AnyInt")) {
return { ty: expected, rewriteSubtree: true };
}
if (expected.is("AnyInt") && ty.is("Int")) {
return { ty, rewriteSubtree: true };
}
if (!ty.resolvableWith(expected)) {
if (inCaseOfError) {
inCaseOfError();
@ -640,10 +730,23 @@ class TypeChecker {
}
private rewriteTree(node: ast.Node, ty: Ty) {
this.nodeTys.set(node, ty);
const k = node.kind;
switch (k.tag) {
case "IdentExpr":
break;
case "IntExpr":
break;
case "BinaryExpr": {
const op = new BinaryOp(k.op);
if (op.isPropagating()) {
this.rewriteTree(k.left, ty);
this.rewriteTree(k.right, ty);
}
break;
}
default:
throw new Error("not implemented");
throw new Error(`not implemented for '${k.tag}'`);
}
}
}
@ -657,15 +760,10 @@ type BinaryOpTest = (op: ast.BinaryOp, left: Ty, right: Ty) => Ty | null;
const binaryOpTests: BinaryOpTest[] = [
(op, left, right) => {
const ops: ast.BinaryOp[] = [
"Add",
"Sub",
"Mul",
"Div",
"Rem",
];
const ops: ast.BinaryOp[] = ["Add", "Sub", "Mul", "Div", "Rem"];
if (
ops.includes(op) && left.is("Int") && left.resolvableWith(right)
ops.includes(op) && (left.is("Int") || left.is("AnyInt")) &&
left.resolvableWith(right)
) {
return left;
}
@ -674,10 +772,22 @@ const binaryOpTests: BinaryOpTest[] = [
(op, left, right) => {
const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"];
if (
ops.includes(op) && left.is("Int") && left.resolvableWith(right)
ops.includes(op) && (left.is("Int") || left.is("AnyInt")) &&
left.resolvableWith(right)
) {
return Ty.Bool;
}
return null;
},
];
class BinaryOp {
constructor(
public op: ast.BinaryOp,
) {}
isPropagating(): boolean {
return (["Add", "Sub", "Mul", "Div", "Rem"] as ast.BinaryOp[])
.includes(this.op);
}
}

View File

@ -38,8 +38,10 @@ if (!mainFn) {
}
const tys = new front.Tys(syms, fileRep);
// const fnTys = front.checkFn(mainFn, syms, fileRep);
//
const checker = new front.Checker(syms, fileRep);
const fnTys = checker.checkFn(mainFn);
// fileAst.visit({
// visit(node) {
// switch (node.kind.tag) {

View File

@ -80,7 +80,7 @@ export function tyPretty(ty: ty.Ty, colors = noColors): string {
case "Array":
return `${c.punctuation}[${
ty.kind.ty.pretty(c)
}${c.punctuation}; ${ty.kind.length}${c.punctuation}]`;
}${c.punctuation}; ${c.literal}${ty.kind.length}${c.punctuation}]`;
case "Slice":
return `${c.punctuation}[${ty.kind.ty.pretty(c)}${c.punctuation}]`;
case "Range":
@ -101,8 +101,6 @@ export function tyPretty(ty: ty.Ty, colors = noColors): string {
`${c.punctuation}, `,
)
}${c.punctuation}) -> ${ty.kind.ty.kind.retTy.pretty(c)}`;
case "AnyInt":
return `${c.typeIdent}{integer}`;
}
return `{${ty.kind.tag}}`;
}
@ -141,7 +139,7 @@ class MirFnPrettyStringifier {
const retTy = fnTy.kind.retTy.pretty(c);
this.result +=
`${c.keyword}fn ${c.fnIdent}${ident}${c.punctuation}(${params}${c.punctuation}) -> ${retTy} ${c.punctuation}{\n`;
`${c.keyword}fn ${c.fnIdent}${ident}${c.punctuation}(${params}${c.punctuation}) -> ${retTy}\n${c.punctuation}{\n`;
for (const bb of this.fn.bbs) {
this.basicBlock(bb);

View File

@ -54,12 +54,16 @@ export class Ty {
/** Only used in type checker. */
static AnyInt = Ty.create("AnyInt", {});
/** Only used in type checker. */
static Indexable(ty: Ty): Ty {
return this.create("Indexable", { ty });
static AnyIndexable(ty: Ty): Ty {
return this.create("AnyIndexable", { ty });
}
/** Only used in type checker. */
static Callable(ty: Ty): Ty {
return this.create("Callable", { ty });
static AnyCallable(ty: Ty): Ty {
return this.create("AnyCallable", { ty });
}
/** Only used in type checker. */
static AnyDerefable(ty: Ty): Ty {
return this.create("AnyDerefable", { ty });
}
private internHash(): string {
@ -111,7 +115,8 @@ export class Ty {
return other.is("Void");
}
if (this.is("Int")) {
return other.is("Int") && this.kind.intTy == other.kind.intTy;
return other.is("Int") && this.kind.intTy == other.kind.intTy ||
other.is("AnyInt");
}
if (this.is("Bool")) {
@ -186,11 +191,11 @@ export class Ty {
}
isIndexable(): boolean {
return this.is("Array") || this.is("Slice") || this.is("Indexable");
return this.is("Array") || this.is("Slice") || this.is("AnyIndexable");
}
indexableTy(): Ty | null {
if (this.is("Array") || this.is("Slice") || this.is("Indexable")) {
if (this.is("Array") || this.is("Slice") || this.is("AnyIndexable")) {
return this.kind.ty;
}
return null;
@ -208,9 +213,20 @@ export class Ty {
return this.kind.ty.as("Fn");
}
isDerefable(): boolean {
return this.is("Ptr") || this.is("PtrMut") || this.is("AnyDerefable");
}
derefableTy(): Ty | null {
if (this.is("Ptr") || this.is("PtrMut") || this.is("AnyDerefable")) {
return this.kind.ty;
}
return null;
}
isCheckerInternal(): boolean {
return this.is("Any") || this.is("AnyInt") || this.is("Indexable") ||
this.is("Callable");
return this.is("Any") || this.is("AnyInt") || this.is("AnyIndexable") ||
this.is("AnyCallable") || this.is("AnyDerefable");
}
pretty(colors?: stringify.PrettyColors): string {
@ -231,5 +247,6 @@ export type TyKind =
| { tag: "Fn"; params: Ty[]; retTy: Ty }
| { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> }
| { tag: "Any" | "AnyInt" }
| { tag: "Indexable"; ty: Ty }
| { tag: "Callable"; ty: Ty };
| { tag: "AnyIndexable"; ty: Ty }
| { tag: "AnyCallable"; ty: Ty }
| { tag: "AnyDerefable"; ty: Ty };