diff --git a/src/ast.ts b/src/ast.ts index 4bd09da..e023866 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -72,6 +72,8 @@ export class Node { return visit(k.expr); case "LetStmt": return visit(k.param, k.expr); + case "IfStmt": + return visit(k.cond, k.truthy, k.falsy); case "Param": return visit(k.ty); case "IdentExpr": @@ -85,7 +87,7 @@ export class Node { case "IdentTy": return visit(); } - const _: never = k; + k satisfies never; } } @@ -104,6 +106,7 @@ export type NodeKind = } | { tag: "ReturnStmt"; expr: Node | null } | { tag: "LetStmt"; param: Node; expr: Node } + | { tag: "IfStmt"; cond: Node; truthy: Node; falsy: Node | null } | { tag: "Param"; ident: string; ty: Node | null } | { tag: "IdentExpr"; ident: string } | { tag: "IntExpr"; value: number } diff --git a/src/front/parse.ts b/src/front/parse.ts index 6c2f275..95508e3 100644 --- a/src/front/parse.ts +++ b/src/front/parse.ts @@ -48,6 +48,8 @@ export class Parser { return this.parseReturnStmt(); } else if (this.test("let")) { return this.parseLetStmt(); + } else if (this.test("if")) { + return this.parseIfStmt(); } else { const place = this.parseExpr(); if (this.eat("=")) { @@ -105,6 +107,18 @@ export class Parser { return ast.Node.create(loc, "LetStmt", { param, expr }); } + parseIfStmt(): ast.Node { + const loc = this.loc(); + this.step(); + const cond = this.parseExpr(); + const truthy = this.parseBlock(); + let falsy: ast.Node | null = null; + if (this.eat("else")) { + falsy = this.parseBlock(); + } + return ast.Node.create(loc, "IfStmt", { cond, truthy, falsy }); + } + parseParam(): ast.Node { const loc = this.loc(); const ident = this.mustEat("ident").value; diff --git a/src/main.ts b/src/main.ts index 0c8399a..5d5c300 100644 --- a/src/main.ts +++ b/src/main.ts @@ -1,7 +1,7 @@ import * as ast from "./ast.ts"; import * as front from "./front/mod.ts"; import * as middle from "./middle.ts"; -import { MirInterpreter } from "./mir_interpreter.ts"; +import { FnInterpreter } from "./mir_interpreter.ts"; const filename = Deno.args[0]; const text = await Deno.readTextFile(filename); @@ -38,5 +38,4 @@ if (!Deno.args.includes("--test")) { console.log(mainMiddleFn.pretty()); } -const interp = new MirInterpreter(); -interp.eval(mainMiddleFn); +new FnInterpreter(mainMiddleFn, []).eval(); diff --git a/src/middle.ts b/src/middle.ts index 7c0a972..e3d791e 100644 --- a/src/middle.ts +++ b/src/middle.ts @@ -21,7 +21,7 @@ export class MiddleLowerer { } class FnLowerer { - private bbs: BasicBlock[] = [new BasicBlock([])]; + private bbs: BasicBlock[] = [new BasicBlock()]; private localMap = new Map(); constructor( @@ -63,6 +63,44 @@ class FnLowerer { this.pushInst(Ty.Void, "Return", { source }); return; } + if (stmt.is("IfStmt")) { + const cond = this.lowerExpr(stmt.kind.cond); + const condBlock = this.bbs.at(-1)!; + + this.bbs.push(new BasicBlock()); + const truthy = this.bbs.at(-1)!; + this.lowerBlock(stmt.kind.truthy.as("Block")); + const truthyEnd = this.bbs.at(-1)!; + + let falsy: BasicBlock | null = null; + let falsyEnd: BasicBlock | null = null; + + if (stmt.kind.falsy) { + this.bbs.push(new BasicBlock()); + falsy = this.bbs.at(-1)!; + this.lowerBlock(stmt.kind.falsy.as("Block")); + falsyEnd = this.bbs.at(-1)!; + } + + this.bbs.push(new BasicBlock()); + const done = this.bbs.at(-1)!; + + condBlock.insts.push( + new Inst(Ty.Void, { + tag: "Branch", + cond, + truthy, + falsy: falsy ?? done, + }), + ); + truthyEnd.insts.push( + new Inst(Ty.Void, { tag: "Jump", target: falsy ?? done }), + ); + falsyEnd?.insts.push( + new Inst(Ty.Void, { tag: "Jump", target: done }), + ); + return; + } if (stmt.is("AssignStmt")) { const source = this.lowerExpr(stmt.kind.expr); const target = this.lowerAssignPlace(stmt.kind.place); @@ -257,9 +295,7 @@ class PrettyCx { } export class BasicBlock { - constructor( - public insts: Inst[], - ) {} + public insts: Inst[] = []; pretty(cx: PrettyCx): string { return `bb${cx.bbId(this)}:\n${ @@ -280,7 +316,9 @@ export class Inst { pretty(cx: PrettyCx): string { const r = (v: Inst) => `_${cx.regId(v)}`; - return `${r(this)}: ${this.ty.pretty()} = ${this.kind.tag}${ + return `${`${r(this)}:`.padEnd(4, " ")} ${ + this.ty.pretty().padEnd(4, " ") + } = ${this.kind.tag}${ (() => { const k = this.kind; switch (k.tag) { @@ -302,7 +340,13 @@ export class Inst { case "LocalLoad": return ` ${r(k.source)}`; case "LocalStore": - return ` ${r(k.target)}, ${r(k.source)}`; + return ` ${r(k.target)} = ${r(k.source)}`; + case "Jump": + return ` bb${cx.bbId(k.target)}`; + case "Branch": + return ` ${r(k.cond)} ? bb${cx.bbId(k.truthy)} : bb${ + cx.bbId(k.falsy) + }`; case "Return": return ` ${r(k.source)}`; case "Eq": @@ -325,7 +369,7 @@ export class Inst { case "DebugPrint": return ` ${k.args.map(r).join(", ")}`; } - const _: never = k; + k satisfies never; })() }`; } @@ -342,6 +386,8 @@ export type InstKind = | { tag: "AllocLocal" } | { tag: "LocalLoad"; source: Inst } | { tag: "LocalStore"; target: Inst; source: Inst } + | { tag: "Jump"; target: BasicBlock } + | { tag: "Branch"; cond: Inst; truthy: BasicBlock; falsy: BasicBlock } | { tag: "Return"; source: Inst } | { tag: BinaryOp; left: Inst; right: Inst } | { tag: "DebugPrint"; args: Inst[] }; diff --git a/src/mir_interpreter.ts b/src/mir_interpreter.ts index 54c98ec..41f0da8 100644 --- a/src/mir_interpreter.ts +++ b/src/mir_interpreter.ts @@ -1,19 +1,24 @@ import * as mir from "./middle.ts"; -export class MirInterpreter { - constructor() {} +export class FnInterpreter { + private regs = new Map(); + private locals: (Val | null)[] = []; + private localMap = new Map(); + private bb: mir.BasicBlock; + private instIdx = 0; - eval(fn: mir.Fn) { - this.evalFn(fn, []); + constructor( + private fn: mir.Fn, + private args: Val[], + ) { + this.bb = this.fn.bbs[0]; } - private evalFn(fn: mir.Fn, args: Val[]): Val { - const regs = new Map(); - const locals: (Val | null)[] = []; - const localMap = new Map(); + eval(): Val { + while (this.instIdx < this.bb.insts.length) { + const inst = this.bb.insts[this.instIdx]; + this.instIdx += 1; - let bb = fn.bbs[0]; - for (const inst of bb.insts) { const k = inst.kind; switch (k.tag) { case "Error": @@ -22,42 +27,61 @@ export class MirInterpreter { case "Int": case "Bool": case "Fn": - regs.set(inst, new Val(k)); - continue; + this.regs.set(inst, new Val(k)); + break; case "Param": - regs.set(inst, args[k.idx]); - continue; + this.regs.set(inst, this.args[k.idx]); + break; case "Call": { - const fn = regs.get(k.callee); + const fn = this.regs.get(k.callee); if (!fn || fn.kind.tag !== "Fn") { throw new Error(); } - const args = k.args.map((arg) => regs.get(arg)!); - const val = this.evalFn(fn.kind.fn, args); - regs.set(inst, val); - continue; + const args = k.args.map((arg) => this.regs.get(arg)!); + const val = new FnInterpreter(fn.kind.fn, args).eval(); + this.regs.set(inst, val); + break; } case "AllocLocal": - localMap.set(inst, locals.length); - locals.push(null); - continue; + this.localMap.set(inst, this.locals.length); + this.locals.push(null); + break; case "LocalLoad": - if (!localMap.has(k.source)) { + if (!this.localMap.has(k.source)) { throw new Error(); } - if (locals[localMap.get(k.source)!] === null) { + if (this.locals[this.localMap.get(k.source)!] === null) { throw new Error(); } - regs.set(inst, locals[localMap.get(k.source)!]!); - continue; + this.regs.set( + inst, + this.locals[this.localMap.get(k.source)!]!, + ); + break; case "LocalStore": - if (!localMap.has(k.target)) { + if (!this.localMap.has(k.target)) { throw new Error(); } - locals[localMap.get(k.target)!] = regs.get(k.source)!; - continue; + this.locals[this.localMap.get(k.target)!] = this.regs.get( + k.source, + )!; + break; + case "Jump": { + this.bb = k.target; + this.instIdx = 0; + break; + } + case "Branch": { + const cond = this.regs.get(k.cond)!; + if (cond.kind.tag !== "Bool") { + throw new Error(); + } + this.bb = cond.kind.value ? k.truthy : k.falsy; + this.instIdx = 0; + break; + } case "Return": - return regs.get(k.source)!; + return this.regs.get(k.source)!; case "Eq": case "Ne": case "Lt": @@ -73,71 +97,77 @@ export class MirInterpreter { case "Sub": case "Mul": case "Div": - case "Rem": { - const left = regs.get(k.left)!; - const right = regs.get(k.right)!; - const lk = left.kind; - const rk = right.kind; - - if (lk.tag === "Int" && rk.tag === "Int") { - const left = lk.value; - const right = lk.value; - - const value = (() => { - const Int = (value: number) => - new Val({ tag: "Int", value }); - const Bool = (value: boolean) => - new Val({ tag: "Bool", value }); - - switch (k.tag) { - case "Eq": - return Bool(left === right); - case "Ne": - return Bool(left !== right); - case "Lt": - return Bool(left < right); - case "Gt": - return Bool(left > right); - case "Lte": - return Bool(left <= right); - case "Gte": - return Bool(left >= right); - case "BitOr": - case "BitXor": - case "BitAnd": - case "Shl": - case "Shr": - break; - case "Add": - return Int(left + right); - case "Sub": - return Int(left - right); - case "Mul": - return Int(left * right); - case "Div": - return Int(Math.floor(left / right)); - case "Rem": - return Int(left % right); - } - throw new Error(`'${k.tag}' not handled`); - })(); - regs.set(inst, value); - continue; - } - throw new Error(`'${k.tag}' not handled`); - } + case "Rem": + this.evalBinaryOp(inst, k); + break; case "DebugPrint": console.log( k.args - .map((a) => regs.get(a)!.pretty()) + .map((a) => this.regs.get(a)!.pretty()) .join(", "), ); - continue; + break; + default: + k satisfies never; } - k satisfies never; } return Val.Void; } + + private evalBinaryOp( + inst: mir.Inst, + k: mir.InstKind & { tag: mir.BinaryOp }, + ) { + const left = this.regs.get(k.left)!; + const right = this.regs.get(k.right)!; + + if (left.kind.tag === "Int" && right.kind.tag === "Int") { + const l = left.kind.value; + const r = right.kind.value; + + const value = (() => { + const Int = (value: number) => new Val({ tag: "Int", value }); + const Bool = (value: boolean) => + new Val({ tag: "Bool", value }); + + switch (k.tag) { + case "Eq": + return Bool(l === r); + case "Ne": + return Bool(l !== r); + case "Lt": + return Bool(l < r); + case "Gt": + return Bool(l > r); + case "Lte": + return Bool(l <= r); + case "Gte": + return Bool(l >= r); + case "BitOr": + case "BitXor": + case "BitAnd": + case "Shl": + case "Shr": + break; + case "Add": + return Int(l + r); + case "Sub": + return Int(l - r); + case "Mul": + return Int(l * r); + case "Div": + return Int(Math.floor(l / r)); + case "Rem": + return Int(l % r); + } + throw new Error(`'${k.tag}' not handled`); + })(); + + this.regs.set(inst, value); + return; + } + throw new Error(`'${k.tag}' not handled`); + } } class Val {