451 lines
15 KiB
TypeScript
451 lines
15 KiB
TypeScript
import * as ast from "./ast.ts";
|
|
import { Syms, Tys } from "./front/mod.ts";
|
|
import { Ty } from "./ty.ts";
|
|
import { BasicBlock, BinaryOp, Fn, Inst, InstKind } from "./mir.ts";
|
|
|
|
export class MiddleLowerer {
|
|
private fns = new Map<number, Fn>();
|
|
|
|
constructor(
|
|
private syms: Syms,
|
|
private tys: Tys,
|
|
) {}
|
|
|
|
lowerFn(stmt: ast.FnStmt): Fn {
|
|
if (this.fns.has(stmt.id)) {
|
|
return this.fns.get(stmt.id)!;
|
|
}
|
|
const fn = new FnLowerer(this, this.syms, this.tys, stmt).lower();
|
|
this.fns.set(stmt.id, fn);
|
|
return fn;
|
|
}
|
|
|
|
allFns(): Fn[] {
|
|
return this.fns.values().toArray();
|
|
}
|
|
}
|
|
|
|
class FnLowerer {
|
|
private allocs: Inst[] = [];
|
|
private bbs: BasicBlock[] = [new BasicBlock()];
|
|
private localMap = new Map<number, Inst>();
|
|
private loopEndMap = new Map<number, BasicBlock>();
|
|
|
|
constructor(
|
|
private lowerer: MiddleLowerer,
|
|
private syms: Syms,
|
|
private tys: Tys,
|
|
private stmt: ast.FnStmt,
|
|
) {}
|
|
|
|
lower(): Fn {
|
|
const ty = this.tys.fnStmt(this.stmt);
|
|
this.lowerBlock(this.stmt.kind.body.as("Block"));
|
|
this.pushInst(Ty.Void, "Return", { source: this.makeVoid() });
|
|
this.bbs[0].insts.unshift(...this.allocs);
|
|
return new Fn(this.stmt, ty, this.bbs);
|
|
}
|
|
|
|
private lowerBlock(block: ast.Block) {
|
|
for (const stmt of block.kind.stmts) {
|
|
this.lowerStmt(stmt);
|
|
}
|
|
}
|
|
|
|
private lowerStmt(stmt: ast.Node) {
|
|
if (stmt.is("LetStmt")) {
|
|
return this.lowerLetStmt(stmt);
|
|
}
|
|
if (stmt.is("ReturnStmt")) {
|
|
return this.lowerReturnStmt(stmt);
|
|
}
|
|
if (stmt.is("IfStmt")) {
|
|
return this.lowerIfStmt(stmt);
|
|
}
|
|
if (stmt.is("WhileStmt")) {
|
|
return this.lowerWhileStmt(stmt);
|
|
}
|
|
if (stmt.is("BreakStmt")) {
|
|
return this.lowerBreakStmt(stmt);
|
|
}
|
|
if (stmt.is("AssignStmt")) {
|
|
return this.lowerAssignStmt(stmt);
|
|
}
|
|
if (stmt.is("ExprStmt")) {
|
|
return this.lowerExpr(stmt.kind.expr);
|
|
}
|
|
throw new Error(`'${stmt.kind.tag}' not handled`);
|
|
}
|
|
|
|
private lowerReturnStmt(stmt: ast.NodeWithKind<"ReturnStmt">) {
|
|
const source = stmt.kind.expr
|
|
? this.lowerExpr(stmt.kind.expr)
|
|
: this.makeVoid();
|
|
this.pushInst(Ty.Void, "Return", { source });
|
|
this.bbs.push(new BasicBlock());
|
|
}
|
|
|
|
private lowerIfStmt(stmt: ast.NodeWithKind<"IfStmt">) {
|
|
const cond = this.lowerExpr(stmt.kind.cond);
|
|
const condBlock = this.bbs.at(-1)!;
|
|
|
|
this.bbs.push(new BasicBlock());
|
|
const truthy = this.bbs.at(-1)!;
|
|
this.lowerBlock(stmt.kind.truthy.as("Block"));
|
|
const truthyEnd = this.bbs.at(-1)!;
|
|
|
|
let falsy: BasicBlock | null = null;
|
|
let falsyEnd: BasicBlock | null = null;
|
|
|
|
if (stmt.kind.falsy) {
|
|
this.bbs.push(new BasicBlock());
|
|
falsy = this.bbs.at(-1)!;
|
|
this.lowerBlock(stmt.kind.falsy.as("Block"));
|
|
falsyEnd = this.bbs.at(-1)!;
|
|
}
|
|
|
|
this.bbs.push(new BasicBlock());
|
|
const done = this.bbs.at(-1)!;
|
|
|
|
condBlock.insts.push(
|
|
new Inst(Ty.Void, {
|
|
tag: "Branch",
|
|
cond,
|
|
truthy,
|
|
falsy: falsy ?? done,
|
|
}),
|
|
);
|
|
truthyEnd.insts.push(
|
|
new Inst(Ty.Void, { tag: "Jump", target: falsy ?? done }),
|
|
);
|
|
falsyEnd?.insts.push(
|
|
new Inst(Ty.Void, { tag: "Jump", target: done }),
|
|
);
|
|
}
|
|
|
|
private lowerWhileStmt(stmt: ast.NodeWithKind<"WhileStmt">) {
|
|
const before = this.bbs.at(-1)!;
|
|
|
|
this.bbs.push(new BasicBlock());
|
|
const body = this.bbs.at(-1)!;
|
|
|
|
const after = new BasicBlock();
|
|
this.loopEndMap.set(stmt.id, after);
|
|
this.lowerBlock(stmt.kind.body.as("Block"));
|
|
|
|
const bodyEnd = this.bbs.at(-1)!;
|
|
|
|
this.bbs.push(new BasicBlock());
|
|
const condBlock = this.bbs.at(-1)!;
|
|
const cond = this.lowerExpr(stmt.kind.cond);
|
|
const condBlockEnd = this.bbs.at(-1)!;
|
|
|
|
this.bbs.push(after);
|
|
|
|
before.insts.push(
|
|
new Inst(Ty.Void, { tag: "Jump", target: condBlock }),
|
|
);
|
|
condBlockEnd.insts.push(
|
|
new Inst(Ty.Void, {
|
|
tag: "Branch",
|
|
cond: cond,
|
|
truthy: body,
|
|
falsy: after,
|
|
}),
|
|
);
|
|
bodyEnd.insts.push(
|
|
new Inst(Ty.Void, { tag: "Jump", target: condBlock }),
|
|
);
|
|
}
|
|
|
|
private lowerBreakStmt(stmt: ast.NodeWithKind<"BreakStmt">) {
|
|
const sym = this.syms.get(stmt);
|
|
if (sym.tag !== "Loop") {
|
|
throw new Error();
|
|
}
|
|
const loopEnd = this.loopEndMap.get(sym.stmt.id);
|
|
if (!loopEnd) {
|
|
throw new Error();
|
|
}
|
|
this.pushInst(Ty.Void, "Jump", { target: loopEnd });
|
|
this.bbs.push(new BasicBlock());
|
|
}
|
|
|
|
private lowerLetStmt(stmt: ast.NodeWithKind<"LetStmt">) {
|
|
const ty = this.tys.param(stmt.kind.param.as("Param"));
|
|
const expr = this.lowerExpr(stmt.kind.expr);
|
|
const local = new Inst(
|
|
Ty.create("PtrMut", { ty }),
|
|
{ tag: "Alloca" },
|
|
);
|
|
this.allocs.push(local);
|
|
this.pushInst(Ty.Void, "Store", {
|
|
target: local,
|
|
source: expr,
|
|
});
|
|
this.localMap.set(stmt.kind.param.id, local);
|
|
}
|
|
|
|
private lowerAssignStmt(stmt: ast.NodeWithKind<"AssignStmt">) {
|
|
const source = this.lowerExpr(stmt.kind.expr);
|
|
const target = this.lowerPlace(stmt.kind.place);
|
|
this.pushInst(Ty.Void, "Store", { target, source });
|
|
}
|
|
|
|
private lowerPlace(place: ast.Node): Inst {
|
|
// evaluate to most direct pointer
|
|
|
|
const _ty = this.tys.place(place);
|
|
|
|
if (place.is("IdentExpr")) {
|
|
const sym = this.syms.get(place);
|
|
if (sym.tag === "Let") {
|
|
const local = this.localMap.get(sym.param.id);
|
|
if (!local) {
|
|
throw new Error();
|
|
}
|
|
return local;
|
|
}
|
|
if (sym.tag === "FnParam") {
|
|
return this.lowerExpr(place);
|
|
}
|
|
throw new Error(`'${sym.tag}' not handled`);
|
|
}
|
|
|
|
if (place.is("UnaryExpr") && place.kind.op === "Deref") {
|
|
return this.lowerExpr(place.kind.expr);
|
|
}
|
|
|
|
if (place.is("IndexExpr")) {
|
|
const value = place.kind.value;
|
|
const valueTy = this.tys.place(value);
|
|
const arg = place.kind.arg;
|
|
const argTy = this.tys.expr(arg);
|
|
if (valueTy.is("Array") || valueTy.is("Slice")) {
|
|
const valueInst = this.lowerPlace(place.kind.value);
|
|
if (argTy.is("Int")) {
|
|
const argInst = this.lowerExpr(arg);
|
|
return this.pushInst(
|
|
Ty.create("PtrMut", { ty: valueTy.kind.ty }),
|
|
"GetElemPtr",
|
|
{ base: valueInst, offset: argInst },
|
|
);
|
|
}
|
|
if (argTy.is("Range")) {
|
|
if (!arg.is("RangeExpr")) {
|
|
throw new Error("not supported yet");
|
|
}
|
|
const begin = arg.kind.begin &&
|
|
this.lowerExpr(arg.kind.begin);
|
|
const end = arg.kind.end &&
|
|
this.lowerExpr(arg.kind.end);
|
|
return this.pushInst(
|
|
Ty.create("PtrMut", {
|
|
ty: Ty.create("Slice", { ty: valueTy.kind.ty }),
|
|
}),
|
|
"Slice",
|
|
{ value: valueInst, begin, end },
|
|
);
|
|
}
|
|
}
|
|
throw new Error(
|
|
`${place.kind.tag} with arg ${argTy.pretty()} not handled`,
|
|
);
|
|
}
|
|
|
|
throw new Error(`'${place.kind.tag}' not handled`);
|
|
}
|
|
|
|
private lowerExpr(expr: ast.Node): Inst {
|
|
const ty = this.tys.expr(expr);
|
|
if (expr.is("IdentExpr")) {
|
|
const sym = this.syms.get(expr);
|
|
if (sym.tag === "Fn") {
|
|
const fn = this.lowerer.lowerFn(sym.stmt);
|
|
return this.pushInst(fn.ty, "Fn", { fn });
|
|
}
|
|
if (sym.tag === "FnParam") {
|
|
const ty = this.tys.expr(sym.param);
|
|
return this.pushInst(ty, "Param", { idx: sym.idx });
|
|
}
|
|
if (sym.tag === "Builtin") {
|
|
throw new Error("handle elsewhere");
|
|
}
|
|
if (sym.tag === "Let") {
|
|
const source = this.lowerPlace(expr);
|
|
return this.pushInst(ty, "Load", { source });
|
|
}
|
|
if (sym.tag === "Bool") {
|
|
return this.pushInst(Ty.Bool, "Bool", { value: sym.value });
|
|
}
|
|
throw new Error(`'${sym.tag}' not handled`);
|
|
}
|
|
if (expr.is("IntExpr")) {
|
|
return this.pushInst(ty, "Int", {
|
|
value: expr.kind.value,
|
|
intTy: expr.kind.intTy,
|
|
});
|
|
}
|
|
if (expr.is("StrExpr")) {
|
|
return this.pushInst(ty, "Str", { value: expr.kind.value });
|
|
}
|
|
if (expr.is("ArrayExpr")) {
|
|
const ty = this.tys.expr(expr);
|
|
const values = expr.kind.values
|
|
.map((value) => this.lowerExpr(value));
|
|
return this.pushInst(ty, "Array", { values });
|
|
}
|
|
if (expr.is("IndexExpr")) {
|
|
const source = this.lowerPlace(expr);
|
|
return this.pushInst(ty, "Load", { source });
|
|
}
|
|
if (expr.is("CallExpr")) {
|
|
const args = expr.kind.args
|
|
.map((arg) => this.lowerExpr(arg));
|
|
|
|
if (expr.kind.value.is("IdentExpr")) {
|
|
const sym = this.syms.get(expr.kind.value);
|
|
if (sym.tag === "Builtin") {
|
|
if (sym.id === "len") {
|
|
return this.pushInst(ty, "Len", { source: args[0] });
|
|
}
|
|
if (sym.id === "print") {
|
|
return this.pushInst(ty, "DebugPrint", { args });
|
|
}
|
|
throw new Error(`builtin '${sym.id}' not handled`);
|
|
}
|
|
}
|
|
|
|
const callee = this.lowerExpr(expr.kind.value);
|
|
return this.pushInst(ty, "Call", { callee, args });
|
|
}
|
|
if (expr.is("UnaryExpr")) {
|
|
return this.lowerUnaryExpr(expr);
|
|
}
|
|
if (expr.is("BinaryExpr")) {
|
|
const leftTy = this.tys.expr(expr.kind.left);
|
|
const rightTy = this.tys.expr(expr.kind.right);
|
|
const binaryOp = binaryOpTests
|
|
.map((test) => test(expr.kind.op, leftTy, rightTy, ty))
|
|
.filter((tested) => tested)
|
|
.at(0);
|
|
if (!binaryOp) {
|
|
throw new Error(
|
|
`'${expr.kind.op}' with '${ty.pretty()}' not handled`,
|
|
);
|
|
}
|
|
const left = this.lowerExpr(expr.kind.left);
|
|
const right = this.lowerExpr(expr.kind.right);
|
|
return this.pushInst(ty, binaryOp, { left, right });
|
|
}
|
|
throw new Error(`'${expr.kind.tag}' not handled`);
|
|
}
|
|
|
|
private lowerUnaryExpr(expr: ast.NodeWithKind<"UnaryExpr">) {
|
|
const resultTy = this.tys.expr(expr);
|
|
const operandTy = this.tys.expr(expr.kind.expr);
|
|
if (
|
|
expr.kind.op === "Neg" &&
|
|
operandTy.resolvableWith(Ty.I32) &&
|
|
resultTy.resolvableWith(Ty.I32)
|
|
) {
|
|
const operand = this.lowerExpr(expr.kind.expr);
|
|
return this.pushInst(Ty.I32, "Negate", { source: operand });
|
|
}
|
|
if (
|
|
expr.kind.op === "Not" &&
|
|
operandTy.resolvableWith(Ty.Bool) &&
|
|
resultTy.resolvableWith(Ty.Bool)
|
|
) {
|
|
const operand = this.lowerExpr(expr.kind.expr);
|
|
return this.pushInst(Ty.Bool, "Not", { source: operand });
|
|
}
|
|
if (expr.kind.op === "Ref" || expr.kind.op === "RefMut") {
|
|
const place = expr.kind.expr;
|
|
if (place.is("IdentExpr")) {
|
|
const sym = this.syms.get(place);
|
|
if (sym.tag === "Let") {
|
|
const local = this.localMap.get(sym.param.id);
|
|
if (!local) {
|
|
throw new Error();
|
|
}
|
|
return local;
|
|
}
|
|
throw new Error(
|
|
`${expr.kind.op} with sym ${sym.tag} not handled`,
|
|
);
|
|
}
|
|
if (place.is("IndexExpr")) {
|
|
const placeTy = this.tys.expr(place);
|
|
const placeInst = this.lowerPlace(place);
|
|
if (placeTy.is("Slice")) {
|
|
return placeInst;
|
|
}
|
|
return this.pushInst(placeTy, "Load", {
|
|
source: placeInst,
|
|
});
|
|
}
|
|
throw new Error(
|
|
`${expr.kind.op} with place ${place.kind.tag} not handled`,
|
|
);
|
|
}
|
|
if (expr.kind.op === "Deref") {
|
|
const source = this.lowerExpr(expr.kind.expr);
|
|
return this.pushInst(resultTy, "Load", { source });
|
|
}
|
|
throw new Error(
|
|
`'${expr.kind.op}' with '${resultTy.pretty()}' not handled`,
|
|
);
|
|
}
|
|
|
|
private makeVoid(): Inst {
|
|
return this.pushInst(Ty.Void, "Void", {});
|
|
}
|
|
|
|
private pushInst<
|
|
Tag extends InstKind["tag"],
|
|
>(
|
|
ty: Ty,
|
|
tag: Tag,
|
|
kind: Omit<InstKind & { tag: Tag }, "tag">,
|
|
): Inst {
|
|
const inst = new Inst(ty, { tag, ...kind } as InstKind);
|
|
this.bbs.at(-1)!.insts.push(inst);
|
|
return inst;
|
|
}
|
|
}
|
|
|
|
type BinaryOpTest = (
|
|
op: ast.BinaryOp,
|
|
left: Ty,
|
|
right: Ty,
|
|
result: Ty,
|
|
) => BinaryOp | null;
|
|
|
|
const binaryOpTests: BinaryOpTest[] = [
|
|
(op, left, right, result) => {
|
|
const ops: ast.BinaryOp[] = ["Add", "Sub", "Mul", "Div", "Rem"];
|
|
if (
|
|
ops.includes(op) &&
|
|
left.is("Int") &&
|
|
left.resolvableWith(right) &&
|
|
result.resolvableWith(left)
|
|
) {
|
|
return op as BinaryOp;
|
|
}
|
|
return null;
|
|
},
|
|
(op, left, right, result) => {
|
|
const ops = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"];
|
|
if (
|
|
ops.includes(op) &&
|
|
left.is("Int") &&
|
|
left.resolvableWith(right) &&
|
|
result.is("Bool")
|
|
) {
|
|
return op as BinaryOp;
|
|
}
|
|
return null;
|
|
},
|
|
];
|