diff --git a/src/mir.ts b/src/mir.ts index 683557b..a2f4b7a 100644 --- a/src/mir.ts +++ b/src/mir.ts @@ -1,5 +1,6 @@ import * as ast from "./ast.ts"; import { Ty } from "./ty.ts"; +import * as stringify from "./stringify.ts"; export interface Visitor { visitFn?(fn: Fn): void; @@ -22,46 +23,7 @@ export class Fn { } pretty(): string { - const fnTy = this.ty.is("FnStmt") && this.ty.kind.ty.is("Fn") - ? this.ty.kind.ty - : null; - if (!fnTy) { - throw new Error(); - } - const cx = new PrettyCx(); - return `fn ${this.stmt.kind.ident}(${ - fnTy.kind.params - .map((ty, idx) => `${idx}: ${ty.pretty()}`) - .join(", ") - }) -> ${fnTy.kind.retTy.pretty()}\n{\n${ - this.bbs - .map((bb) => bb.pretty(cx)) - .join("\n") - }\n}`; - } -} - -class IdMap { - private map = new Map(); - private counter = 0; - - id(val: T): number { - if (!this.map.has(val)) { - this.map.set(val, this.counter++); - } - return this.map.get(val)!; - } -} - -export class PrettyCx { - private bbIds = new IdMap(); - private regIds = new IdMap(); - - bbId(bb: BasicBlock): number { - return this.bbIds.id(bb); - } - regId(reg: Inst): number { - return this.regIds.id(reg); + return stringify.mirFnPretty(this); } } @@ -74,18 +36,6 @@ export class BasicBlock { inst.visit(v); } } - - pretty(cx: PrettyCx): string { - const consts = ["Void", "Int", "Bool", "Array"]; - - return `bb${cx.bbId(this)}:\n${ - this.insts - .filter((inst) => !consts.includes(inst.kind.tag)) - .map((inst) => inst.pretty(cx)) - .map((line) => ` ${line}`) - .join("\n") - }`; - } } export class Inst { @@ -97,88 +47,6 @@ export class Inst { visit(v: Visitor) { v.visitInst?.(this); } - - pretty(cx: PrettyCx): string { - const valueless = ["Store", "Jump", "Branch", "Return", "DebugPrint"]; - const valueType = `%${cx.regId(this)} (${this.ty.pretty()}) = `; - return `${ - !valueless.includes(this.kind.tag) ? valueType : "" - }${this.kind.tag} ${this.prettyArgs(cx)}`; - } - - private prettyArgs(cx: PrettyCx): string { - const consts = ["Void", "Int", "Bool", "Array"]; - - const r = (v: Inst) => - consts.includes(v.kind.tag) ? v.prettyArgs(cx) : `%${cx.regId(v)}`; - - const k = this.kind; - switch (k.tag) { - case "Error": - return ""; - case "Void": - return ""; - case "Int": - return `${k.value}${k.intTy}`; - case "Bool": - return `${k.value}`; - case "Str": - return `${JSON.stringify(k.value)}`; - case "Array": - return `[${k.values.map(r).join(", ")}]`; - case "Fn": - return `${k.fn.stmt.kind.ident}`; - case "Param": - return `${k.idx}`; - case "GetElemPtr": - return `&[ptr ${r(k.base)}][${r(k.offset)}]`; - case "Slice": - return `&[ptr ${r(k.value)}][${k.begin ? r(k.begin) : ""}..${ - k.end ? r(k.end) : "" - }]`; - case "Call": - return `${r(k.callee)} (${k.args.map(r).join(", ")})`; - case "Alloca": - return ""; - case "Load": - return `[ptr ${r(k.source)}]`; - case "Store": - return `[ptr ${r(k.target)}] = ${r(k.source)}`; - case "Jump": - return `bb${cx.bbId(k.target)}`; - case "Branch": - return `if ${r(k.cond)}: bb${cx.bbId(k.truthy)}, else: bb${ - cx.bbId(k.falsy) - }`; - case "Return": - return `${r(k.source)}`; - case "Not": - case "Negate": - return `${r(k.source)}`; - case "Eq": - case "Ne": - case "Lt": - case "Gt": - case "Lte": - case "Gte": - case "BitOr": - case "BitXor": - case "BitAnd": - case "Shl": - case "Shr": - case "Add": - case "Sub": - case "Mul": - case "Div": - case "Rem": - return `${r(k.left)} ${r(k.right)}`; - case "Len": - return `${r(k.source)}`; - case "DebugPrint": - return `${k.args.map(r).join(", ")}`; - } - k satisfies never; - } } export type InstKind = diff --git a/src/stringify.ts b/src/stringify.ts new file mode 100644 index 0000000..b5a768b --- /dev/null +++ b/src/stringify.ts @@ -0,0 +1,367 @@ +import * as ty from "./ty.ts"; +import * as mir from "./mir.ts"; + +export function tyPretty(ty: ty.Ty): string { + switch (ty.kind.tag) { + case "Error": + return ""; + case "Void": + return "void"; + case "IntLiteral": + return "{integer}"; + case "Int": + return `${ty.kind.intTy}`; + case "Bool": + return "bool"; + case "Ptr": + return `*${ty.kind.ty.pretty()}`; + case "PtrMut": + return `*mut ${ty.kind.ty.pretty()}`; + case "Array": + return `[${ty.kind.ty.pretty()}; ${ty.kind.length}]`; + case "Slice": + return `[${ty.kind.ty.pretty()}]`; + case "Range": + return `Range`; + case "Fn": + return `fn (${ + ty.kind.params.map((param) => param.pretty()).join(", ") + }) -> ${ty.kind.retTy.pretty()}`; + case "FnStmt": + if (!ty.kind.ty.is("Fn")) { + throw new Error(); + } + return `fn ${ty.kind.stmt.kind.ident}(${ + ty.kind.ty.kind.params.map((param) => param.pretty()) + .join( + ", ", + ) + }) -> ${ty.kind.ty.kind.retTy.pretty()}`; + } + return ""; +} + +export function mirFnPretty(fn: mir.Fn): string { + // return mirFnPrettyNew(fn); + return mirFnPrettyOld(fn); +} +export function mirFnPrettyNew(fn: mir.Fn): string { + return new MirFnPrettyStringifier(fn).stringify(); +} +export function mirFnPrettyOld(fn: mir.Fn): string { + class IdMap { + private map = new Map(); + private counter = 0; + + id(val: T): number { + if (!this.map.has(val)) { + this.map.set(val, this.counter++); + } + return this.map.get(val)!; + } + } + + class PrettyCx { + private bbIds = new IdMap(); + private regIds = new IdMap(); + + bbId(bb: mir.BasicBlock): number { + return this.bbIds.id(bb); + } + regId(reg: mir.Inst): number { + return this.regIds.id(reg); + } + } + + function mirBasicBlockPretty(bb: mir.BasicBlock, cx: PrettyCx): string { + const consts = ["Void", "Int", "Bool", "Array"]; + + return `bb${cx.bbId(bb)}:\n${ + bb.insts + .filter((inst) => !consts.includes(inst.kind.tag)) + .map((inst) => mirInstPretty(inst, cx)) + .map((line) => ` ${line}`) + .join("\n") + }`; + } + + function mirInstPretty(inst: mir.Inst, cx: PrettyCx): string { + const valueless = ["Store", "Jump", "Branch", "Return", "DebugPrint"]; + const valueType = `%${cx.regId(inst)} (${inst.ty.pretty()}) = `; + return `${ + !valueless.includes(inst.kind.tag) ? valueType : "" + }${inst.kind.tag} ${mirInstArgsPretty(inst, cx)}`; + } + + function mirInstArgsPretty(inst: mir.Inst, cx: PrettyCx): string { + const consts = ["Void", "Int", "Bool", "Array"]; + + const r = (v: mir.Inst) => + consts.includes(v.kind.tag) + ? mirInstArgsPretty(v, cx) + : `%${cx.regId(v)}`; + + const k = inst.kind; + switch (k.tag) { + case "Error": + return ""; + case "Void": + return ""; + case "Int": + return `${k.value}${k.intTy}`; + case "Bool": + return `${k.value}`; + case "Str": + return `${JSON.stringify(k.value)}`; + case "Array": + return `[${k.values.map(r).join(", ")}]`; + case "Fn": + return `${k.fn.stmt.kind.ident}`; + case "Param": + return `${k.idx}`; + case "GetElemPtr": + return `&[ptr ${r(k.base)}][${r(k.offset)}]`; + case "Slice": + return `&[ptr ${r(k.value)}][${k.begin ? r(k.begin) : ""}..${ + k.end ? r(k.end) : "" + }]`; + case "Call": + return `${r(k.callee)} (${k.args.map(r).join(", ")})`; + case "Alloca": + return ""; + case "Load": + return `[ptr ${r(k.source)}]`; + case "Store": + return `[ptr ${r(k.target)}] = ${r(k.source)}`; + case "Jump": + return `bb${cx.bbId(k.target)}`; + case "Branch": + return `if ${r(k.cond)}: bb${cx.bbId(k.truthy)}, else: bb${ + cx.bbId(k.falsy) + }`; + case "Return": + return `${r(k.source)}`; + case "Not": + case "Negate": + return `${r(k.source)}`; + case "Eq": + case "Ne": + case "Lt": + case "Gt": + case "Lte": + case "Gte": + case "BitOr": + case "BitXor": + case "BitAnd": + case "Shl": + case "Shr": + case "Add": + case "Sub": + case "Mul": + case "Div": + case "Rem": + return `${r(k.left)} ${r(k.right)}`; + case "Len": + return `${r(k.source)}`; + case "DebugPrint": + return `${k.args.map(r).join(", ")}`; + } + return ""; + } + + const fnTy = fn.ty.is("FnStmt") && fn.ty.kind.ty.is("Fn") + ? fn.ty.kind.ty + : null; + if (!fnTy) { + throw new Error(); + } + const cx = new PrettyCx(); + return `fn ${fn.stmt.kind.ident}(${ + fnTy.kind.params + .map((ty, idx) => `${idx}: ${ty.pretty()}`) + .join(", ") + }) -> ${fnTy.kind.retTy.pretty()}\n{\n${ + fn.bbs + .map((bb) => mirBasicBlockPretty(bb, cx)) + .join("\n") + }\n}`; +} + +class MirFnPrettyStringifier { + private bbIds = new Map(); + private instIds = new Map(); + private result = ""; + + constructor( + private fn: mir.Fn, + ) {} + + stringify(): string { + const fnTy = this.fn.ty.is("FnStmt") && this.fn.ty.kind.ty.is("Fn") + ? this.fn.ty.kind.ty + : null; + if (!fnTy) { + throw new Error(); + } + + const ident = this.fn.stmt.kind.ident; + + const params = fnTy.kind.params + .map((ty, idx) => `${idx}: ${ty.pretty()}`) + .join(", "); + const retTy = fnTy.kind.retTy.pretty(); + + this.result += `fn ${ident}(${params}) -> ${retTy} {\n`; + + for (const bb of this.fn.bbs) { + this.basicBlock(bb); + } + + return this.result; + } + + private basicBlock(bb: mir.BasicBlock) { + this.result += `bb${this.bbId(bb)}:\n`; + + for (const inst of bb.insts) { + this.inst(inst); + } + } + + private instsPrinted = new Set(); + + private inst(inst: mir.Inst) { + if (this.instsPrinted.has(inst)) { + return; + } + + // inst is a value + const expr = (tail: string) => { + this.result += ` %${ + this.instId(inst) + } (${inst.ty.pretty()}) = ${inst.kind.tag}${tail}\n`; + }; + // inst is a statement + const stmt = (tail: string) => { + this.result += ` ${inst.kind.tag}${tail}\n`; + }; + + this.instTail(inst, expr, stmt); + + this.instsPrinted.add(inst); + } + + private instArg(inst: mir.Inst): string { + const inline: mir.InstKind["tag"][] = ["Void", "Int", "Bool", "Array"]; + if ( + inline.includes(inst.kind.tag) && + !this.instsPrinted.has(inst) + ) { + let str = `${inst.kind.tag}`; + this.instTail(inst, (s) => { + str += s; + }, () => {}); + return str; + } + this.inst(inst); + return `%${this.instId(inst)}`; + } + + private instTail( + inst: mir.Inst, + expr: (str: string) => void, + stmt: (str: string) => void, + ) { + const r = (inst: mir.Inst) => this.instArg(inst); + + const k = inst.kind; + switch (k.tag) { + case "Error": + return expr(``); + case "Void": + return expr(``); + case "Int": + return expr(` ${k.value}${k.intTy}`); + case "Bool": + return expr(` ${k.value}`); + case "Str": + return expr(` ${JSON.stringify(k.value)}`); + case "Array": + return expr(` [${k.values.map(r).join(", ")}]`); + case "Fn": + return expr(` ${k.fn.stmt.kind.ident}`); + case "Param": + return expr(` ${k.idx}`); + case "GetElemPtr": + return expr(` &(ptr ${r(k.base)})[${r(k.offset)}]`); + case "Slice": + return expr( + ` &[ptr ${r(k.value)}][${k.begin ? r(k.begin) : ""}..${ + k.end ? r(k.end) : "" + }]`, + ); + case "Call": + return expr(` ${r(k.callee)} (${k.args.map(r).join(", ")})`); + case "Alloca": + return expr(``); + case "Load": + return expr(` [ptr ${r(k.source)}]`); + case "Store": + return stmt(` [ptr ${r(k.target)}] = ${r(k.source)}`); + case "Jump": + return stmt(` bb${this.bbId(k.target)}`); + case "Branch": + return stmt( + ` if ${r(k.cond)}: bb${this.bbId(k.truthy)}, else: bb${ + this.bbId(k.falsy) + }`, + ); + case "Return": + return stmt(` ${r(k.source)}`); + case "Not": + case "Negate": + return expr(`${r(k.source)}`); + case "Eq": + case "Ne": + case "Lt": + case "Gt": + case "Lte": + case "Gte": + case "BitOr": + case "BitXor": + case "BitAnd": + case "Shl": + case "Shr": + case "Add": + case "Sub": + case "Mul": + case "Div": + case "Rem": + return expr(` ${r(k.left)}, ${r(k.right)}`); + case "Len": + return expr(` ${r(k.source)}`); + case "DebugPrint": + return stmt(` ${k.args.map(r).join(", ")}`); + } + } + + private bbId(bb: mir.BasicBlock) { + const found = this.bbIds.get(bb); + if (found !== undefined) { + return found; + } + const id = this.bbIds.size; + this.bbIds.set(bb, id); + return id; + } + + private instId(inst: mir.Inst) { + const found = this.instIds.get(inst); + if (found !== undefined) { + return found; + } + const id = this.instIds.size; + this.instIds.set(inst, id); + return id; + } +} diff --git a/src/ty.ts b/src/ty.ts index 9bb5ba9..ac029aa 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -1,4 +1,5 @@ import * as ast from "./ast.ts"; +import * as stringify from "./stringify.ts"; export class Ty { private static idCounter = 0; @@ -145,45 +146,7 @@ export class Ty { } pretty(): string { - switch (this.kind.tag) { - case "Error": - return ""; - case "Void": - return "void"; - case "IntLiteral": - return "{integer}"; - case "Int": - return `${this.kind.intTy}`; - case "Bool": - return "bool"; - case "Ptr": - return `*${this.kind.ty.pretty()}`; - case "PtrMut": - return `*mut ${this.kind.ty.pretty()}`; - case "Array": - return `[${this.kind.ty.pretty()}; ${this.kind.length}]`; - case "Slice": - return `[${this.kind.ty.pretty()}]`; - case "Range": - return `Range`; - case "Fn": - return `fn (${ - this.kind.params.map((param) => param.pretty()).join(", ") - }) -> ${this.kind.retTy.pretty()}`; - case "FnStmt": - if (!this.kind.ty.is("Fn")) { - throw new Error(); - } - return `fn ${this.kind.stmt.kind.ident}(${ - this.kind.ty.kind.params.map((param) => param.pretty()) - .join( - ", ", - ) - }) -> ${this.kind.ty.kind.retTy.pretty()}`; - default: - this.kind satisfies never; - } - throw new Error("unhandled"); + return stringify.tyPretty(this); } }