From ef269442f1cffd3cbe4c4e592981a2cc810527e4 Mon Sep 17 00:00:00 2001 From: sfja Date: Tue, 10 Mar 2026 12:48:54 +0100 Subject: [PATCH] middle works --- program.ethlang | 9 +- src/ast.ts | 31 ++++-- src/front.ts | 209 ++++++++++++++++++++++++++++++++------ src/main.ts | 8 +- src/middle.ts | 257 ++++++++++++++++++++++++++++++++++++++++++++++- src/root_syms.ts | 23 ----- src/ty.ts | 52 ++++++++++ 7 files changed, 520 insertions(+), 69 deletions(-) delete mode 100644 src/root_syms.ts diff --git a/program.ethlang b/program.ethlang index 726f906..8ad6b2d 100644 --- a/program.ethlang +++ b/program.ethlang @@ -1,9 +1,14 @@ -fn add(a: int, b: int) -> int { +fn add(a: int, b: int) -> int +{ return __add(a, b); } -fn main() -> int { +fn main() +{ let sum = add(2, 3); print_int(sum); } + +// vim: syntax=rust + diff --git a/src/ast.ts b/src/ast.ts index daa0a79..a6fdde5 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -19,6 +19,27 @@ export class Node { public kind: NodeKind, ) {} + as< + Tag extends NodeKind["tag"], + >(tag: Tag): NodeWithKind { + this.assertIs(tag); + return this; + } + + assertIs< + Tag extends NodeKind["tag"], + >(tag: Tag): asserts this is NodeWithKind { + if (this.kind.tag !== tag) { + throw new Error(); + } + } + + is< + Tag extends NodeKind["tag"], + >(tag: Tag): this is NodeWithKind { + return this.kind.tag === tag; + } + visit(v: Visitor) { if (v.visit(this) === "break") { return; @@ -95,6 +116,10 @@ export type NodeWithKind< Tag extends NodeKind["tag"], > = Node & { kind: { tag: Tag } }; +export type Block = NodeWithKind<"Block">; +export type FnStmt = NodeWithKind<"FnStmt">; +export type Param = NodeWithKind<"Param">; + export function assertNodeWithKind< Tag extends NodeKind["tag"], >(node: Node, tag: Tag): asserts node is NodeWithKind { @@ -102,9 +127,3 @@ export function assertNodeWithKind< throw new Error(); } } - -export function isNodeWithKind< - Tag extends NodeKind["tag"], ->(node: Node, tag: Tag): node is NodeWithKind { - return node.kind.tag === tag; -} diff --git a/src/front.ts b/src/front.ts index eb68cea..68f0f7a 100644 --- a/src/front.ts +++ b/src/front.ts @@ -1,7 +1,23 @@ import * as ast from "./ast.ts"; -import { rootSyms } from "./root_syms.ts"; import { Ty } from "./ty.ts"; +const rootSyms = [ + { + id: "print_int", + ty: Ty.create("Fn", { + params: [Ty.Int], + retTy: Ty.Void, + }), + }, + { + id: "__add", + ty: Ty.create("Fn", { + params: [Ty.Int, Ty.Int], + retTy: Ty.Int, + }), + }, +]; + export class Checker { private nodeTys = new Map(); @@ -24,12 +40,65 @@ export class Checker { private checkNode(node: ast.Node): Ty { const k = node.kind; - if (ast.isNodeWithKind(node, "FnStmt")) { + if (node.is("FnStmt")) { return this.checkFnStmt(node); } - if (ast.isNodeWithKind(node, "IdentTy")) { + if (node.is("Param")) { + const sym = this.resols.get(node); + + if (sym.tag === "Let") { + const exprTy = this.check(sym.stmt.kind.expr); + if (node.kind.ty) { + const explicitTy = this.check(node.kind.ty); + this.assertCompatible( + exprTy, + explicitTy, + sym.stmt.kind.expr.line, + ); + } + return exprTy; + } + if (sym.tag === "FnParam") { + if (!node.kind.ty) { + this.error(node.line, `parameter must have a type`); + this.fail(); + } + return this.check(node.kind.ty); + } + + throw new Error(`'${sym.tag}' not handled`); + } + + if (node.is("IdentExpr")) { + const sym = this.resols.get(node); + if (sym.tag === "Fn") { + return this.check(sym.stmt); + } + if (sym.tag === "Builtin") { + return rootSyms.find((s) => s.id === sym.id)!.ty; + } + if (sym.tag === "FnParam") { + return this.check(sym.param); + } + if (sym.tag === "Let") { + return this.check(sym.param); + } + throw new Error(`'${sym.tag}' not handled`); + } + + if (node.is("IntExpr")) { + return Ty.Int; + } + + if (node.is("CallExpr")) { + return this.checkCall(node); + } + + if (node.is("IdentTy")) { switch (node.kind.ident) { + case "void": + return Ty.Void; case "int": return Ty.Int; default: @@ -37,7 +106,7 @@ export class Checker { } } - throw new Error(`'${k.tag}' not checked`); + throw new Error(`'${k.tag}' not unhandled`); } private checkFnStmt(stmt: ast.NodeWithKind<"FnStmt">): Ty { @@ -48,10 +117,20 @@ export class Checker { k.body.visit({ visit: (node) => { - if (ast.isNodeWithKind(node, "ReturnStmt")) { - if (node.kind.expr) { - const ty = this.check(node.kind.expr); - } else { + if (node.is("ReturnStmt")) { + const ty = node.kind.expr + ? this.check(node.kind.expr) + : Ty.Void; + if (!ty.compatibleWith(retTy)) { + this.error( + node.line, + `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, + ); + this.info( + stmt.kind.retTy?.line ?? stmt.line, + `return type '${retTy}' defined here`, + ); + this.fail(); } } }, @@ -61,18 +140,73 @@ export class Checker { return Ty.create("FnStmt", { stmt, ty }); } - private typesCompatible(a: Ty, b: Ty): boolean { - const ak = a.kind; - const bk = b.kind; - if (ak.tag === "Error") { - return false; + private checkCall(node: ast.NodeWithKind<"CallExpr">): Ty { + const calleeTy = this.check(node.kind.expr); + + const callableTy = calleeTy.isKind("Fn") + ? calleeTy + : calleeTy.isKind("FnStmt") + ? calleeTy.kind.ty as Ty & { kind: { tag: "Fn" } } + : null; + + if (!callableTy) { + this.error( + node.line, + `type '${calleeTy.pretty()}' not callable`, + ); + this.fail(); } - if (ak.tag === "Void") { - return bk.tag === "Void"; + + const args = node.kind.args + .map((arg) => this.check(arg)); + const params = callableTy.kind.params; + if (args.length !== params.length) { + this.error( + node.line, + `incorrect amount of arguments. got ${args.length} expected ${params.length}`, + ); + if (calleeTy.isKind("FnStmt")) { + this.info( + calleeTy.kind.stmt.line, + "function defined here", + ); + } + this.fail(); + } + for (const i of args.keys()) { + if (!args[i].compatibleWith(params[i])) { + this.error( + node.kind.args[i].line, + `type '${args[i].pretty()}' not compatible with type '${ + params[i] + }', for argument ${i}`, + ); + if (calleeTy.isKind("FnStmt")) { + this.info( + calleeTy.kind.stmt.kind.params[i].line, + `parameter '${ + calleeTy.kind.stmt.kind.params[i] + .as("Param").kind.ident + }' defined here`, + ); + } + this.fail(); + } + } + return callableTy.kind.retTy; + } + + private assertCompatible(left: Ty, right: Ty, line: number): void { + if (!left.compatibleWith(right)) { + this.error( + line, + `type '${left.pretty()}' not compatible with type '${right.pretty()}'`, + ); + this.fail(); } } - private error(line: number, message: string): never { + private error(line: number, message: string) { printDiagnostics( this.filename, line, @@ -80,6 +214,19 @@ export class Checker { message, this.text, ); + } + + private info(line: number, message: string) { + printDiagnostics( + this.filename, + line, + "info", + message, + this.text, + ); + } + + private fail(): never { Deno.exit(1); } } @@ -92,6 +239,7 @@ export type Sym = tag: "FnParam"; stmt: ast.NodeWithKind<"FnStmt">; param: ast.NodeWithKind<"Param">; + idx: number; } | { tag: "Let"; @@ -106,7 +254,7 @@ export class ResolveMap { get(node: ast.Node): Sym { if (!this.resols.has(node.id)) { - throw new Error("not resolved"); + throw new Error(`'${node.kind.tag}' not resolved`); } return this.resols.get(node.id)!; } @@ -166,10 +314,8 @@ export function resolve( if (k.tag === "File" || k.tag === "Block") { syms = ResolverSyms.forkFrom(syms); for (const stmt of k.stmts) { - const k = stmt.kind; - if (k.tag === "FnStmt") { - ast.assertNodeWithKind(stmt, "FnStmt"); - syms.define(k.ident, { tag: "Fn", stmt }); + if (stmt.is("FnStmt")) { + syms.define(stmt.kind.ident, { tag: "Fn", stmt }); } } node.visitBelow(this); @@ -180,13 +326,11 @@ export function resolve( if (k.tag === "FnStmt") { ast.assertNodeWithKind(node, "FnStmt"); syms = ResolverSyms.forkFrom(syms); - for (const param of k.params) { + for (const [idx, param] of k.params.entries()) { ast.assertNodeWithKind(param, "Param"); - syms.define(param.kind.ident, { - tag: "FnParam", - stmt: node, - param, - }); + const sym: Sym = { tag: "FnParam", stmt: node, param, idx }; + syms.define(param.kind.ident, sym); + resols.set(param.id, sym); } node.visitBelow(this); syms = syms.parent!; @@ -196,7 +340,9 @@ export function resolve( if (k.tag === "LetStmt") { const stmt = node as ast.NodeWithKind<"LetStmt">; const param = k.param as ast.NodeWithKind<"Param">; - syms.define(param.kind.ident, { tag: "Let", stmt, param }); + const sym: Sym = { tag: "Let", stmt, param }; + syms.define(param.kind.ident, sym); + resols.set(param.id, sym); } if (k.tag === "IdentExpr") { @@ -480,18 +626,19 @@ export function tokenize(text: string): Tok[] { export function printDiagnostics( filename: string, line: number, - severity: "error", + severity: "error" | "info", message: string, text?: string, ) { const severityColor = ({ "error": "red", + "info": "blue", } as { [Key in typeof severity]: string })[severity]; console.error( `%c${severity}%c: ${message}\n %c--> ${filename}:${line}%c`, `color: ${severityColor}; font-weight: bold;`, - "color: white; font-weight: bold;", + "color: lightwhite; font-weight: bold;", "color: gray;", "", ); @@ -512,7 +659,7 @@ export function printDiagnostics( `${" ".repeat(lineNumberText.length)}%c|` + `%c${"~".repeat(lineText.length)}%c`, "color: cyan;", - "color: white;", + "color: lightwhite;", "color: cyan;", `color: ${severityColor};`, "", diff --git a/src/main.ts b/src/main.ts index 6648e29..9461669 100644 --- a/src/main.ts +++ b/src/main.ts @@ -1,5 +1,6 @@ -import * as front from "./front.ts"; import * as ast from "./ast.ts"; +import * as front from "./front.ts"; +import * as middle from "./middle.ts"; const filename = Deno.args[0]; const text = await Deno.readTextFile(filename); @@ -29,6 +30,7 @@ if (!mainFn) { Deno.exit(1); } -const mainTy = checker.check(mainFn); +const m = new middle.MiddleLowerer(resols, checker); +const mainMiddleFn = m.lowerFn(mainFn); -console.log({ ast: fileAst, resols }); +console.log(mainMiddleFn.pretty()); diff --git a/src/middle.ts b/src/middle.ts index d6d9ae0..0768ac7 100644 --- a/src/middle.ts +++ b/src/middle.ts @@ -1,13 +1,262 @@ +import * as ast from "./ast.ts"; +import { Checker, ResolveMap } from "./front.ts"; +import { Ty } from "./ty.ts"; + +export class MiddleLowerer { + private fns = new Map(); + + constructor( + private resols: ResolveMap, + private checker: Checker, + ) {} + + lowerFn(stmt: ast.FnStmt): Fn { + if (this.fns.has(stmt.id)) { + return this.fns.get(stmt.id)!; + } + const fn = new FnLowerer(this, this.resols, this.checker, stmt).lower(); + this.fns.set(stmt.id, fn); + return fn; + } +} + +class FnLowerer { + private bbs: BasicBlock[] = [new BasicBlock([])]; + private localMap = new Map(); + + constructor( + private lowerer: MiddleLowerer, + private resols: ResolveMap, + private checker: Checker, + private stmt: ast.FnStmt, + ) {} + + lower(): Fn { + const ty = this.checker.check(this.stmt); + this.lowerBlock(this.stmt.kind.body.as("Block")); + return new Fn(this.stmt, ty, this.bbs); + } + + private lowerBlock(block: ast.Block) { + for (const stmt of block.kind.stmts) { + this.lowerStmt(stmt); + } + } + + private lowerStmt(stmt: ast.Node) { + if (stmt.is("LetStmt")) { + const ty = this.checker.check(stmt.kind.param); + const expr = this.lowerExpr(stmt.kind.expr); + const local = this.pushInst(ty, "AllocLocal", {}); + this.pushInst(Ty.Void, "LocalStore", { + target: local, + source: expr, + }); + this.localMap.set(stmt.kind.param.id, local); + return; + } + if (stmt.is("ReturnStmt")) { + const source = stmt.kind.expr + ? this.lowerExpr(stmt.kind.expr) + : this.makeVoid(); + this.pushInst(Ty.Void, "Return", { source }); + return; + } + if (stmt.is("ExprStmt")) { + this.lowerExpr(stmt.kind.expr); + return; + } + throw new Error(`'${stmt.kind.tag}' not handled`); + } + + private lowerExpr(expr: ast.Node): Inst { + if (expr.is("IdentExpr")) { + const sym = this.resols.get(expr); + if (sym.tag === "Fn") { + const fn = this.lowerer.lowerFn(sym.stmt); + return this.pushInst(fn.ty, "Fn", { fn }); + } + if (sym.tag === "FnParam") { + const ty = this.checker.check(sym.param); + return this.pushInst(ty, "Param", { idx: sym.idx }); + } + if (sym.tag === "Builtin") { + throw new Error("handle elsewhere"); + } + if (sym.tag === "Let") { + const local = this.localMap.get(sym.param.id); + if (!local) { + throw new Error(); + } + return this.pushInst(local.ty, "LocalLoad", { source: local }); + } + throw new Error(`'${sym.tag}' not handled`); + } + if (expr.is("IntExpr")) { + return this.pushInst(Ty.Int, "Int", { value: expr.kind.value }); + } + if (expr.is("CallExpr")) { + const ty = this.checker.check(expr); + const args = expr.kind.args + .map((arg) => this.lowerExpr(arg)); + + if (expr.kind.expr.is("IdentExpr")) { + const sym = this.resols.get(expr.kind.expr); + if (sym.tag === "Builtin") { + if (sym.id === "__add") { + const [left, right] = args; + return this.pushInst(ty, "Add", { left, right }); + } + if (sym.id === "print_int") { + return this.pushInst(ty, "DebugPrint", { args }); + } + throw new Error(`builtin '${sym.id}' not handled`); + } + } + + const callee = this.lowerExpr(expr.kind.expr); + return this.pushInst(ty, "Call", { callee, args }); + } + throw new Error(`'${expr.kind.tag}' not handled`); + } + + private makeVoid(): Inst { + return this.pushInst(Ty.Void, "Void", {}); + } + + private pushInst< + Tag extends InstKind["tag"], + >( + ty: Ty, + tag: Tag, + kind: Omit, + ): Inst { + const inst = new Inst(ty, { tag, ...kind } as InstKind); + this.bbs.at(-1)!.insts.push(inst); + return inst; + } +} + +export class Fn { + constructor( + public stmt: ast.FnStmt, + public ty: Ty, + public bbs: BasicBlock[], + ) {} + + pretty(): string { + const fnTy = this.ty.isKind("FnStmt") && this.ty.kind.ty.isKind("Fn") + ? this.ty.kind.ty + : null; + if (!fnTy) { + throw new Error(); + } + const cx = new PrettyCx(); + return `fn ${this.stmt.kind.ident}(${ + fnTy.kind.params + .map((ty, idx) => `${idx}: ${ty.pretty()}`) + .join(", ") + }) -> ${fnTy.kind.retTy.pretty()}\n{\n${ + this.bbs + .map((bb) => bb.pretty(cx)) + .join("\n") + }\n}`; + } +} + +class IdMap { + private map = new Map(); + private counter = 0; + + id(val: T): number { + if (!this.map.has(val)) { + this.map.set(val, this.counter++); + } + return this.map.get(val)!; + } +} + +class PrettyCx { + private bbIds = new IdMap(); + private regIds = new IdMap(); + + bbId(bb: BasicBlock): number { + return this.bbIds.id(bb); + } + regId(reg: Inst): number { + return this.regIds.id(reg); + } +} + export class BasicBlock { constructor( - public instructions: Inst[], + public insts: Inst[], ) {} + + pretty(cx: PrettyCx): string { + return `bb${cx.bbId(this)}:\n${ + this.insts + .map((inst) => inst.pretty(cx)) + .map((line) => ` ${line}`) + .join("\n") + }`; + } } export class Inst { - constructor() {} + constructor( + public ty: Ty, + public kind: InstKind, + ) {} + + pretty(cx: PrettyCx): string { + const r = (v: Inst) => `_${cx.regId(v)}`; + + return `${r(this)}: ${this.ty.pretty()} = ${this.kind.tag}${ + (() => { + const k = this.kind; + switch (k.tag) { + case "Error": + return ""; + case "Void": + return ""; + case "Int": + return ` ${k.value}`; + case "Fn": + return ` ${k.fn.stmt.kind.ident}`; + case "Param": + return ` ${k.idx}`; + case "Call": + return ` ${r(k.callee)} (${k.args.map(r).join(", ")})`; + case "AllocLocal": + return ""; + case "LocalLoad": + return ` ${r(k.source)}`; + case "LocalStore": + return ` ${r(k.target)}, ${r(k.source)}`; + case "Return": + return ` ${r(k.source)}`; + case "Add": + return ` ${r(k.left)} ${r(k.right)}`; + case "DebugPrint": + return ` ${k.args.map(r).join(", ")}`; + } + const _: never = k; + })() + }`; + } } -export type InsKind = +export type InstKind = | { tag: "Error" } - | { tag: "Call" }; + | { tag: "Void" } + | { tag: "Int"; value: number } + | { tag: "Fn"; fn: Fn } + | { tag: "Param"; idx: number } + | { tag: "Call"; callee: Inst; args: Inst[] } + | { tag: "AllocLocal" } + | { tag: "LocalLoad"; source: Inst } + | { tag: "LocalStore"; target: Inst; source: Inst } + | { tag: "Return"; source: Inst } + | { tag: "Add"; left: Inst; right: Inst } + | { tag: "DebugPrint"; args: Inst[] }; diff --git a/src/root_syms.ts b/src/root_syms.ts deleted file mode 100644 index a497a32..0000000 --- a/src/root_syms.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { Ty } from "./ty.ts"; - -export type RootSym = { - id: string; - ty: Ty; -}; - -export const rootSyms: RootSym[] = [ - { - id: "print_int", - ty: Ty.create("Fn", { - params: [], - retTy: Ty.Void, - }), - }, - { - id: "__add", - ty: Ty.create("Fn", { - params: [Ty.Int, Ty.Int], - retTy: Ty.Int, - }), - }, -]; diff --git a/src/ty.ts b/src/ty.ts index 37bec0b..05cd2f1 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -51,7 +51,59 @@ export class Ty { return other.isKind("Int"); } if (this.isKind("Fn")) { + if (!other.isKind("Fn")) { + return false; + } + for (const i of this.kind.params.keys()) { + if (!this.kind.params[i].compatibleWith(other.kind.params[i])) { + return false; + } + } + if (!this.kind.retTy.compatibleWith(other.kind.retTy)) { + return false; + } + return true; } + if (this.isKind("FnStmt")) { + if (!other.isKind("FnStmt")) { + return false; + } + if (!this.kind.ty.compatibleWith(other.kind.ty)) { + return false; + } + // redundant; sanity check + if (this.kind.stmt.id !== other.kind.stmt.id) { + return false; + } + return true; + } + throw new Error(`'${this.kind.tag}' not handled`); + } + + pretty(): string { + if (this.isKind("Error")) { + return ""; + } + if (this.isKind("Void")) { + return "void"; + } + if (this.isKind("Int")) { + return "int"; + } + if (this.isKind("Fn")) { + return `fn (${ + this.kind.params.map((param) => param.pretty()).join(", ") + }) -> ${this.kind.retTy.pretty()}`; + } + if (this.isKind("FnStmt")) { + if (!this.kind.ty.isKind("Fn")) throw new Error(); + return `fn ${this.kind.stmt.kind.ident}(${ + this.kind.ty.kind.params.map((param) => param.pretty()).join( + ", ", + ) + }) -> ${this.kind.ty.kind.retTy.pretty()}`; + } + throw new Error("unhandled"); } }