middle works

This commit is contained in:
sfja 2026-03-10 12:48:54 +01:00
parent d4ced3250b
commit ef269442f1
7 changed files with 520 additions and 69 deletions

View File

@ -1,9 +1,14 @@
fn add(a: int, b: int) -> int {
fn add(a: int, b: int) -> int
{
return __add(a, b);
}
fn main() -> int {
fn main()
{
let sum = add(2, 3);
print_int(sum);
}
// vim: syntax=rust

View File

@ -19,6 +19,27 @@ export class Node {
public kind: NodeKind,
) {}
as<
Tag extends NodeKind["tag"],
>(tag: Tag): NodeWithKind<Tag> {
this.assertIs(tag);
return this;
}
assertIs<
Tag extends NodeKind["tag"],
>(tag: Tag): asserts this is NodeWithKind<Tag> {
if (this.kind.tag !== tag) {
throw new Error();
}
}
is<
Tag extends NodeKind["tag"],
>(tag: Tag): this is NodeWithKind<Tag> {
return this.kind.tag === tag;
}
visit(v: Visitor) {
if (v.visit(this) === "break") {
return;
@ -95,6 +116,10 @@ export type NodeWithKind<
Tag extends NodeKind["tag"],
> = Node & { kind: { tag: Tag } };
export type Block = NodeWithKind<"Block">;
export type FnStmt = NodeWithKind<"FnStmt">;
export type Param = NodeWithKind<"Param">;
export function assertNodeWithKind<
Tag extends NodeKind["tag"],
>(node: Node, tag: Tag): asserts node is NodeWithKind<Tag> {
@ -102,9 +127,3 @@ export function assertNodeWithKind<
throw new Error();
}
}
export function isNodeWithKind<
Tag extends NodeKind["tag"],
>(node: Node, tag: Tag): node is NodeWithKind<Tag> {
return node.kind.tag === tag;
}

View File

@ -1,7 +1,23 @@
import * as ast from "./ast.ts";
import { rootSyms } from "./root_syms.ts";
import { Ty } from "./ty.ts";
const rootSyms = [
{
id: "print_int",
ty: Ty.create("Fn", {
params: [Ty.Int],
retTy: Ty.Void,
}),
},
{
id: "__add",
ty: Ty.create("Fn", {
params: [Ty.Int, Ty.Int],
retTy: Ty.Int,
}),
},
];
export class Checker {
private nodeTys = new Map<number, Ty>();
@ -24,12 +40,65 @@ export class Checker {
private checkNode(node: ast.Node): Ty {
const k = node.kind;
if (ast.isNodeWithKind(node, "FnStmt")) {
if (node.is("FnStmt")) {
return this.checkFnStmt(node);
}
if (ast.isNodeWithKind(node, "IdentTy")) {
if (node.is("Param")) {
const sym = this.resols.get(node);
if (sym.tag === "Let") {
const exprTy = this.check(sym.stmt.kind.expr);
if (node.kind.ty) {
const explicitTy = this.check(node.kind.ty);
this.assertCompatible(
exprTy,
explicitTy,
sym.stmt.kind.expr.line,
);
}
return exprTy;
}
if (sym.tag === "FnParam") {
if (!node.kind.ty) {
this.error(node.line, `parameter must have a type`);
this.fail();
}
return this.check(node.kind.ty);
}
throw new Error(`'${sym.tag}' not handled`);
}
if (node.is("IdentExpr")) {
const sym = this.resols.get(node);
if (sym.tag === "Fn") {
return this.check(sym.stmt);
}
if (sym.tag === "Builtin") {
return rootSyms.find((s) => s.id === sym.id)!.ty;
}
if (sym.tag === "FnParam") {
return this.check(sym.param);
}
if (sym.tag === "Let") {
return this.check(sym.param);
}
throw new Error(`'${sym.tag}' not handled`);
}
if (node.is("IntExpr")) {
return Ty.Int;
}
if (node.is("CallExpr")) {
return this.checkCall(node);
}
if (node.is("IdentTy")) {
switch (node.kind.ident) {
case "void":
return Ty.Void;
case "int":
return Ty.Int;
default:
@ -37,7 +106,7 @@ export class Checker {
}
}
throw new Error(`'${k.tag}' not checked`);
throw new Error(`'${k.tag}' not unhandled`);
}
private checkFnStmt(stmt: ast.NodeWithKind<"FnStmt">): Ty {
@ -48,10 +117,20 @@ export class Checker {
k.body.visit({
visit: (node) => {
if (ast.isNodeWithKind(node, "ReturnStmt")) {
if (node.kind.expr) {
const ty = this.check(node.kind.expr);
} else {
if (node.is("ReturnStmt")) {
const ty = node.kind.expr
? this.check(node.kind.expr)
: Ty.Void;
if (!ty.compatibleWith(retTy)) {
this.error(
node.line,
`type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`,
);
this.info(
stmt.kind.retTy?.line ?? stmt.line,
`return type '${retTy}' defined here`,
);
this.fail();
}
}
},
@ -61,18 +140,73 @@ export class Checker {
return Ty.create("FnStmt", { stmt, ty });
}
private typesCompatible(a: Ty, b: Ty): boolean {
const ak = a.kind;
const bk = b.kind;
if (ak.tag === "Error") {
return false;
private checkCall(node: ast.NodeWithKind<"CallExpr">): Ty {
const calleeTy = this.check(node.kind.expr);
const callableTy = calleeTy.isKind("Fn")
? calleeTy
: calleeTy.isKind("FnStmt")
? calleeTy.kind.ty as Ty & { kind: { tag: "Fn" } }
: null;
if (!callableTy) {
this.error(
node.line,
`type '${calleeTy.pretty()}' not callable`,
);
this.fail();
}
if (ak.tag === "Void") {
return bk.tag === "Void";
const args = node.kind.args
.map((arg) => this.check(arg));
const params = callableTy.kind.params;
if (args.length !== params.length) {
this.error(
node.line,
`incorrect amount of arguments. got ${args.length} expected ${params.length}`,
);
if (calleeTy.isKind("FnStmt")) {
this.info(
calleeTy.kind.stmt.line,
"function defined here",
);
}
this.fail();
}
for (const i of args.keys()) {
if (!args[i].compatibleWith(params[i])) {
this.error(
node.kind.args[i].line,
`type '${args[i].pretty()}' not compatible with type '${
params[i]
}', for argument ${i}`,
);
if (calleeTy.isKind("FnStmt")) {
this.info(
calleeTy.kind.stmt.kind.params[i].line,
`parameter '${
calleeTy.kind.stmt.kind.params[i]
.as("Param").kind.ident
}' defined here`,
);
}
this.fail();
}
}
return callableTy.kind.retTy;
}
private assertCompatible(left: Ty, right: Ty, line: number): void {
if (!left.compatibleWith(right)) {
this.error(
line,
`type '${left.pretty()}' not compatible with type '${right.pretty()}'`,
);
this.fail();
}
}
private error(line: number, message: string): never {
private error(line: number, message: string) {
printDiagnostics(
this.filename,
line,
@ -80,6 +214,19 @@ export class Checker {
message,
this.text,
);
}
private info(line: number, message: string) {
printDiagnostics(
this.filename,
line,
"info",
message,
this.text,
);
}
private fail(): never {
Deno.exit(1);
}
}
@ -92,6 +239,7 @@ export type Sym =
tag: "FnParam";
stmt: ast.NodeWithKind<"FnStmt">;
param: ast.NodeWithKind<"Param">;
idx: number;
}
| {
tag: "Let";
@ -106,7 +254,7 @@ export class ResolveMap {
get(node: ast.Node): Sym {
if (!this.resols.has(node.id)) {
throw new Error("not resolved");
throw new Error(`'${node.kind.tag}' not resolved`);
}
return this.resols.get(node.id)!;
}
@ -166,10 +314,8 @@ export function resolve(
if (k.tag === "File" || k.tag === "Block") {
syms = ResolverSyms.forkFrom(syms);
for (const stmt of k.stmts) {
const k = stmt.kind;
if (k.tag === "FnStmt") {
ast.assertNodeWithKind(stmt, "FnStmt");
syms.define(k.ident, { tag: "Fn", stmt });
if (stmt.is("FnStmt")) {
syms.define(stmt.kind.ident, { tag: "Fn", stmt });
}
}
node.visitBelow(this);
@ -180,13 +326,11 @@ export function resolve(
if (k.tag === "FnStmt") {
ast.assertNodeWithKind(node, "FnStmt");
syms = ResolverSyms.forkFrom(syms);
for (const param of k.params) {
for (const [idx, param] of k.params.entries()) {
ast.assertNodeWithKind(param, "Param");
syms.define(param.kind.ident, {
tag: "FnParam",
stmt: node,
param,
});
const sym: Sym = { tag: "FnParam", stmt: node, param, idx };
syms.define(param.kind.ident, sym);
resols.set(param.id, sym);
}
node.visitBelow(this);
syms = syms.parent!;
@ -196,7 +340,9 @@ export function resolve(
if (k.tag === "LetStmt") {
const stmt = node as ast.NodeWithKind<"LetStmt">;
const param = k.param as ast.NodeWithKind<"Param">;
syms.define(param.kind.ident, { tag: "Let", stmt, param });
const sym: Sym = { tag: "Let", stmt, param };
syms.define(param.kind.ident, sym);
resols.set(param.id, sym);
}
if (k.tag === "IdentExpr") {
@ -480,18 +626,19 @@ export function tokenize(text: string): Tok[] {
export function printDiagnostics(
filename: string,
line: number,
severity: "error",
severity: "error" | "info",
message: string,
text?: string,
) {
const severityColor = ({
"error": "red",
"info": "blue",
} as { [Key in typeof severity]: string })[severity];
console.error(
`%c${severity}%c: ${message}\n %c--> ${filename}:${line}%c`,
`color: ${severityColor}; font-weight: bold;`,
"color: white; font-weight: bold;",
"color: lightwhite; font-weight: bold;",
"color: gray;",
"",
);
@ -512,7 +659,7 @@ export function printDiagnostics(
`${" ".repeat(lineNumberText.length)}%c|` +
`%c${"~".repeat(lineText.length)}%c`,
"color: cyan;",
"color: white;",
"color: lightwhite;",
"color: cyan;",
`color: ${severityColor};`,
"",

View File

@ -1,5 +1,6 @@
import * as front from "./front.ts";
import * as ast from "./ast.ts";
import * as front from "./front.ts";
import * as middle from "./middle.ts";
const filename = Deno.args[0];
const text = await Deno.readTextFile(filename);
@ -29,6 +30,7 @@ if (!mainFn) {
Deno.exit(1);
}
const mainTy = checker.check(mainFn);
const m = new middle.MiddleLowerer(resols, checker);
const mainMiddleFn = m.lowerFn(mainFn);
console.log({ ast: fileAst, resols });
console.log(mainMiddleFn.pretty());

View File

@ -1,13 +1,262 @@
import * as ast from "./ast.ts";
import { Checker, ResolveMap } from "./front.ts";
import { Ty } from "./ty.ts";
export class MiddleLowerer {
private fns = new Map<number, Fn>();
constructor(
private resols: ResolveMap,
private checker: Checker,
) {}
lowerFn(stmt: ast.FnStmt): Fn {
if (this.fns.has(stmt.id)) {
return this.fns.get(stmt.id)!;
}
const fn = new FnLowerer(this, this.resols, this.checker, stmt).lower();
this.fns.set(stmt.id, fn);
return fn;
}
}
class FnLowerer {
private bbs: BasicBlock[] = [new BasicBlock([])];
private localMap = new Map<number, Inst>();
constructor(
private lowerer: MiddleLowerer,
private resols: ResolveMap,
private checker: Checker,
private stmt: ast.FnStmt,
) {}
lower(): Fn {
const ty = this.checker.check(this.stmt);
this.lowerBlock(this.stmt.kind.body.as("Block"));
return new Fn(this.stmt, ty, this.bbs);
}
private lowerBlock(block: ast.Block) {
for (const stmt of block.kind.stmts) {
this.lowerStmt(stmt);
}
}
private lowerStmt(stmt: ast.Node) {
if (stmt.is("LetStmt")) {
const ty = this.checker.check(stmt.kind.param);
const expr = this.lowerExpr(stmt.kind.expr);
const local = this.pushInst(ty, "AllocLocal", {});
this.pushInst(Ty.Void, "LocalStore", {
target: local,
source: expr,
});
this.localMap.set(stmt.kind.param.id, local);
return;
}
if (stmt.is("ReturnStmt")) {
const source = stmt.kind.expr
? this.lowerExpr(stmt.kind.expr)
: this.makeVoid();
this.pushInst(Ty.Void, "Return", { source });
return;
}
if (stmt.is("ExprStmt")) {
this.lowerExpr(stmt.kind.expr);
return;
}
throw new Error(`'${stmt.kind.tag}' not handled`);
}
private lowerExpr(expr: ast.Node): Inst {
if (expr.is("IdentExpr")) {
const sym = this.resols.get(expr);
if (sym.tag === "Fn") {
const fn = this.lowerer.lowerFn(sym.stmt);
return this.pushInst(fn.ty, "Fn", { fn });
}
if (sym.tag === "FnParam") {
const ty = this.checker.check(sym.param);
return this.pushInst(ty, "Param", { idx: sym.idx });
}
if (sym.tag === "Builtin") {
throw new Error("handle elsewhere");
}
if (sym.tag === "Let") {
const local = this.localMap.get(sym.param.id);
if (!local) {
throw new Error();
}
return this.pushInst(local.ty, "LocalLoad", { source: local });
}
throw new Error(`'${sym.tag}' not handled`);
}
if (expr.is("IntExpr")) {
return this.pushInst(Ty.Int, "Int", { value: expr.kind.value });
}
if (expr.is("CallExpr")) {
const ty = this.checker.check(expr);
const args = expr.kind.args
.map((arg) => this.lowerExpr(arg));
if (expr.kind.expr.is("IdentExpr")) {
const sym = this.resols.get(expr.kind.expr);
if (sym.tag === "Builtin") {
if (sym.id === "__add") {
const [left, right] = args;
return this.pushInst(ty, "Add", { left, right });
}
if (sym.id === "print_int") {
return this.pushInst(ty, "DebugPrint", { args });
}
throw new Error(`builtin '${sym.id}' not handled`);
}
}
const callee = this.lowerExpr(expr.kind.expr);
return this.pushInst(ty, "Call", { callee, args });
}
throw new Error(`'${expr.kind.tag}' not handled`);
}
private makeVoid(): Inst {
return this.pushInst(Ty.Void, "Void", {});
}
private pushInst<
Tag extends InstKind["tag"],
>(
ty: Ty,
tag: Tag,
kind: Omit<InstKind & { tag: Tag }, "tag">,
): Inst {
const inst = new Inst(ty, { tag, ...kind } as InstKind);
this.bbs.at(-1)!.insts.push(inst);
return inst;
}
}
export class Fn {
constructor(
public stmt: ast.FnStmt,
public ty: Ty,
public bbs: BasicBlock[],
) {}
pretty(): string {
const fnTy = this.ty.isKind("FnStmt") && this.ty.kind.ty.isKind("Fn")
? this.ty.kind.ty
: null;
if (!fnTy) {
throw new Error();
}
const cx = new PrettyCx();
return `fn ${this.stmt.kind.ident}(${
fnTy.kind.params
.map((ty, idx) => `${idx}: ${ty.pretty()}`)
.join(", ")
}) -> ${fnTy.kind.retTy.pretty()}\n{\n${
this.bbs
.map((bb) => bb.pretty(cx))
.join("\n")
}\n}`;
}
}
class IdMap<T> {
private map = new Map<T, number>();
private counter = 0;
id(val: T): number {
if (!this.map.has(val)) {
this.map.set(val, this.counter++);
}
return this.map.get(val)!;
}
}
class PrettyCx {
private bbIds = new IdMap<BasicBlock>();
private regIds = new IdMap<Inst>();
bbId(bb: BasicBlock): number {
return this.bbIds.id(bb);
}
regId(reg: Inst): number {
return this.regIds.id(reg);
}
}
export class BasicBlock {
constructor(
public instructions: Inst[],
public insts: Inst[],
) {}
pretty(cx: PrettyCx): string {
return `bb${cx.bbId(this)}:\n${
this.insts
.map((inst) => inst.pretty(cx))
.map((line) => ` ${line}`)
.join("\n")
}`;
}
}
export class Inst {
constructor() {}
constructor(
public ty: Ty,
public kind: InstKind,
) {}
pretty(cx: PrettyCx): string {
const r = (v: Inst) => `_${cx.regId(v)}`;
return `${r(this)}: ${this.ty.pretty()} = ${this.kind.tag}${
(() => {
const k = this.kind;
switch (k.tag) {
case "Error":
return "";
case "Void":
return "";
case "Int":
return ` ${k.value}`;
case "Fn":
return ` ${k.fn.stmt.kind.ident}`;
case "Param":
return ` ${k.idx}`;
case "Call":
return ` ${r(k.callee)} (${k.args.map(r).join(", ")})`;
case "AllocLocal":
return "";
case "LocalLoad":
return ` ${r(k.source)}`;
case "LocalStore":
return ` ${r(k.target)}, ${r(k.source)}`;
case "Return":
return ` ${r(k.source)}`;
case "Add":
return ` ${r(k.left)} ${r(k.right)}`;
case "DebugPrint":
return ` ${k.args.map(r).join(", ")}`;
}
const _: never = k;
})()
}`;
}
}
export type InsKind =
export type InstKind =
| { tag: "Error" }
| { tag: "Call" };
| { tag: "Void" }
| { tag: "Int"; value: number }
| { tag: "Fn"; fn: Fn }
| { tag: "Param"; idx: number }
| { tag: "Call"; callee: Inst; args: Inst[] }
| { tag: "AllocLocal" }
| { tag: "LocalLoad"; source: Inst }
| { tag: "LocalStore"; target: Inst; source: Inst }
| { tag: "Return"; source: Inst }
| { tag: "Add"; left: Inst; right: Inst }
| { tag: "DebugPrint"; args: Inst[] };

View File

@ -1,23 +0,0 @@
import { Ty } from "./ty.ts";
export type RootSym = {
id: string;
ty: Ty;
};
export const rootSyms: RootSym[] = [
{
id: "print_int",
ty: Ty.create("Fn", {
params: [],
retTy: Ty.Void,
}),
},
{
id: "__add",
ty: Ty.create("Fn", {
params: [Ty.Int, Ty.Int],
retTy: Ty.Int,
}),
},
];

View File

@ -51,7 +51,59 @@ export class Ty {
return other.isKind("Int");
}
if (this.isKind("Fn")) {
if (!other.isKind("Fn")) {
return false;
}
for (const i of this.kind.params.keys()) {
if (!this.kind.params[i].compatibleWith(other.kind.params[i])) {
return false;
}
}
if (!this.kind.retTy.compatibleWith(other.kind.retTy)) {
return false;
}
return true;
}
if (this.isKind("FnStmt")) {
if (!other.isKind("FnStmt")) {
return false;
}
if (!this.kind.ty.compatibleWith(other.kind.ty)) {
return false;
}
// redundant; sanity check
if (this.kind.stmt.id !== other.kind.stmt.id) {
return false;
}
return true;
}
throw new Error(`'${this.kind.tag}' not handled`);
}
pretty(): string {
if (this.isKind("Error")) {
return "<error>";
}
if (this.isKind("Void")) {
return "void";
}
if (this.isKind("Int")) {
return "int";
}
if (this.isKind("Fn")) {
return `fn (${
this.kind.params.map((param) => param.pretty()).join(", ")
}) -> ${this.kind.retTy.pretty()}`;
}
if (this.isKind("FnStmt")) {
if (!this.kind.ty.isKind("Fn")) throw new Error();
return `fn ${this.kind.stmt.kind.ident}(${
this.kind.ty.kind.params.map((param) => param.pretty()).join(
", ",
)
}) -> ${this.kind.ty.kind.retTy.pretty()}`;
}
throw new Error("unhandled");
}
}