optimize ir

This commit is contained in:
sfja 2026-06-04 05:41:34 +02:00
parent ee00599362
commit 54e82d4a8d
2 changed files with 229 additions and 33 deletions

View File

@ -141,13 +141,12 @@ export class Cx {
runSimulation() { runSimulation() {
const comp = this.board.toIr(); const comp = this.board.toIr();
console.log("Before optimizing"); console.log("Before optimizing");
console.log(new ir.ComponentPrinter().stringify(comp)); console.log(...new ir.ComponentPrinter().stringifyToConsole(comp));
const optimizer = new ir.ComponentOptimizer(comp); new ir.ComponentOptimizer(comp).optimize();
optimizer.optimize();
console.log("After optimizing"); console.log("After optimizing");
console.log(new ir.ComponentPrinter().stringify(comp)); console.log(...new ir.ComponentPrinter().stringifyToConsole(comp));
} }
} }

View File

@ -10,6 +10,68 @@ export class Component {
export class Stmt { export class Stmt {
constructor(public kind: StmtKind) {} constructor(public kind: StmtKind) {}
sources(): Stmt[] {
const k = this.kind;
switch (k.tag) {
case "Null":
return [];
case "Input":
return [];
case "Output":
return [k.src];
case "GetState":
return [];
case "SetState":
return [k.src];
case "Not":
return [k.op];
case "And":
case "Or":
return [k.lhs, k.rhs];
case "Component":
return [...k.inputs];
}
}
replaceSource(oldStmt: Stmt, newStmt: Stmt) {
const k = this.kind;
switch (k.tag) {
case "Null":
case "Input":
case "GetState":
break;
case "Output":
case "SetState":
if (k.src === oldStmt) k.src = newStmt;
break;
case "Not":
if (k.op === oldStmt) k.op = newStmt;
break;
case "And":
case "Or":
if (k.lhs === oldStmt) k.lhs = newStmt;
if (k.rhs === oldStmt) k.rhs = newStmt;
break;
case "Component":
k.inputs = k.inputs.map((stmt) =>
stmt === oldStmt ? newStmt : oldStmt,
);
break;
}
}
replaceState(oldState: State, newState: State) {
const k = this.kind;
switch (k.tag) {
case "GetState":
case "SetState":
if (k.state === oldState) k.state = newState;
break;
default:
break;
}
}
} }
export type StmtKind = export type StmtKind =
@ -82,34 +144,35 @@ export class ComponentBuilder {
class StmtsMutater { class StmtsMutater {
constructor(private comp: Component) {} constructor(private comp: Component) {}
replaceStmt(oldStmt: Stmt, newStmt: Stmt) { [Symbol.iterator](): Iterator<Stmt> {
for (const stmt of this.comp.stmts) { return this.comp.stmts[Symbol.iterator]();
const k = stmt.kind;
switch (k.tag) {
case "Null":
case "Input":
case "GetState":
break;
case "Output":
case "SetState":
if (k.src === oldStmt) k.src = newStmt;
break;
case "Not":
if (k.op === oldStmt) k.op = newStmt;
break;
case "And":
case "Or":
if (k.lhs === oldStmt) k.lhs = newStmt;
if (k.rhs === oldStmt) k.rhs = newStmt;
break;
case "Component":
throw new Error("not implemented");
} }
replaceSource(oldStmt: Stmt, newStmt: Stmt) {
for (const stmt of this.comp.stmts) {
stmt.replaceSource(oldStmt, newStmt);
} }
} }
removeStmtAt(i: number) { replaceState(oldState: State, newState: State) {
this.comp.stmts.splice(i, 1); for (const stmt of this.comp.stmts) {
stmt.replaceState(oldState, newState);
}
}
removeStmt(stmt: Stmt) {
this.comp.stmts = this.comp.stmts.filter((s) => s !== stmt);
}
removeStmtAt(i: number): Stmt {
return this.comp.stmts.splice(i, 1)[0];
}
insertStmtAt(i: number, stmt: Stmt) {
this.comp.stmts = [
...this.comp.stmts.slice(0, i),
stmt,
...this.comp.stmts.slice(i),
];
} }
} }
@ -117,13 +180,19 @@ export class ComponentOptimizer {
constructor(private comp: Component) {} constructor(private comp: Component) {}
optimize() { optimize() {
let lengthBefore: number; const score = () => this.comp.stmts.length * 100 + this.comp.states.length;
let scoreBefore: number;
do { do {
lengthBefore = this.comp.stmts.length; scoreBefore = score();
this.eliminateRedundantState(); this.eliminateRedundantState();
this.hoistInputs(); this.hoistInputs();
} while (this.comp.stmts.length !== lengthBefore); this.moveSetStateToSource();
this.collapseStates();
this.eliminateUnusedStates();
this.eliminateRedundantSetState();
} while (score() !== scoreBefore);
} }
eliminateRedundantState() { eliminateRedundantState() {
@ -136,7 +205,7 @@ export class ComponentOptimizer {
case "GetState": { case "GetState": {
const candidate = immediatelyReadStateStmt.get(k.state); const candidate = immediatelyReadStateStmt.get(k.state);
if (candidate) { if (candidate) {
mut.replaceStmt(stmt, candidate); mut.replaceSource(stmt, candidate);
mut.removeStmtAt(i); mut.removeStmtAt(i);
} }
break; break;
@ -155,6 +224,106 @@ export class ComponentOptimizer {
); );
this.comp.stmts = [...inputs, ...notInputs]; this.comp.stmts = [...inputs, ...notInputs];
} }
moveSetStateToSource() {
const mut = new StmtsMutater(this.comp);
for (const [baseIdx, stmt] of this.comp.stmts.entries()) {
const indices = this.indexMap();
const sourceIndices = stmt.sources().map((stmt) => indices.get(stmt)!);
if (sourceIndices.length == 0) {
continue;
}
const lastSourceIndex = sourceIndices.reduce((p, v) => Math.max(p, v));
if (lastSourceIndex >= baseIdx - 1) {
continue;
}
mut.removeStmt(stmt);
mut.insertStmtAt(lastSourceIndex + 1, stmt);
}
}
collapseStates() {
const mut = new StmtsMutater(this.comp);
const sourceStates = new MultiMap<Stmt, State>();
for (const stmt of this.comp.stmts) {
if (stmt.kind.tag !== "SetState") continue;
sourceStates.add(stmt.kind.src, stmt.kind.state);
}
for (const [_stmt, [newState, ...oldStates]] of sourceStates) {
for (const oldState of oldStates) {
mut.replaceState(oldState, newState);
}
}
}
eliminateUnusedStates() {
const mut = new StmtsMutater(this.comp);
const usedStates = new Set<State>();
for (const stmt of mut) {
const k = stmt.kind;
switch (k.tag) {
case "GetState":
case "SetState":
usedStates.add(k.state);
break;
default:
break;
}
}
this.comp.states = this.comp.states.filter((state) =>
usedStates.has(state),
);
}
eliminateRedundantSetState() {
const mut = new StmtsMutater(this.comp);
for (let i = this.comp.stmts.length - 1; i > 0; --i) {
const [first, second] = this.comp.stmts.slice(i - 1, i + 1);
if (
first.kind.tag === "SetState" &&
second.kind.tag === first.kind.tag &&
first.kind.state === second.kind.state &&
first.kind.src === second.kind.src
) {
mut.removeStmt(second);
}
}
}
private indexMap(): Map<Stmt, number> {
return new Map(this.comp.stmts.map((stmt, i) => [stmt, i]));
}
}
class MultiMap<Key, Value> {
private map = new Map<Key, Value[]>();
add(key: Key, ...values: Value[]) {
if (!this.map.has(key)) {
this.map.set(key, []);
}
this.map.get(key)!.push(...values);
}
get(key: Key): Value[] {
return this.map.get(key) ?? [];
}
[Symbol.iterator](): Iterator<[Key, Value[]]> {
return this.map[Symbol.iterator]();
}
} }
export class ComponentPrinter { export class ComponentPrinter {
@ -164,11 +333,39 @@ export class ComponentPrinter {
stringify(comp: Component): string { stringify(comp: Component): string {
return ( return (
`component ${comp.label} ${comp.inputs} ${comp.outputs} {\n` + `component ${comp.label} ${comp.inputs} ${comp.outputs} {\n` +
// ` states [ ${comp.states.map((state) => this.stateId(state)).join(", ")} ]\n` + ` state [ ${comp.states.map((state) => this.stateId(state)).join(", ")} ]\n` +
`${comp.stmts.map((stmt) => ` ${this.stringifyStmt(stmt)}\n`).join("")}}` `${comp.stmts.map((stmt) => ` ${this.stringifyStmt(stmt)}\n`).join("")}}`
); );
} }
stringifyToConsole(comp: Component): string[] {
let fmt = this.stringify(comp)
.replaceAll(/%\d+/g, "\\c(color: cyan)$&\\c")
.replaceAll(/#\d+/g, "\\c(color: lightgreen)$&\\c")
.replaceAll(/ \d+/g, "\\c(color: #bf8bf0)$&\\c")
.replaceAll(
/(?:component)|(?:state)/g,
"\\c(color: #d44949; font-weight: bold)$&\\c",
)
.replaceAll(
/(?:Null)|(?:Input)|(?:Output)|(?:GetState)|(?:SetState)|(?:Not)|(?:And)|(?:Or)|(?:Component)/g,
"\\c(color: orange)$&\\c",
);
const selectors: string[] = [];
let match;
while ((match = fmt.match(/\\c(?:\((.*?)\))?/))) {
fmt = fmt.replace(/\\c(?:\(.*?\))?/, "\r%c");
selectors.push(match[1]);
}
fmt += "%c";
selectors.push("");
return [fmt, ...selectors];
}
private stmtId(stmt: Stmt): string { private stmtId(stmt: Stmt): string {
if (!this.stmtIds.has(stmt)) { if (!this.stmtIds.has(stmt)) {
this.stmtIds.set(stmt, this.stmtIds.size); this.stmtIds.set(stmt, this.stmtIds.size);