From ee00599362f6d4a53ded5dd745f84a212fb4619f Mon Sep 17 00:00:00 2001 From: sfja Date: Wed, 3 Jun 2026 08:32:54 +0200 Subject: [PATCH] lower to ir --- editor/src/editor/Board.ts | 261 +++++++++++++++++++++++++++++++++--- editor/src/editor/Cx.ts | 13 ++ editor/src/editor/ir.ts | 210 +++++++++++++++++++++++++++++ editor/src/editor/states.ts | 4 + 4 files changed, 473 insertions(+), 15 deletions(-) create mode 100644 editor/src/editor/ir.ts diff --git a/editor/src/editor/Board.ts b/editor/src/editor/Board.ts index 63a2cfc..f24b75d 100644 --- a/editor/src/editor/Board.ts +++ b/editor/src/editor/Board.ts @@ -7,6 +7,7 @@ import { v2, V2, } from "./V2"; +import * as ir from "./ir"; export class Board { private components: Component[] = []; @@ -192,6 +193,135 @@ export class Board { ); this.wires = this.wires.filter((wire) => !wire.isSelected(selection)); } + + toIr(): ir.Component { + console.log("Lowering to IR"); + + for (const comp of this.components) { + comp.markedWiresConnected = []; + } + for (const joint of this.joints) { + joint.markedWiresConnected = []; + } + for (const wire of this.wires) { + wire.markConnections(); + } + + const inputs = this.components.filter( + (comp) => comp.kind.label === "input", + ); + const outputs = this.components.filter( + (comp) => comp.kind.label === "output", + ); + + const inputIdcs = new Map(); + for (const [i, input] of inputs.entries()) { + inputIdcs.set(input, i); + } + const outputIdcs = new Map(); + for (const [i, output] of outputs.entries()) { + outputIdcs.set(output, i); + } + + const b = new ir.ComponentBuilder(inputs.length, outputs.length, "main"); + + const wireStates = new Map(); + for (const wire of this.wires) { + wireStates.set(wire, b.makeState()); + } + + const compSet = new Set(); + const jointSet = new Set(); + const wireSet = new Set(); + + const visitor: BoardVisitor = { + visitComponent: (comp) => { + if (compSet.has(comp)) return "break"; + compSet.add(comp); + + const inputStates = new Map(); + for (const [wire, connection] of comp.markedWiresConnected) { + if (connection.tag === "InputPin") { + inputStates.set(connection.i, wireStates.get(wire)!); + } + } + + const inputStmt = (i: number) => { + return inputStates.has(i) + ? b.makeGetState(inputStates.get(i)!) + : b.makeNull(); + }; + + const stmt = (() => { + switch (comp.kind.label) { + case "input": + return b.makeInput(inputIdcs.get(comp)!); + case "output": + return b.makeOutput(outputIdcs.get(comp)!, inputStmt(0)); + case "not": + return b.makeNot(inputStmt(0)); + case "and": + return b.makeBinary("And", inputStmt(0), inputStmt(1)); + case "or": + return b.makeBinary("Or", inputStmt(0), inputStmt(1)); + default: + throw new Error("not implemented"); + } + })(); + + for (const [wire, connection] of comp.markedWiresConnected) { + if (connection.tag === "OutputPin") { + b.makeSetState(wireStates.get(wire)!, stmt); + } + } + }, + visitJoint: (joint) => { + if (jointSet.has(joint)) return "break"; + jointSet.add(joint); + + const visited = joint.markedWiresConnected.filter(([wire]) => + wireSet.has(wire), + ); + if (visited.length > 1) { + throw new Error("joint has more than 1 input"); + } + + const notVisited = joint.markedWiresConnected.filter( + ([wire]) => !wireSet.has(wire), + ); + + const sourceState = wireStates.get(visited[0][0]); + if (!sourceState) { + throw new Error("assert"); + } + const src = b.makeGetState(sourceState); + + for (const [wire] of notVisited) { + const dst = wireStates.get(wire); + if (!dst) { + throw new Error("assert"); + } + b.makeSetState(dst, src); + } + }, + visitWire: (wire) => { + if (wireSet.has(wire)) return "break"; + wireSet.add(wire); + }, + }; + + for (const comp of this.components) { + comp.visitForward(visitor); + } + + return b.build(); + } +} + +export interface BoardVisitor { + visitComponent(comp: Component): void | "break"; + visitJoint(joint: Joint): void | "break"; + visitWire(wire: Wire): void | "break"; } export class ComponentRepo { @@ -221,6 +351,8 @@ export class ComponentRepo { } export class Component { + public markedWiresConnected: [Wire, WireConnection][] = []; + constructor( public kind: ComponentKind, public pos: V2, @@ -264,6 +396,17 @@ export class Component { outputPinPos(i: number): V2 { return this.pos.add(v2(this.kind.size.x, this.kind.outputPinOffsets()[i])); } + + visitForward(visitor: BoardVisitor) { + if (visitor.visitComponent(this) === "break") return; + for (const [wire, connection] of this.markedWiresConnected) { + switch (connection.tag) { + case "OutputPin": + wire.visitForward(visitor, connection); + break; + } + } + } } type ComponentMouseOverResult = @@ -291,11 +434,23 @@ export class ComponentKind { } export class Joint { + public markedWiresConnected: [Wire, WireConnection][] = []; + constructor(public pos: V2) {} isMouseOver(pos: V2): boolean { return this.pos.distance(pos) < 6; } + + visitForward(visitor: BoardVisitor, entryWire: Wire) { + if (visitor.visitJoint(this) === "break") return; + for (const [wire, connection] of this.markedWiresConnected) { + if (wire === entryWire) { + continue; + } + wire.visitForward(visitor, connection); + } + } } export class Wire { @@ -304,6 +459,26 @@ export class Wire { private end: WireConnection, ) {} + isInput(): boolean { + return this.mapConns((connection) => connection.tag === "InputPin").some( + (v) => v, + ); + } + + markConnections() { + this.mapConns((connection) => { + switch (connection.tag) { + case "InputPin": + case "OutputPin": + connection.comp.markedWiresConnected.push([this, connection]); + break; + case "Joint": + connection.joint.markedWiresConnected.push([this, connection]); + break; + } + }); + } + isMouseOver(pos: V2): boolean { const distance = lineSegmentPointDistance( this.beginPos(), @@ -314,23 +489,66 @@ export class Wire { } isSelected(selection: Selection): boolean { - return ( - this.connectionIsSelected(this.begin, selection) || - this.connectionIsSelected(this.end, selection) - ); + return this.mapConns((connection) => { + switch (connection.tag) { + case "InputPin": + case "OutputPin": + return selection.isComponentSelected(connection.comp); + case "Joint": + return selection.isJointSelected(connection.joint); + } + }).some((v) => v); } - private connectionIsSelected( - connection: WireConnection, - selection: Selection, - ): boolean { - switch (connection.tag) { - case "InputPin": - case "OutputPin": - return selection.isComponentSelected(connection.comp); - case "Joint": - return selection.isJointSelected(connection.joint); - } + connectedToComponent(comp: Component): boolean { + return this.mapConns((connection) => { + switch (connection.tag) { + case "InputPin": + case "OutputPin": + return connection.comp === comp; + case "Joint": + return false; + } + }).some((v) => v); + } + connectedToJoint(joint: Joint): boolean { + return this.mapConns((connection) => { + switch (connection.tag) { + case "InputPin": + case "OutputPin": + return false; + case "Joint": + return connection.joint === joint; + } + }).some((v) => v); + } + + connectedComponents(): Component[] { + return this.mapConns((connection) => { + switch (connection.tag) { + case "InputPin": + case "OutputPin": + return [connection.comp]; + case "Joint": + return []; + } + }).flat(); + } + + connectedJoints(): Joint[] { + return this.mapConns((connection) => { + switch (connection.tag) { + case "InputPin": + case "OutputPin": + return []; + case "Joint": + return [connection.joint]; + } + }).flat(); + } + + private mapConns(mapper: (connection: WireConnection) => R): [R, R] { + return [mapper(this.begin), mapper(this.end)]; } beginPos(): V2 { @@ -351,6 +569,19 @@ export class Wire { return connection.joint.pos; } } + + visitForward(visitor: BoardVisitor, prev: WireConnection) { + if (visitor.visitWire(this) === "break") return; + const connection = this.begin === prev ? this.end : this.begin; + switch (connection.tag) { + case "InputPin": + connection.comp.visitForward(visitor); + break; + case "Joint": + connection.joint.visitForward(visitor, this); + break; + } + } } export type WireConnection = diff --git a/editor/src/editor/Cx.ts b/editor/src/editor/Cx.ts index 33263d8..1b4d878 100644 --- a/editor/src/editor/Cx.ts +++ b/editor/src/editor/Cx.ts @@ -8,6 +8,7 @@ import { import { Renderer } from "./Renderer"; import * as states from "./states"; import { v2, V2 } from "./V2"; +import * as ir from "./ir"; export type Tool = string; @@ -136,6 +137,18 @@ export class Cx { const absY = pos.y - this.offset.y; return v2(absX, absY); } + + runSimulation() { + const comp = this.board.toIr(); + console.log("Before optimizing"); + console.log(new ir.ComponentPrinter().stringify(comp)); + + const optimizer = new ir.ComponentOptimizer(comp); + optimizer.optimize(); + + console.log("After optimizing"); + console.log(new ir.ComponentPrinter().stringify(comp)); + } } export class SelectionBox { diff --git a/editor/src/editor/ir.ts b/editor/src/editor/ir.ts new file mode 100644 index 0000000..6c62ddf --- /dev/null +++ b/editor/src/editor/ir.ts @@ -0,0 +1,210 @@ +export class Component { + constructor( + public stmts: Stmt[], + public states: State[], + public inputs: number, + public outputs: number, + public label: string, + ) {} +} + +export class Stmt { + constructor(public kind: StmtKind) {} +} + +export type StmtKind = + | { tag: "Null" } + | { tag: "Input"; i: number } + | { tag: "Output"; i: number; src: Stmt } + | { tag: "GetState"; state: State } + | { tag: "SetState"; state: State; src: Stmt } + | { tag: "Not"; op: Stmt } + | { tag: "And" | "Or"; lhs: Stmt; rhs: Stmt } + | { tag: "Component"; comp: Component; inputs: Stmt[]; outputs: Stmt[] }; + +export class State {} + +export class ComponentBuilder { + private stmts: Stmt[] = []; + private states: State[] = []; + + constructor( + private inputs: number, + private outputs: number, + private label: string, + ) {} + + makeState(): State { + const state = new State(); + this.states.push(state); + return state; + } + + makeNull(): Stmt { + return this.makeStmt({ tag: "Null" }); + } + makeInput(i: number): Stmt { + return this.makeStmt({ tag: "Input", i }); + } + makeOutput(i: number, src: Stmt): Stmt { + return this.makeStmt({ tag: "Output", i, src }); + } + makeGetState(state: State): Stmt { + return this.makeStmt({ tag: "GetState", state }); + } + makeSetState(state: State, src: Stmt): Stmt { + return this.makeStmt({ tag: "SetState", state, src }); + } + makeNot(op: Stmt): Stmt { + return this.makeStmt({ tag: "Not", op }); + } + makeBinary(tag: "And" | "Or", lhs: Stmt, rhs: Stmt): Stmt { + return this.makeStmt({ tag, lhs, rhs }); + } + + private makeStmt(kind: StmtKind): Stmt { + const stmt = new Stmt(kind); + this.stmts.push(stmt); + return stmt; + } + + build(): Component { + return new Component( + this.stmts, + this.states, + this.inputs, + this.outputs, + this.label, + ); + } +} + +class StmtsMutater { + constructor(private comp: Component) {} + + replaceStmt(oldStmt: Stmt, newStmt: Stmt) { + for (const stmt of this.comp.stmts) { + const k = stmt.kind; + switch (k.tag) { + case "Null": + case "Input": + case "GetState": + break; + case "Output": + case "SetState": + if (k.src === oldStmt) k.src = newStmt; + break; + case "Not": + if (k.op === oldStmt) k.op = newStmt; + break; + case "And": + case "Or": + if (k.lhs === oldStmt) k.lhs = newStmt; + if (k.rhs === oldStmt) k.rhs = newStmt; + break; + case "Component": + throw new Error("not implemented"); + } + } + } + + removeStmtAt(i: number) { + this.comp.stmts.splice(i, 1); + } +} + +export class ComponentOptimizer { + constructor(private comp: Component) {} + + optimize() { + let lengthBefore: number; + do { + lengthBefore = this.comp.stmts.length; + + this.eliminateRedundantState(); + this.hoistInputs(); + } while (this.comp.stmts.length !== lengthBefore); + } + + eliminateRedundantState() { + const mut = new StmtsMutater(this.comp); + const immediatelyReadStateStmt = new Map(); + + for (const [i, stmt] of this.comp.stmts.entries()) { + const k = stmt.kind; + switch (k.tag) { + case "GetState": { + const candidate = immediatelyReadStateStmt.get(k.state); + if (candidate) { + mut.replaceStmt(stmt, candidate); + mut.removeStmtAt(i); + } + break; + } + case "SetState": + immediatelyReadStateStmt.set(k.state, k.src); + break; + } + } + } + + hoistInputs() { + const inputs = this.comp.stmts.filter((stmt) => stmt.kind.tag === "Input"); + const notInputs = this.comp.stmts.filter( + (stmt) => stmt.kind.tag !== "Input", + ); + this.comp.stmts = [...inputs, ...notInputs]; + } +} + +export class ComponentPrinter { + private stmtIds = new Map(); + private stateIds = new Map(); + + stringify(comp: Component): string { + return ( + `component ${comp.label} ${comp.inputs} ${comp.outputs} {\n` + + // ` states [ ${comp.states.map((state) => this.stateId(state)).join(", ")} ]\n` + + `${comp.stmts.map((stmt) => ` ${this.stringifyStmt(stmt)}\n`).join("")}}` + ); + } + + private stmtId(stmt: Stmt): string { + if (!this.stmtIds.has(stmt)) { + this.stmtIds.set(stmt, this.stmtIds.size); + } + return `%${this.stmtIds.get(stmt)!}`; + } + private stateId(state: State): string { + if (!this.stateIds.has(state)) { + this.stateIds.set(state, this.stateIds.size); + } + return `#${this.stateIds.get(state)!}`; + } + + private stringifyStmt(stmt: Stmt) { + const stmtId = (stmt: Stmt) => this.stmtId(stmt); + const stateId = (state: State) => this.stateId(state); + + const k = stmt.kind; + switch (k.tag) { + case "Null": + return `${stmtId(stmt)} = Null`; + case "Input": + return `${stmtId(stmt)} = Input ${k.i}`; + case "Output": + return `Output ${k.i}, ${stmtId(k.src)}`; + case "GetState": + return `${stmtId(stmt)} = GetState ${stateId(k.state)}`; + case "SetState": + return `SetState ${stateId(k.state)}, ${stmtId(k.src)}`; + case "Not": + return `${stmtId(stmt)} = Not ${stmtId(k.op)}`; + case "And": + case "Or": + return `${stmtId(stmt)} = ${k.tag} ${stmtId(k.lhs)}, ${stmtId(k.rhs)}`; + case "Component": + return `Component <...>`; + } + } +} diff --git a/editor/src/editor/states.ts b/editor/src/editor/states.ts index f5fcc22..c5867f2 100644 --- a/editor/src/editor/states.ts +++ b/editor/src/editor/states.ts @@ -20,6 +20,10 @@ export class Normal implements State { constructor(private cx: Cx) {} + enterState(): void { + this.cx.runSimulation(); + } + onMouseDown(pos: V2): void { if ( this.cx.board.handleMouseClick(pos.sub(this.cx.offset), {