394 lines
14 KiB
TypeScript
394 lines
14 KiB
TypeScript
import * as mir from "./mir.ts";
|
|
|
|
export class FnInterpreter {
|
|
private regs = new Map<mir.Inst, Val>();
|
|
private bb: mir.BasicBlock;
|
|
private instIdx = 0;
|
|
|
|
constructor(
|
|
private fn: mir.Fn,
|
|
private args: Val[],
|
|
) {
|
|
this.bb = this.fn.bbs[0];
|
|
}
|
|
|
|
eval(): Val {
|
|
const cx = new mir.PrettyCx();
|
|
while (this.instIdx < this.bb.insts.length) {
|
|
const inst = this.bb.insts[this.instIdx];
|
|
this.instIdx += 1;
|
|
|
|
// console.log(poin
|
|
// `[${this.instIdx.toString().padStart(2, " ")}] ${
|
|
// inst.pretty(cx)
|
|
// }`,
|
|
// );
|
|
|
|
const k = inst.kind;
|
|
switch (k.tag) {
|
|
case "Error":
|
|
throw new Error();
|
|
case "Void":
|
|
case "Int":
|
|
case "Bool":
|
|
this.regs.set(inst, new Val(k));
|
|
break;
|
|
case "Array":
|
|
this.regs.set(
|
|
inst,
|
|
new Val({
|
|
tag: "Array",
|
|
values: k.values.map((inst) =>
|
|
this.regs.get(inst)!
|
|
),
|
|
}),
|
|
);
|
|
break;
|
|
case "Fn":
|
|
this.regs.set(inst, new Val(k));
|
|
break;
|
|
case "Param":
|
|
this.regs.set(inst, this.args[k.idx]);
|
|
break;
|
|
// case "Index": {
|
|
// const idx = this.regs.get(k.arg)!;
|
|
// if (idx.kind.tag !== "Int") {
|
|
// throw new Error();
|
|
// }
|
|
// const value = this.regs.get(k.value)!;
|
|
// if (value.kind.tag === "Array") {
|
|
// if (idx.kind.value >= value.kind.values.length) {
|
|
// throw new Error();
|
|
// }
|
|
// this.regs.set(inst, value.kind.values[idx.kind.value]);
|
|
// } else if (value.kind.tag === "Slice") {
|
|
// if (value.kind.value.kind.tag !== "Array") {
|
|
// throw new Error();
|
|
// }
|
|
// const values = value.kind.value.kind.values;
|
|
// const begin = value.kind.value;
|
|
// const end = value.kind.end;
|
|
// if (
|
|
// begin.kind.tag !== "Int" || end.kind.tag !== "Int"
|
|
// ) {
|
|
// throw new Error();
|
|
// }
|
|
// if (
|
|
// begin.kind.value + idx.kind.value < 0 ||
|
|
// end.kind.value + idx.kind.value >= values.length
|
|
// ) {
|
|
// throw new Error();
|
|
// }
|
|
// this.regs.set(
|
|
// inst,
|
|
// values[begin.kind.value + idx.kind.value],
|
|
// );
|
|
// } else {
|
|
// throw new Error();
|
|
// }
|
|
// break;
|
|
// }
|
|
case "Index": {
|
|
const offset = this.regs.get(k.offset)!;
|
|
if (offset.kind.tag !== "Int") {
|
|
throw new Error();
|
|
}
|
|
const base = this.regs.get(k.base)!;
|
|
if (base.kind.tag === "Ptr") {
|
|
const array = base.kind.value;
|
|
if (array.kind.tag !== "Array") {
|
|
console.log({ array });
|
|
throw new Error();
|
|
}
|
|
if (offset.kind.value >= array.kind.values.length) {
|
|
throw new Error();
|
|
}
|
|
this.regs.set(
|
|
inst,
|
|
new Val({
|
|
tag: "ArrayElemPtr",
|
|
values: array.kind.values,
|
|
idx: offset.kind.value,
|
|
mutable: base.kind.mutable,
|
|
}),
|
|
);
|
|
} else if (base.kind.tag === "Slice") {
|
|
throw new Error();
|
|
} else {
|
|
throw new Error();
|
|
}
|
|
break;
|
|
}
|
|
case "Slice": {
|
|
const begin = k.begin && this.regs.get(k.begin)!;
|
|
const end = k.end && this.regs.get(k.end)!;
|
|
if (
|
|
begin && begin.kind.tag !== "Int" ||
|
|
end && end.kind.tag !== "Int"
|
|
) {
|
|
throw new Error();
|
|
}
|
|
const ptr = this.regs.get(k.value)!;
|
|
if (ptr.kind.tag !== "Ptr") {
|
|
throw new Error();
|
|
}
|
|
const value = ptr.kind.value;
|
|
if (value.kind.tag !== "Array") {
|
|
throw new Error();
|
|
}
|
|
if (
|
|
begin && begin.kind.tag === "Int" &&
|
|
begin.kind.value < 0
|
|
) {
|
|
throw new Error();
|
|
}
|
|
if (
|
|
end && end.kind.tag === "Int" &&
|
|
end.kind.value >= value.kind.values.length
|
|
) {
|
|
throw new Error();
|
|
}
|
|
this.regs.set(
|
|
inst,
|
|
new Val({
|
|
tag: "Slice",
|
|
value,
|
|
begin: begin ?? new Val({ tag: "Int", value: 0 }),
|
|
end: end ??
|
|
new Val({
|
|
tag: "Int",
|
|
value: value.kind.values.length,
|
|
}),
|
|
}),
|
|
);
|
|
break;
|
|
}
|
|
case "Call": {
|
|
const fn = this.regs.get(k.callee);
|
|
if (!fn || fn.kind.tag !== "Fn") {
|
|
throw new Error();
|
|
}
|
|
const args = k.args.map((arg) => this.regs.get(arg)!);
|
|
const val = new FnInterpreter(fn.kind.fn, args).eval();
|
|
this.regs.set(inst, val);
|
|
break;
|
|
}
|
|
case "Alloca": {
|
|
this.regs.set(
|
|
inst,
|
|
new Val({
|
|
tag: "Ptr",
|
|
mutable: true,
|
|
value: new Val({ tag: "Null" }),
|
|
}),
|
|
);
|
|
break;
|
|
}
|
|
case "Load": {
|
|
const source = this.regs.get(k.source)!;
|
|
if (source.kind.tag === "Ptr") {
|
|
this.regs.set(inst, source.kind.value);
|
|
} else if (source.kind.tag === "ArrayElemPtr") {
|
|
this.regs.set(
|
|
inst,
|
|
source.kind.values[source.kind.idx],
|
|
);
|
|
} else {
|
|
throw new Error();
|
|
}
|
|
break;
|
|
}
|
|
case "Store": {
|
|
const target = this.regs.get(k.target)!;
|
|
if (target.kind.tag === "Ptr") {
|
|
const source = this.regs.get(k.source)!;
|
|
target.kind.value = source;
|
|
} else if (target.kind.tag === "ArrayElemPtr") {
|
|
const source = this.regs.get(k.source)!;
|
|
target.kind.values[target.kind.idx] = source;
|
|
} else {
|
|
throw new Error();
|
|
}
|
|
break;
|
|
}
|
|
case "Jump": {
|
|
this.bb = k.target;
|
|
this.instIdx = 0;
|
|
break;
|
|
}
|
|
case "Branch": {
|
|
const cond = this.regs.get(k.cond)!;
|
|
if (cond.kind.tag !== "Bool") {
|
|
throw new Error();
|
|
}
|
|
this.bb = cond.kind.value ? k.truthy : k.falsy;
|
|
this.instIdx = 0;
|
|
break;
|
|
}
|
|
case "Return":
|
|
return this.regs.get(k.source)!;
|
|
case "Not": {
|
|
const source = this.regs.get(k.source)!;
|
|
if (source.kind.tag !== "Bool") {
|
|
throw new Error();
|
|
}
|
|
this.regs.set(
|
|
inst,
|
|
new Val({ tag: "Bool", value: !source.kind.value }),
|
|
);
|
|
break;
|
|
}
|
|
case "Negate": {
|
|
const source = this.regs.get(k.source)!;
|
|
if (source.kind.tag !== "Int") {
|
|
throw new Error();
|
|
}
|
|
this.regs.set(
|
|
inst,
|
|
new Val({ tag: "Int", value: -source.kind.value }),
|
|
);
|
|
break;
|
|
}
|
|
case "Eq":
|
|
case "Ne":
|
|
case "Lt":
|
|
case "Gt":
|
|
case "Lte":
|
|
case "Gte":
|
|
case "BitOr":
|
|
case "BitXor":
|
|
case "BitAnd":
|
|
case "Shl":
|
|
case "Shr":
|
|
case "Add":
|
|
case "Sub":
|
|
case "Mul":
|
|
case "Div":
|
|
case "Rem":
|
|
this.evalBinaryOp(inst, k);
|
|
break;
|
|
case "DebugPrint":
|
|
console.log(
|
|
k.args
|
|
.map((a) => this.regs.get(a)!.pretty())
|
|
.join(", "),
|
|
);
|
|
break;
|
|
default:
|
|
k satisfies never;
|
|
}
|
|
}
|
|
return Val.Void;
|
|
}
|
|
|
|
private evalBinaryOp(
|
|
inst: mir.Inst,
|
|
k: mir.InstKind & { tag: mir.BinaryOp },
|
|
) {
|
|
const left = this.regs.get(k.left)!;
|
|
const right = this.regs.get(k.right)!;
|
|
|
|
if (left.kind.tag === "Int" && right.kind.tag === "Int") {
|
|
const l = left.kind.value;
|
|
const r = right.kind.value;
|
|
|
|
const value = (() => {
|
|
const Int = (value: number) => new Val({ tag: "Int", value });
|
|
const Bool = (value: boolean) =>
|
|
new Val({ tag: "Bool", value });
|
|
|
|
switch (k.tag) {
|
|
case "Eq":
|
|
return Bool(l === r);
|
|
case "Ne":
|
|
return Bool(l !== r);
|
|
case "Lt":
|
|
return Bool(l < r);
|
|
case "Gt":
|
|
return Bool(l > r);
|
|
case "Lte":
|
|
return Bool(l <= r);
|
|
case "Gte":
|
|
return Bool(l >= r);
|
|
case "BitOr":
|
|
case "BitXor":
|
|
case "BitAnd":
|
|
case "Shl":
|
|
case "Shr":
|
|
break;
|
|
case "Add":
|
|
return Int(l + r);
|
|
case "Sub":
|
|
return Int(l - r);
|
|
case "Mul":
|
|
return Int(l * r);
|
|
case "Div":
|
|
return Int(Math.floor(l / r));
|
|
case "Rem":
|
|
return Int(l % r);
|
|
}
|
|
throw new Error(`'${k.tag}' not handled`);
|
|
})();
|
|
|
|
this.regs.set(inst, value);
|
|
return;
|
|
}
|
|
throw new Error(`'${k.tag}' not handled`);
|
|
}
|
|
}
|
|
|
|
class Val {
|
|
constructor(
|
|
public kind: ValKind,
|
|
) {}
|
|
|
|
static Void = new Val({ tag: "Void" });
|
|
|
|
pretty(): string {
|
|
const k = this.kind;
|
|
switch (k.tag) {
|
|
case "Null":
|
|
return "<null>";
|
|
case "Void":
|
|
return "void";
|
|
case "Int":
|
|
case "Bool":
|
|
return `${k.value}`;
|
|
case "Ptr":
|
|
case "ArrayElemPtr":
|
|
return `<pointer>`;
|
|
case "Slice":
|
|
if (k.value.kind.tag !== "Array") {
|
|
throw new Error();
|
|
}
|
|
if (k.begin.kind.tag !== "Int" || k.end.kind.tag !== "Int") {
|
|
throw new Error();
|
|
}
|
|
return `[${
|
|
k.value.kind.values.slice(
|
|
k.begin.kind.value,
|
|
k.end.kind.value,
|
|
).map((v) => v.pretty()).join(", ")
|
|
}]`;
|
|
case "Array":
|
|
return `[${k.values.map((v) => v.pretty()).join(", ")}]`;
|
|
case "Fn":
|
|
return `<${k.fn.ty.pretty()}>`;
|
|
default:
|
|
k satisfies never;
|
|
}
|
|
throw new Error();
|
|
}
|
|
}
|
|
|
|
type ValKind =
|
|
| { tag: "Null" }
|
|
| { tag: "Void" }
|
|
| { tag: "Int"; value: number }
|
|
| { tag: "Bool"; value: boolean }
|
|
| { tag: "Ptr"; mutable: boolean; value: Val }
|
|
| { tag: "ArrayElemPtr"; mutable: boolean; values: Val[]; idx: number }
|
|
| { tag: "Slice"; value: Val; begin: Val; end: Val }
|
|
| { tag: "Array"; values: Val[] }
|
|
| { tag: "Fn"; fn: mir.Fn };
|