import * as ast from "./ast.ts"; import { Syms, Tys } from "./front/mod.ts"; import { Ty } from "./ty.ts"; import { BasicBlock, BinaryOp, Fn, Inst, InstKind } from "./mir.ts"; export class MiddleLowerer { private fns = new Map(); constructor( private syms: Syms, private tys: Tys, ) {} lowerFn(stmt: ast.FnStmt): Fn { if (this.fns.has(stmt.id)) { return this.fns.get(stmt.id)!; } const fn = new FnLowerer(this, this.syms, this.tys, stmt).lower(); this.fns.set(stmt.id, fn); return fn; } allFns(): Fn[] { return this.fns.values().toArray(); } } class FnLowerer { private allocs: Inst[] = []; private bbs: BasicBlock[] = [new BasicBlock()]; private localMap = new Map(); private loopEndMap = new Map(); constructor( private lowerer: MiddleLowerer, private syms: Syms, private tys: Tys, private stmt: ast.FnStmt, ) {} lower(): Fn { const ty = this.tys.fnStmt(this.stmt); this.lowerBlock(this.stmt.kind.body.as("Block")); this.pushInst(Ty.Void, "Return", { source: this.makeVoid() }); this.bbs[0].insts.unshift(...this.allocs); return new Fn(this.stmt, ty, this.bbs); } private lowerBlock(block: ast.Block) { for (const stmt of block.kind.stmts) { this.lowerStmt(stmt); } } private lowerStmt(stmt: ast.Node) { if (stmt.is("LetStmt")) { return this.lowerLetStmt(stmt); } if (stmt.is("ReturnStmt")) { return this.lowerReturnStmt(stmt); } if (stmt.is("IfStmt")) { return this.lowerIfStmt(stmt); } if (stmt.is("WhileStmt")) { return this.lowerWhileStmt(stmt); } if (stmt.is("BreakStmt")) { return this.lowerBreakStmt(stmt); } if (stmt.is("AssignStmt")) { return this.lowerAssignStmt(stmt); } if (stmt.is("ExprStmt")) { return this.lowerExpr(stmt.kind.expr); } throw new Error(`'${stmt.kind.tag}' not handled`); } private lowerReturnStmt(stmt: ast.NodeWithKind<"ReturnStmt">) { const source = stmt.kind.expr ? this.lowerExpr(stmt.kind.expr) : this.makeVoid(); this.pushInst(Ty.Void, "Return", { source }); this.bbs.push(new BasicBlock()); } private lowerIfStmt(stmt: ast.NodeWithKind<"IfStmt">) { const cond = this.lowerExpr(stmt.kind.cond); const condBlock = this.bbs.at(-1)!; this.bbs.push(new BasicBlock()); const truthy = this.bbs.at(-1)!; this.lowerBlock(stmt.kind.truthy.as("Block")); const truthyEnd = this.bbs.at(-1)!; let falsy: BasicBlock | null = null; let falsyEnd: BasicBlock | null = null; if (stmt.kind.falsy) { this.bbs.push(new BasicBlock()); falsy = this.bbs.at(-1)!; this.lowerBlock(stmt.kind.falsy.as("Block")); falsyEnd = this.bbs.at(-1)!; } this.bbs.push(new BasicBlock()); const done = this.bbs.at(-1)!; condBlock.insts.push( new Inst(Ty.Void, { tag: "Branch", cond, truthy, falsy: falsy ?? done, }), ); truthyEnd.insts.push( new Inst(Ty.Void, { tag: "Jump", target: falsy ?? done }), ); falsyEnd?.insts.push( new Inst(Ty.Void, { tag: "Jump", target: done }), ); } private lowerWhileStmt(stmt: ast.NodeWithKind<"WhileStmt">) { const before = this.bbs.at(-1)!; this.bbs.push(new BasicBlock()); const body = this.bbs.at(-1)!; const after = new BasicBlock(); this.loopEndMap.set(stmt.id, after); this.lowerBlock(stmt.kind.body.as("Block")); const bodyEnd = this.bbs.at(-1)!; this.bbs.push(new BasicBlock()); const condBlock = this.bbs.at(-1)!; const cond = this.lowerExpr(stmt.kind.cond); const condBlockEnd = this.bbs.at(-1)!; this.bbs.push(after); before.insts.push( new Inst(Ty.Void, { tag: "Jump", target: condBlock }), ); condBlockEnd.insts.push( new Inst(Ty.Void, { tag: "Branch", cond: cond, truthy: body, falsy: after, }), ); bodyEnd.insts.push( new Inst(Ty.Void, { tag: "Jump", target: condBlock }), ); } private lowerBreakStmt(stmt: ast.NodeWithKind<"BreakStmt">) { const sym = this.syms.get(stmt); if (sym.tag !== "Loop") { throw new Error(); } const loopEnd = this.loopEndMap.get(sym.stmt.id); if (!loopEnd) { throw new Error(); } this.pushInst(Ty.Void, "Jump", { target: loopEnd }); this.bbs.push(new BasicBlock()); } private lowerLetStmt(stmt: ast.NodeWithKind<"LetStmt">) { const ty = this.tys.param(stmt.kind.param.as("Param")); const expr = this.lowerExpr(stmt.kind.expr); const local = new Inst( Ty.create("PtrMut", { ty }), { tag: "Alloca" }, ); this.allocs.push(local); this.pushInst(Ty.Void, "Store", { target: local, source: expr, }); this.localMap.set(stmt.kind.param.id, local); } private lowerAssignStmt(stmt: ast.NodeWithKind<"AssignStmt">) { const source = this.lowerExpr(stmt.kind.expr); const target = this.lowerPlace(stmt.kind.place); this.pushInst(Ty.Void, "Store", { target, source }); } private lowerPlace(place: ast.Node): Inst { // evaluate to most direct pointer const _ty = this.tys.place(place); if (place.is("IdentExpr")) { const sym = this.syms.get(place); if (sym.tag === "Let") { const local = this.localMap.get(sym.param.id); if (!local) { throw new Error(); } return local; } if (sym.tag === "FnParam") { return this.lowerExpr(place); } throw new Error(`'${sym.tag}' not handled`); } if (place.is("UnaryExpr") && place.kind.op === "Deref") { return this.lowerExpr(place.kind.expr); } if (place.is("IndexExpr")) { const value = place.kind.value; const valueTy = this.tys.place(value); const arg = place.kind.arg; const argTy = this.tys.expr(arg); if (valueTy.is("Array") || valueTy.is("Slice")) { const valueInst = this.lowerPlace(place.kind.value); if (argTy.is("Int")) { const argInst = this.lowerExpr(arg); return this.pushInst( Ty.create("PtrMut", { ty: valueTy.kind.ty }), "GetElemPtr", { base: valueInst, offset: argInst }, ); } if (argTy.is("Range")) { if (!arg.is("RangeExpr")) { throw new Error("not supported yet"); } const begin = arg.kind.begin && this.lowerExpr(arg.kind.begin); const end = arg.kind.end && this.lowerExpr(arg.kind.end); return this.pushInst( Ty.create("PtrMut", { ty: Ty.create("Slice", { ty: valueTy.kind.ty }), }), "Slice", { value: valueInst, begin, end }, ); } } throw new Error( `${place.kind.tag} with arg ${argTy.pretty()} not handled`, ); } throw new Error(`'${place.kind.tag}' not handled`); } private lowerExpr(expr: ast.Node): Inst { const ty = this.tys.expr(expr); if (expr.is("IdentExpr")) { const sym = this.syms.get(expr); if (sym.tag === "Fn") { const fn = this.lowerer.lowerFn(sym.stmt); return this.pushInst(fn.ty, "Fn", { fn }); } if (sym.tag === "FnParam") { const ty = this.tys.expr(sym.param); return this.pushInst(ty, "Param", { idx: sym.idx }); } if (sym.tag === "Builtin") { throw new Error("handle elsewhere"); } if (sym.tag === "Let") { const source = this.lowerPlace(expr); return this.pushInst(ty, "Load", { source }); } if (sym.tag === "Bool") { return this.pushInst(Ty.Bool, "Bool", { value: sym.value }); } throw new Error(`'${sym.tag}' not handled`); } if (expr.is("IntExpr")) { return this.pushInst(ty, "Int", { value: expr.kind.value, intTy: expr.kind.intTy, }); } if (expr.is("StrExpr")) { return this.pushInst(ty, "Str", { value: expr.kind.value }); } if (expr.is("ArrayExpr")) { const ty = this.tys.expr(expr); const values = expr.kind.values .map((value) => this.lowerExpr(value)); return this.pushInst(ty, "Array", { values }); } if (expr.is("IndexExpr")) { const source = this.lowerPlace(expr); return this.pushInst(ty, "Load", { source }); } if (expr.is("CallExpr")) { const args = expr.kind.args .map((arg) => this.lowerExpr(arg)); if (expr.kind.value.is("IdentExpr")) { const sym = this.syms.get(expr.kind.value); if (sym.tag === "Builtin") { if (sym.id === "len") { return this.pushInst(ty, "Len", { source: args[0] }); } if (sym.id === "print") { return this.pushInst(ty, "DebugPrint", { args }); } throw new Error(`builtin '${sym.id}' not handled`); } } const callee = this.lowerExpr(expr.kind.value); return this.pushInst(ty, "Call", { callee, args }); } if (expr.is("UnaryExpr")) { return this.lowerUnaryExpr(expr); } if (expr.is("BinaryExpr")) { const leftTy = this.tys.expr(expr.kind.left); const rightTy = this.tys.expr(expr.kind.right); const binaryOp = binaryOpTests .map((test) => test(expr.kind.op, leftTy, rightTy, ty)) .filter((tested) => tested) .at(0); if (!binaryOp) { throw new Error( `'${expr.kind.op}' with '${ty.pretty()}' not handled`, ); } const left = this.lowerExpr(expr.kind.left); const right = this.lowerExpr(expr.kind.right); return this.pushInst(ty, binaryOp, { left, right }); } throw new Error(`'${expr.kind.tag}' not handled`); } private lowerUnaryExpr(expr: ast.NodeWithKind<"UnaryExpr">) { const resultTy = this.tys.expr(expr); const operandTy = this.tys.expr(expr.kind.expr); if ( expr.kind.op === "Neg" && operandTy.resolvableWith(Ty.I32) && resultTy.resolvableWith(Ty.I32) ) { const operand = this.lowerExpr(expr.kind.expr); return this.pushInst(Ty.I32, "Negate", { source: operand }); } if ( expr.kind.op === "Not" && operandTy.resolvableWith(Ty.Bool) && resultTy.resolvableWith(Ty.Bool) ) { const operand = this.lowerExpr(expr.kind.expr); return this.pushInst(Ty.Bool, "Not", { source: operand }); } if (expr.kind.op === "Ref" || expr.kind.op === "RefMut") { const place = expr.kind.expr; if (place.is("IdentExpr")) { const sym = this.syms.get(place); if (sym.tag === "Let") { const local = this.localMap.get(sym.param.id); if (!local) { throw new Error(); } return local; } throw new Error( `${expr.kind.op} with sym ${sym.tag} not handled`, ); } if (place.is("IndexExpr")) { const placeTy = this.tys.expr(place); const placeInst = this.lowerPlace(place); if (placeTy.is("Slice")) { return placeInst; } return this.pushInst(placeTy, "Load", { source: placeInst, }); } throw new Error( `${expr.kind.op} with place ${place.kind.tag} not handled`, ); } if (expr.kind.op === "Deref") { const source = this.lowerExpr(expr.kind.expr); return this.pushInst(resultTy, "Load", { source }); } throw new Error( `'${expr.kind.op}' with '${resultTy.pretty()}' not handled`, ); } private makeVoid(): Inst { return this.pushInst(Ty.Void, "Void", {}); } private pushInst< Tag extends InstKind["tag"], >( ty: Ty, tag: Tag, kind: Omit, ): Inst { const inst = new Inst(ty, { tag, ...kind } as InstKind); this.bbs.at(-1)!.insts.push(inst); return inst; } } type BinaryOpTest = ( op: ast.BinaryOp, left: Ty, right: Ty, result: Ty, ) => BinaryOp | null; const binaryOpTests: BinaryOpTest[] = [ (op, left, right, result) => { const ops: ast.BinaryOp[] = ["Add", "Sub", "Mul", "Div", "Rem"]; if ( ops.includes(op) && left.is("Int") && left.resolvableWith(right) && result.resolvableWith(left) ) { return op as BinaryOp; } return null; }, (op, left, right, result) => { const ops = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"]; if ( ops.includes(op) && left.is("Int") && left.resolvableWith(right) && result.is("Bool") ) { return op as BinaryOp; } return null; }, ];