diff --git a/src/front/check2.ts b/src/front/check2.ts index 333b742..45f5d92 100644 --- a/src/front/check2.ts +++ b/src/front/check2.ts @@ -137,7 +137,7 @@ class TypeChecker { break; } case "AssignStmt": { - const placeTy = this.expr(k.place, Ty.Any); + const placeTy = this.place(k.place, Ty.Any); const exprTy = this.expr(k.expr, placeTy); if (!placeTy.resolvableWith(exprTy)) { this.reporter.error( diff --git a/src/main.ts b/src/main.ts index 4da6e30..2170d37 100644 --- a/src/main.ts +++ b/src/main.ts @@ -70,7 +70,7 @@ const fnTys = checker.checkFn(mainFn); // }, // }); -const m = new middle.MiddleLowerer(syms, tys); +const m = new middle.MiddleLowerer(syms, checker); const mainMiddleFn = m.lowerFn(mainFn); if (Deno.args.includes("--print-mir")) { diff --git a/src/middle.ts b/src/middle.ts index 7339bf0..2834d68 100644 --- a/src/middle.ts +++ b/src/middle.ts @@ -1,5 +1,5 @@ import * as ast from "./ast.ts"; -import { Syms, Tys } from "./front/mod.ts"; +import { CheckedFn, Checker, Syms, Tys } from "./front/mod.ts"; import { Ty } from "./ty.ts"; import { BasicBlock, BinaryOp, Fn, Inst, InstKind } from "./mir.ts"; @@ -8,14 +8,19 @@ export class MiddleLowerer { constructor( private syms: Syms, - private tys: Tys, + private checker: Checker, ) {} lowerFn(stmt: ast.FnStmt): Fn { if (this.fns.has(stmt.id)) { return this.fns.get(stmt.id)!; } - const fn = new FnLowerer(this, this.syms, this.tys, stmt).lower(); + const fn = new FnLowerer( + this, + this.syms, + this.checker.checkFn(stmt), + stmt, + ).lower(); this.fns.set(stmt.id, fn); return fn; } @@ -34,12 +39,12 @@ class FnLowerer { constructor( private lowerer: MiddleLowerer, private syms: Syms, - private tys: Tys, + private tys: CheckedFn, private stmt: ast.FnStmt, ) {} lower(): Fn { - const ty = this.tys.fnStmt(this.stmt); + const ty = this.tys.ty(); this.lowerBlock(this.stmt.kind.body.as("Block")); this.pushInst(Ty.Void, "Return", { source: this.makeVoid() }); this.bbs[0].insts.unshift(...this.allocs); @@ -172,7 +177,7 @@ class FnLowerer { } private lowerLetStmt(stmt: ast.NodeWithKind<"LetStmt">) { - const ty = this.tys.param(stmt.kind.param.as("Param")); + const ty = this.tys.exprTy(stmt.kind.expr); const expr = this.lowerExpr(stmt.kind.expr); const local = new Inst( Ty.create("PtrMut", { ty }), @@ -195,7 +200,7 @@ class FnLowerer { private lowerPlace(place: ast.Node): Inst { // evaluate to most direct pointer - const _ty = this.tys.place(place); + const _ty = this.tys.exprTy(place); if (place.is("IdentExpr")) { const sym = this.syms.get(place); @@ -218,9 +223,9 @@ class FnLowerer { if (place.is("IndexExpr")) { const value = place.kind.value; - const valueTy = this.tys.place(value); + const valueTy = this.tys.exprTy(value); const arg = place.kind.arg; - const argTy = this.tys.expr(arg); + const argTy = this.tys.exprTy(arg); if (valueTy.is("Array") || valueTy.is("Slice")) { const valueInst = this.lowerPlace(place.kind.value); if (argTy.is("Int")) { @@ -257,7 +262,7 @@ class FnLowerer { } private lowerExpr(expr: ast.Node): Inst { - const ty = this.tys.expr(expr); + const ty = this.tys.exprTy(expr); if (expr.is("IdentExpr")) { const sym = this.syms.get(expr); if (sym.tag === "Fn") { @@ -265,7 +270,9 @@ class FnLowerer { return this.pushInst(fn.ty, "Fn", { fn }); } if (sym.tag === "FnParam") { - const ty = this.tys.expr(sym.param); + const ty = this.tys.ty().as("FnStmt") + .kind.ty.as("Fn") + .kind.params[sym.idx]; return this.pushInst(ty, "Param", { idx: sym.idx }); } if (sym.tag === "Builtin") { @@ -290,7 +297,7 @@ class FnLowerer { return this.pushInst(ty, "Str", { value: expr.kind.value }); } if (expr.is("ArrayExpr")) { - const ty = this.tys.expr(expr); + const ty = this.tys.exprTy(expr); const values = expr.kind.values .map((value) => this.lowerExpr(value)); return this.pushInst(ty, "Array", { values }); @@ -323,8 +330,8 @@ class FnLowerer { return this.lowerUnaryExpr(expr); } if (expr.is("BinaryExpr")) { - const leftTy = this.tys.expr(expr.kind.left); - const rightTy = this.tys.expr(expr.kind.right); + const leftTy = this.tys.exprTy(expr.kind.left); + const rightTy = this.tys.exprTy(expr.kind.right); const binaryOp = binaryOpTests .map((test) => test(expr.kind.op, leftTy, rightTy, ty)) .filter((tested) => tested) @@ -342,8 +349,8 @@ class FnLowerer { } private lowerUnaryExpr(expr: ast.NodeWithKind<"UnaryExpr">) { - const resultTy = this.tys.expr(expr); - const operandTy = this.tys.expr(expr.kind.expr); + const resultTy = this.tys.exprTy(expr); + const operandTy = this.tys.exprTy(expr.kind.expr); if ( expr.kind.op === "Neg" && operandTy.resolvableWith(Ty.I32) && @@ -376,7 +383,7 @@ class FnLowerer { ); } if (place.is("IndexExpr")) { - const placeTy = this.tys.expr(place); + const placeTy = this.tys.exprTy(place); const placeInst = this.lowerPlace(place); if (placeTy.is("Slice")) { return placeInst;