ethos/src/middle.ts
sfja a42917b485
All checks were successful
Check / Explore-Gitea-Actions (push) Successful in 8s
add debug_print
2026-03-17 18:55:54 +01:00

437 lines
14 KiB
TypeScript

import * as ast from "./ast.ts";
import { Syms, Tys } from "./front/mod.ts";
import { Ty } from "./ty.ts";
import { BasicBlock, BinaryOp, Fn, Inst, InstKind } from "./mir.ts";
export class MiddleLowerer {
private fns = new Map<number, Fn>();
constructor(
private syms: Syms,
private tys: Tys,
) {}
lowerFn(stmt: ast.FnStmt): Fn {
if (this.fns.has(stmt.id)) {
return this.fns.get(stmt.id)!;
}
const fn = new FnLowerer(this, this.syms, this.tys, stmt).lower();
this.fns.set(stmt.id, fn);
return fn;
}
allFns(): Fn[] {
return this.fns.values().toArray();
}
}
class FnLowerer {
private allocs: Inst[] = [];
private bbs: BasicBlock[] = [new BasicBlock()];
private localMap = new Map<number, Inst>();
private loopEndMap = new Map<number, BasicBlock>();
constructor(
private lowerer: MiddleLowerer,
private syms: Syms,
private tys: Tys,
private stmt: ast.FnStmt,
) {}
lower(): Fn {
const ty = this.tys.fnStmt(this.stmt);
this.lowerBlock(this.stmt.kind.body.as("Block"));
this.pushInst(Ty.Void, "Return", { source: this.makeVoid() });
this.bbs[0].insts.unshift(...this.allocs);
return new Fn(this.stmt, ty, this.bbs);
}
private lowerBlock(block: ast.Block) {
for (const stmt of block.kind.stmts) {
this.lowerStmt(stmt);
}
}
private lowerStmt(stmt: ast.Node) {
if (stmt.is("LetStmt")) {
return this.lowerLetStmt(stmt);
}
if (stmt.is("ReturnStmt")) {
return this.lowerReturnStmt(stmt);
}
if (stmt.is("IfStmt")) {
return this.lowerIfStmt(stmt);
}
if (stmt.is("WhileStmt")) {
return this.lowerWhileStmt(stmt);
}
if (stmt.is("BreakStmt")) {
return this.lowerBreakStmt(stmt);
}
if (stmt.is("AssignStmt")) {
return this.lowerAssignStmt(stmt);
}
if (stmt.is("ExprStmt")) {
return this.lowerExpr(stmt.kind.expr);
}
throw new Error(`'${stmt.kind.tag}' not handled`);
}
private lowerReturnStmt(stmt: ast.NodeWithKind<"ReturnStmt">) {
const source = stmt.kind.expr
? this.lowerExpr(stmt.kind.expr)
: this.makeVoid();
this.pushInst(Ty.Void, "Return", { source });
this.bbs.push(new BasicBlock());
}
private lowerIfStmt(stmt: ast.NodeWithKind<"IfStmt">) {
const cond = this.lowerExpr(stmt.kind.cond);
const condBlock = this.bbs.at(-1)!;
this.bbs.push(new BasicBlock());
const truthy = this.bbs.at(-1)!;
this.lowerBlock(stmt.kind.truthy.as("Block"));
const truthyEnd = this.bbs.at(-1)!;
let falsy: BasicBlock | null = null;
let falsyEnd: BasicBlock | null = null;
if (stmt.kind.falsy) {
this.bbs.push(new BasicBlock());
falsy = this.bbs.at(-1)!;
this.lowerBlock(stmt.kind.falsy.as("Block"));
falsyEnd = this.bbs.at(-1)!;
}
this.bbs.push(new BasicBlock());
const done = this.bbs.at(-1)!;
condBlock.insts.push(
new Inst(Ty.Void, {
tag: "Branch",
cond,
truthy,
falsy: falsy ?? done,
}),
);
truthyEnd.insts.push(
new Inst(Ty.Void, { tag: "Jump", target: falsy ?? done }),
);
falsyEnd?.insts.push(
new Inst(Ty.Void, { tag: "Jump", target: done }),
);
}
private lowerWhileStmt(stmt: ast.NodeWithKind<"WhileStmt">) {
const before = this.bbs.at(-1)!;
this.bbs.push(new BasicBlock());
const body = this.bbs.at(-1)!;
const after = new BasicBlock();
this.loopEndMap.set(stmt.id, after);
this.lowerBlock(stmt.kind.body.as("Block"));
const bodyEnd = this.bbs.at(-1)!;
this.bbs.push(new BasicBlock());
const condBlock = this.bbs.at(-1)!;
const cond = this.lowerExpr(stmt.kind.cond);
const condBlockEnd = this.bbs.at(-1)!;
this.bbs.push(after);
before.insts.push(
new Inst(Ty.Void, { tag: "Jump", target: condBlock }),
);
condBlockEnd.insts.push(
new Inst(Ty.Void, {
tag: "Branch",
cond: cond,
truthy: body,
falsy: after,
}),
);
bodyEnd.insts.push(
new Inst(Ty.Void, { tag: "Jump", target: condBlock }),
);
}
private lowerBreakStmt(stmt: ast.NodeWithKind<"BreakStmt">) {
const sym = this.syms.get(stmt);
if (sym.tag !== "Loop") {
throw new Error();
}
const loopEnd = this.loopEndMap.get(sym.stmt.id);
if (!loopEnd) {
throw new Error();
}
this.pushInst(Ty.Void, "Jump", { target: loopEnd });
this.bbs.push(new BasicBlock());
}
private lowerLetStmt(stmt: ast.NodeWithKind<"LetStmt">) {
const ty = this.tys.expr(stmt.kind.param);
const expr = this.lowerExpr(stmt.kind.expr);
const local = new Inst(
Ty.create("PtrMut", { ty }),
{ tag: "Alloca" },
);
this.allocs.push(local);
this.pushInst(Ty.Void, "Store", {
target: local,
source: expr,
});
this.localMap.set(stmt.kind.param.id, local);
}
private lowerAssignStmt(stmt: ast.NodeWithKind<"AssignStmt">) {
const source = this.lowerExpr(stmt.kind.expr);
const target = this.lowerPlace(stmt.kind.place);
this.pushInst(Ty.Void, "Store", { target, source });
}
private lowerPlace(place: ast.Node): Inst {
// evaluate to most direct pointer
const _ty = this.tys.place(place);
if (place.is("IdentExpr")) {
const sym = this.syms.get(place);
if (sym.tag === "Let") {
const local = this.localMap.get(sym.param.id);
if (!local) {
throw new Error();
}
return local;
}
if (sym.tag === "FnParam") {
return this.lowerExpr(place);
}
throw new Error(`'${sym.tag}' not handled`);
}
if (place.is("UnaryExpr") && place.kind.op === "Deref") {
return this.lowerExpr(place.kind.expr);
}
if (place.is("IndexExpr")) {
const value = place.kind.value;
const valueTy = this.tys.place(value);
const arg = place.kind.arg;
const argTy = this.tys.expr(arg);
if (!valueTy.is("Array") && !valueTy.is("Slice")) {
throw new Error();
}
const valueInst = this.lowerPlace(place.kind.value);
if (argTy.is("Int")) {
const argInst = this.lowerExpr(arg);
return this.pushInst(
Ty.create("PtrMut", { ty: valueTy.kind.ty }),
"GetElemPtr",
{ base: valueInst, offset: argInst },
);
}
if (argTy.is("Range")) {
if (!arg.is("RangeExpr")) {
throw new Error("not supported yet");
}
const begin = arg.kind.begin &&
this.lowerExpr(arg.kind.begin);
const end = arg.kind.end &&
this.lowerExpr(arg.kind.end);
return this.pushInst(
Ty.create("PtrMut", {
ty: Ty.create("Slice", { ty: valueTy.kind.ty }),
}),
"Slice",
{ value: valueInst, begin, end },
);
}
throw new Error(
`${place.kind.tag} with arg ${argTy.pretty()} not handled`,
);
}
throw new Error(`'${place.kind.tag}' not handled`);
}
private lowerExpr(expr: ast.Node): Inst {
const ty = this.tys.expr(expr);
if (expr.is("IdentExpr")) {
const sym = this.syms.get(expr);
if (sym.tag === "Fn") {
const fn = this.lowerer.lowerFn(sym.stmt);
return this.pushInst(fn.ty, "Fn", { fn });
}
if (sym.tag === "FnParam") {
const ty = this.tys.expr(sym.param);
return this.pushInst(ty, "Param", { idx: sym.idx });
}
if (sym.tag === "Builtin") {
throw new Error("handle elsewhere");
}
if (sym.tag === "Let") {
const source = this.lowerPlace(expr);
return this.pushInst(ty, "Load", { source });
}
if (sym.tag === "Bool") {
return this.pushInst(Ty.Bool, "Bool", { value: sym.value });
}
throw new Error(`'${sym.tag}' not handled`);
}
if (expr.is("IntExpr")) {
return this.pushInst(Ty.Int, "Int", { value: expr.kind.value });
}
if (expr.is("ArrayExpr")) {
const ty = this.tys.expr(expr);
const values = expr.kind.values
.map((value) => this.lowerExpr(value));
return this.pushInst(ty, "Array", { values });
}
if (expr.is("IndexExpr")) {
const source = this.lowerPlace(expr);
return this.pushInst(ty, "Load", { source });
}
if (expr.is("CallExpr")) {
const ty = this.tys.expr(expr);
const args = expr.kind.args
.map((arg) => this.lowerExpr(arg));
if (expr.kind.value.is("IdentExpr")) {
const sym = this.syms.get(expr.kind.value);
if (sym.tag === "Builtin") {
if (sym.id === "debug_print") {
return this.pushInst(ty, "DebugPrint", { args });
}
throw new Error(`builtin '${sym.id}' not handled`);
}
}
const callee = this.lowerExpr(expr.kind.value);
return this.pushInst(ty, "Call", { callee, args });
}
if (expr.is("UnaryExpr")) {
return this.lowerUnaryExpr(expr);
}
if (expr.is("BinaryExpr")) {
const resultTy = this.tys.expr(expr);
const leftTy = this.tys.expr(expr.kind.left);
const rightTy = this.tys.expr(expr.kind.right);
const binaryOp = binaryOpPatterns
.find((pat) =>
expr.kind.op === pat.op &&
resultTy.compatibleWith(pat.result) &&
leftTy.compatibleWith(pat.left ?? pat.result) &&
rightTy.compatibleWith(pat.right ?? pat.left ?? pat.result)
);
if (!binaryOp) {
throw new Error(
`'${expr.kind.op}' with '${resultTy.pretty()}' not handled`,
);
}
const left = this.lowerExpr(expr.kind.left);
const right = this.lowerExpr(expr.kind.right);
return this.pushInst(resultTy, binaryOp.tag, { left, right });
}
throw new Error(`'${expr.kind.tag}' not handled`);
}
private lowerUnaryExpr(expr: ast.NodeWithKind<"UnaryExpr">) {
const resultTy = this.tys.expr(expr);
const operandTy = this.tys.expr(expr.kind.expr);
if (
expr.kind.op === "Negate" &&
operandTy.compatibleWith(Ty.Int) &&
resultTy.compatibleWith(Ty.Int)
) {
const operand = this.lowerExpr(expr.kind.expr);
return this.pushInst(Ty.Int, "Negate", { source: operand });
}
if (
expr.kind.op === "Not" &&
operandTy.compatibleWith(Ty.Bool) &&
resultTy.compatibleWith(Ty.Bool)
) {
const operand = this.lowerExpr(expr.kind.expr);
return this.pushInst(Ty.Bool, "Not", { source: operand });
}
if (expr.kind.op === "Ref" || expr.kind.op === "RefMut") {
const place = expr.kind.expr;
if (place.is("IdentExpr")) {
const sym = this.syms.get(place);
if (sym.tag === "Let") {
const local = this.localMap.get(sym.param.id);
if (!local) {
throw new Error();
}
return local;
}
throw new Error(
`${expr.kind.op} with sym ${sym.tag} not handled`,
);
}
if (place.is("IndexExpr")) {
const placeTy = this.tys.expr(place);
const placeInst = this.lowerPlace(place);
if (placeTy.is("Slice")) {
return placeInst;
}
return this.pushInst(placeTy, "Load", {
source: placeInst,
});
}
throw new Error(
`${expr.kind.op} with place ${place.kind.tag} not handled`,
);
}
if (expr.kind.op === "Deref") {
const source = this.lowerExpr(expr.kind.expr);
return this.pushInst(resultTy, "Load", { source });
}
throw new Error(
`'${expr.kind.op}' with '${resultTy.pretty()}' not handled`,
);
}
private makeVoid(): Inst {
return this.pushInst(Ty.Void, "Void", {});
}
private pushInst<
Tag extends InstKind["tag"],
>(
ty: Ty,
tag: Tag,
kind: Omit<InstKind & { tag: Tag }, "tag">,
): Inst {
const inst = new Inst(ty, { tag, ...kind } as InstKind);
this.bbs.at(-1)!.insts.push(inst);
return inst;
}
}
type BinaryOpPattern = {
op: ast.BinaryOp;
tag: BinaryOp;
result: Ty;
left?: Ty;
right?: Ty;
};
const binaryOpPatterns: BinaryOpPattern[] = [
{ op: "Add", tag: "Add", result: Ty.Int, left: Ty.Int },
{ op: "Subtract", tag: "Sub", result: Ty.Int, left: Ty.Int },
{ op: "Multiply", tag: "Mul", result: Ty.Int, left: Ty.Int },
{ op: "Divide", tag: "Div", result: Ty.Int, left: Ty.Int },
{ op: "Remainder", tag: "Rem", result: Ty.Int },
{ op: "Eq", tag: "Eq", result: Ty.Bool, left: Ty.Int },
{ op: "Ne", tag: "Ne", result: Ty.Bool, left: Ty.Int },
{ op: "Lt", tag: "Lt", result: Ty.Bool, left: Ty.Int },
{ op: "Gt", tag: "Gt", result: Ty.Bool, left: Ty.Int },
{ op: "Lte", tag: "Lte", result: Ty.Bool, left: Ty.Int },
{ op: "Gte", tag: "Gte", result: Ty.Bool, left: Ty.Int },
];