import * as mir from "./middle.ts"; export class FnInterpreter { private regs = new Map(); private locals: (Val | null)[] = []; private localMap = new Map(); private bb: mir.BasicBlock; private instIdx = 0; constructor( private fn: mir.Fn, private args: Val[], ) { this.bb = this.fn.bbs[0]; } eval(): Val { while (this.instIdx < this.bb.insts.length) { const inst = this.bb.insts[this.instIdx]; this.instIdx += 1; const k = inst.kind; switch (k.tag) { case "Error": throw new Error(); case "Void": case "Int": case "Bool": case "Fn": this.regs.set(inst, new Val(k)); break; case "Param": this.regs.set(inst, this.args[k.idx]); break; case "Call": { const fn = this.regs.get(k.callee); if (!fn || fn.kind.tag !== "Fn") { throw new Error(); } const args = k.args.map((arg) => this.regs.get(arg)!); const val = new FnInterpreter(fn.kind.fn, args).eval(); this.regs.set(inst, val); break; } case "AllocLocal": this.localMap.set(inst, this.locals.length); this.locals.push(null); break; case "LocalLoad": if (!this.localMap.has(k.source)) { throw new Error(); } if (this.locals[this.localMap.get(k.source)!] === null) { throw new Error(); } this.regs.set( inst, this.locals[this.localMap.get(k.source)!]!, ); break; case "LocalStore": if (!this.localMap.has(k.target)) { throw new Error(); } this.locals[this.localMap.get(k.target)!] = this.regs.get( k.source, )!; break; case "Jump": { this.bb = k.target; this.instIdx = 0; break; } case "Branch": { const cond = this.regs.get(k.cond)!; if (cond.kind.tag !== "Bool") { throw new Error(); } this.bb = cond.kind.value ? k.truthy : k.falsy; this.instIdx = 0; break; } case "Return": return this.regs.get(k.source)!; case "Not": { const source = this.regs.get(k.source)!; if (source.kind.tag !== "Bool") { throw new Error(); } this.regs.set( inst, new Val({ tag: "Bool", value: !source.kind.value }), ); break; } case "Negate": { const source = this.regs.get(k.source)!; if (source.kind.tag !== "Int") { throw new Error(); } this.regs.set( inst, new Val({ tag: "Int", value: -source.kind.value }), ); break; } case "Eq": case "Ne": case "Lt": case "Gt": case "Lte": case "Gte": case "BitOr": case "BitXor": case "BitAnd": case "Shl": case "Shr": case "Add": case "Sub": case "Mul": case "Div": case "Rem": this.evalBinaryOp(inst, k); break; case "DebugPrint": console.log( k.args .map((a) => this.regs.get(a)!.pretty()) .join(", "), ); break; default: k satisfies never; } } return Val.Void; } private evalBinaryOp( inst: mir.Inst, k: mir.InstKind & { tag: mir.BinaryOp }, ) { const left = this.regs.get(k.left)!; const right = this.regs.get(k.right)!; if (left.kind.tag === "Int" && right.kind.tag === "Int") { const l = left.kind.value; const r = right.kind.value; const value = (() => { const Int = (value: number) => new Val({ tag: "Int", value }); const Bool = (value: boolean) => new Val({ tag: "Bool", value }); switch (k.tag) { case "Eq": return Bool(l === r); case "Ne": return Bool(l !== r); case "Lt": return Bool(l < r); case "Gt": return Bool(l > r); case "Lte": return Bool(l <= r); case "Gte": return Bool(l >= r); case "BitOr": case "BitXor": case "BitAnd": case "Shl": case "Shr": break; case "Add": return Int(l + r); case "Sub": return Int(l - r); case "Mul": return Int(l * r); case "Div": return Int(Math.floor(l / r)); case "Rem": return Int(l % r); } throw new Error(`'${k.tag}' not handled`); })(); this.regs.set(inst, value); return; } throw new Error(`'${k.tag}' not handled`); } } class Val { constructor( public kind: ValKind, ) {} static Void = new Val({ tag: "Void" }); pretty() { const k = this.kind; switch (k.tag) { case "Void": return "void"; case "Int": case "Bool": return `${k.value}`; case "Fn": return `<${k.fn.ty.pretty()}>`; default: k satisfies never; } } } type ValKind = | { tag: "Void" } | { tag: "Int"; value: number } | { tag: "Bool"; value: boolean } | { tag: "Fn"; fn: mir.Fn };