add visitor, apply on resolver

This commit is contained in:
sfja 2024-12-13 23:22:08 +01:00
parent a88c502529
commit b6c7f09f8b
4 changed files with 295 additions and 187 deletions

View File

@ -119,33 +119,3 @@ export type Anno = {
values: Expr[]; values: Expr[];
pos: Pos; pos: Pos;
}; };
export function stmtToString(stmt: Stmt): string {
const body = (() => {
switch (stmt.kind.type) {
case "assign":
return `{ subject: ${exprToString(stmt.kind.subject)}, value: ${
exprToString(stmt.kind.value)
} }`;
}
return "(<not implemented>)";
})();
const { line } = stmt.pos;
return `${stmt.kind.type}:${line}${body}`;
}
export function exprToString(expr: Expr): string {
const body = (() => {
switch (expr.kind.type) {
case "binary":
return `(${
exprToString(expr.kind.left)
} ${expr.kind.binaryType} ${exprToString(expr.kind.right)})`;
case "sym":
return `(${expr.kind.ident})`;
}
return "(<not implemented>)";
})();
const { line } = expr.pos;
return `${expr.kind.type}:${line}${body}`;
}

232
compiler/ast_visitor.ts Normal file
View File

@ -0,0 +1,232 @@
import { EType, Expr, Param, Stmt } from "./ast.ts";
export type VisitRes = "stop" | void;
export interface AstVisitor<Args extends unknown[] = []> {
visitStmts?(stmts: Stmt[], ...args: Args): VisitRes;
visitStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitErrorStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitImportStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitBreakStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitReturnStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitFnStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitLetStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitAssignStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitExprStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitExpr?(expr: Expr, ...args: Args): VisitRes;
visitErrorExpr?(expr: Expr, ...args: Args): VisitRes;
visitIntExpr?(expr: Expr, ...args: Args): VisitRes;
visitStringExpr?(expr: Expr, ...args: Args): VisitRes;
visitIdentExpr?(expr: Expr, ...args: Args): VisitRes;
visitGroupExpr?(expr: Expr, ...args: Args): VisitRes;
visitFieldExpr?(expr: Expr, ...args: Args): VisitRes;
visitIndexExpr?(expr: Expr, ...args: Args): VisitRes;
visitCallExpr?(expr: Expr, ...args: Args): VisitRes;
visitUnaryExpr?(expr: Expr, ...args: Args): VisitRes;
visitBinaryExpr?(expr: Expr, ...args: Args): VisitRes;
visitIfExpr?(expr: Expr, ...args: Args): VisitRes;
visitBoolExpr?(expr: Expr, ...args: Args): VisitRes;
visitNullExpr?(expr: Expr, ...args: Args): VisitRes;
visitLoopExpr?(expr: Expr, ...args: Args): VisitRes;
visitBlockExpr?(expr: Expr, ...args: Args): VisitRes;
visitSymExpr?(expr: Expr, ...args: Args): VisitRes;
visitParam?(param: Param, ...args: Args): VisitRes;
visitEType?(etype: EType, ...args: Args): VisitRes;
visitErrorEType?(etype: EType, ...args: Args): VisitRes;
visitIdentEType?(etype: EType, ...args: Args): VisitRes;
visitArrayEType?(etype: EType, ...args: Args): VisitRes;
visitStructEType?(etype: EType, ...args: Args): VisitRes;
visitAnno?(etype: EType, ...args: Args): VisitRes;
}
export function visitStmts<Args extends unknown[] = []>(
stmts: Stmt[],
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitStmts?.(stmts, ...args) === "stop") return;
stmts.map((stmt) => visitStmt(stmt, v, ...args));
}
export function visitStmt<Args extends unknown[] = []>(
stmt: Stmt,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitStmt?.(stmt, ...args) == "stop") return;
switch (stmt.kind.type) {
case "error":
if (v.visitErrorStmt?.(stmt, ...args) == "stop") return;
break;
case "import":
if (v.visitImportStmt?.(stmt, ...args) == "stop") return;
visitExpr(stmt.kind.path, v, ...args);
break;
case "break":
if (v.visitBreakStmt?.(stmt, ...args) == "stop") return;
if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args);
break;
case "return":
if (v.visitReturnStmt?.(stmt, ...args) == "stop") return;
if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args);
break;
case "fn":
if (v.visitFnStmt?.(stmt, ...args) == "stop") return;
stmt.kind.params.map((param) => visitParam(param, v, ...args));
if (stmt.kind.returnType) {
visitEType(stmt.kind.returnType, v, ...args);
}
visitExpr(stmt.kind.body, v, ...args);
break;
case "let":
if (v.visitLetStmt?.(stmt, ...args) == "stop") return;
visitParam(stmt.kind.param, v, ...args);
visitExpr(stmt.kind.value, v, ...args);
break;
case "assign":
if (v.visitAssignStmt?.(stmt, ...args) == "stop") return;
visitExpr(stmt.kind.subject, v, ...args);
visitExpr(stmt.kind.value, v, ...args);
break;
case "expr":
if (v.visitExprStmt?.(stmt, ...args) == "stop") return;
visitExpr(stmt.kind.expr, v, ...args);
break;
}
}
export function visitExpr<Args extends unknown[] = []>(
expr: Expr,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitExpr?.(expr, ...args) == "stop") return;
switch (expr.kind.type) {
case "error":
if (v.visitErrorExpr?.(expr, ...args) == "stop") return;
break;
case "string":
if (v.visitStringExpr?.(expr, ...args) == "stop") return;
break;
case "int":
if (v.visitIntExpr?.(expr, ...args) == "stop") return;
break;
case "ident":
if (v.visitIdentExpr?.(expr, ...args) == "stop") return;
break;
case "group":
if (v.visitGroupExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.expr, v, ...args);
break;
case "field":
if (v.visitFieldExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
break;
case "index":
if (v.visitIndexExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
visitExpr(expr.kind.value, v, ...args);
break;
case "call":
if (v.visitCallExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
expr.kind.args.map((arg) => visitExpr(arg, v, ...args));
break;
case "unary":
if (v.visitUnaryExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
break;
case "binary":
if (v.visitBinaryExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.left, v, ...args);
visitExpr(expr.kind.right, v, ...args);
break;
case "if":
if (v.visitIfExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.cond, v, ...args);
visitExpr(expr.kind.truthy, v, ...args);
if (expr.kind.falsy) visitExpr(expr.kind.falsy, v, ...args);
break;
case "bool":
if (v.visitBoolExpr?.(expr, ...args) == "stop") return;
break;
case "null":
if (v.visitNullExpr?.(expr, ...args) == "stop") return;
break;
case "loop":
if (v.visitLoopExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.body, v, ...args);
break;
case "block":
if (v.visitBlockExpr?.(expr, ...args) == "stop") return;
expr.kind.stmts.map((stmt) => visitStmt(stmt, v, ...args));
if (expr.kind.expr) visitExpr(expr.kind.expr, v, ...args);
break;
case "sym":
if (v.visitSymExpr?.(expr, ...args) == "stop") return;
break;
}
}
export function visitParam<Args extends unknown[] = []>(
param: Param,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitParam?.(param, ...args) == "stop") return;
if (param.etype) visitEType(param.etype, v, ...args);
}
export function visitEType<Args extends unknown[] = []>(
etype: EType,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitEType?.(etype, ...args) == "stop") return;
switch (etype.kind.type) {
case "error":
if (v.visitErrorEType?.(etype, ...args) == "stop") return;
break;
case "ident":
if (v.visitIdentEType?.(etype, ...args) == "stop") return;
break;
case "array":
if (v.visitArrayEType?.(etype, ...args) == "stop") return;
if (etype.kind.inner) visitEType(etype.kind.inner, v, ...args);
break;
case "struct":
if (v.visitStructEType?.(etype, ...args) == "stop") return;
etype.kind.fields.map((field) => visitParam(field, v, ...args));
break;
}
}
export function stmtToString(stmt: Stmt): string {
const body = (() => {
switch (stmt.kind.type) {
case "assign":
return `{ subject: ${exprToString(stmt.kind.subject)}, value: ${
exprToString(stmt.kind.value)
} }`;
}
return "(<not implemented>)";
})();
const { line } = stmt.pos;
return `${stmt.kind.type}:${line}${body}`;
}
export function exprToString(expr: Expr): string {
const body = (() => {
switch (expr.kind.type) {
case "binary":
return `(${
exprToString(expr.kind.left)
} ${expr.kind.binaryType} ${exprToString(expr.kind.right)})`;
case "sym":
return `(${expr.kind.ident})`;
}
return "(<not implemented>)";
})();
const { line } = expr.pos;
return `${expr.kind.type}:${line}${body}`;
}

View File

@ -1,5 +1,5 @@
import { Builtins } from "./arch.ts"; import { Builtins } from "./arch.ts";
import { Expr, Stmt, stmtToString } from "./ast.ts"; import { Expr, Stmt } from "./ast.ts";
import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts"; import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts";
import { Ops } from "./mod.ts"; import { Ops } from "./mod.ts";
import { Assembler, Label } from "./assembler.ts"; import { Assembler, Label } from "./assembler.ts";
@ -224,7 +224,7 @@ export class Lowerer {
case "field": case "field":
break; break;
case "index": case "index":
return this.lowerIndexExpr(expr); return this.lowerIndexExpr(expr);
case "call": case "call":
return this.lowerCallExpr(expr); return this.lowerCallExpr(expr);
case "unary": case "unary":
@ -245,8 +245,8 @@ export class Lowerer {
if (expr.kind.type !== "index") { if (expr.kind.type !== "index") {
throw new Error(); throw new Error();
} }
this.lowerExpr(expr.kind.subject) this.lowerExpr(expr.kind.subject);
this.lowerExpr(expr.kind.value) this.lowerExpr(expr.kind.value);
if (expr.kind.subject.vtype?.type == "array") { if (expr.kind.subject.vtype?.type == "array") {
this.program.add(Ops.Builtin, Builtins.ArrayAt); this.program.add(Ops.Builtin, Builtins.ArrayAt);

View File

@ -1,5 +1,5 @@
import { Builtins } from "./arch.ts";
import { Expr, Stmt } from "./ast.ts"; import { Expr, Stmt } from "./ast.ts";
import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts";
import { printStackTrace, Reporter } from "./info.ts"; import { printStackTrace, Reporter } from "./info.ts";
import { import {
FnSyms, FnSyms,
@ -10,23 +10,37 @@ import {
} from "./resolver_syms.ts"; } from "./resolver_syms.ts";
import { Pos } from "./token.ts"; import { Pos } from "./token.ts";
export class Resolver { export class Resolver implements AstVisitor<[Syms]> {
private root = new GlobalSyms(); private root = new GlobalSyms();
public constructor(private reporter: Reporter) { public constructor(private reporter: Reporter) {
this.root.define("print", {
type: "builtin",
ident: "print",
builtinId: Builtins.Print,
});
} }
public resolve(stmts: Stmt[]) { public resolve(stmts: Stmt[]): VisitRes {
const scopeSyms = new StaticSyms(this.root); const scopeSyms = new StaticSyms(this.root);
this.scoutFnStmts(stmts, scopeSyms); this.scoutFnStmts(stmts, scopeSyms);
for (const stmt of stmts) { visitStmts(stmts, this, scopeSyms);
this.resolveStmt(stmt, scopeSyms); return "stop";
}
visitLetStmt(stmt: Stmt, syms: Syms): VisitRes {
if (stmt.kind.type !== "let") {
throw new Error("expected let statement");
} }
visitExpr(stmt.kind.value, this, syms);
const ident = stmt.kind.param.ident;
if (syms.definedLocally(ident)) {
this.reportAlreadyDefined(ident, stmt.pos, syms);
return;
}
syms.define(ident, {
ident,
type: "let",
pos: stmt.kind.param.pos,
stmt,
param: stmt.kind.param,
});
return "stop";
} }
private scoutFnStmts(stmts: Stmt[], syms: Syms) { private scoutFnStmts(stmts: Stmt[], syms: Syms) {
@ -48,148 +62,7 @@ export class Resolver {
} }
} }
private resolveExpr(expr: Expr, syms: Syms) { visitFnStmt(stmt: Stmt, syms: Syms): VisitRes {
if (expr.kind.type === "error") {
return;
}
if (expr.kind.type === "ident") {
this.resolveIdentExpr(expr, syms);
return;
}
if (expr.kind.type === "binary") {
this.resolveExpr(expr.kind.left, syms);
this.resolveExpr(expr.kind.right, syms);
return;
}
if (expr.kind.type === "block") {
const childSyms = new LeafSyms(syms);
this.scoutFnStmts(expr.kind.stmts, childSyms);
for (const stmt of expr.kind.stmts) {
this.resolveStmt(stmt, childSyms);
}
if (expr.kind.expr) {
this.resolveExpr(expr.kind.expr, childSyms);
}
return;
}
if (expr.kind.type === "group") {
this.resolveExpr(expr.kind.expr, syms);
return;
}
if (expr.kind.type === "field") {
this.resolveExpr(expr.kind.subject, syms);
return;
}
if (expr.kind.type === "index") {
this.resolveExpr(expr.kind.subject, syms);
this.resolveExpr(expr.kind.value, syms);
return;
}
if (expr.kind.type === "call") {
this.resolveExpr(expr.kind.subject, syms);
for (const e of expr.kind.args) {
this.resolveExpr(e, syms);
}
return;
}
if (expr.kind.type === "unary") {
this.resolveExpr(expr.kind.subject, syms);
return;
}
if (expr.kind.type === "if") {
this.resolveExpr(expr.kind.cond, syms);
this.resolveExpr(expr.kind.truthy, syms);
if (expr.kind.falsy !== undefined) {
this.resolveExpr(expr.kind.falsy, syms);
}
return;
}
if (expr.kind.type === "loop") {
this.resolveExpr(expr.kind.body, syms);
return;
}
if (
expr.kind.type === "int" || expr.kind.type === "bool" ||
expr.kind.type === "null" || expr.kind.type === "string" ||
expr.kind.type === "sym"
) {
return;
}
}
private resolveIdentExpr(expr: Expr, syms: Syms) {
if (expr.kind.type !== "ident") {
throw new Error("expected ident");
}
const ident = expr.kind;
const symResult = syms.get(ident.value);
if (!symResult.ok) {
this.reportUseOfUndefined(ident.value, expr.pos, syms);
return;
}
const sym = symResult.sym;
expr.kind = {
type: "sym",
ident: ident.value,
sym,
};
}
private resolveStmt(stmt: Stmt, syms: Syms) {
if (stmt.kind.type === "error") {
return;
}
if (stmt.kind.type === "let") {
this.resolveLetStmt(stmt, syms);
return;
}
if (stmt.kind.type === "fn") {
this.resolveFnStmt(stmt, syms);
return;
}
if (stmt.kind.type === "return") {
if (stmt.kind.expr) {
this.resolveExpr(stmt.kind.expr, syms);
}
return;
}
if (stmt.kind.type === "break") {
if (stmt.kind.expr !== undefined) {
this.resolveExpr(stmt.kind.expr, syms);
}
return;
}
if (stmt.kind.type === "assign") {
this.resolveExpr(stmt.kind.subject, syms);
this.resolveExpr(stmt.kind.value, syms);
return;
}
if (stmt.kind.type === "expr") {
this.resolveExpr(stmt.kind.expr, syms);
return;
}
}
private resolveLetStmt(stmt: Stmt, syms: Syms) {
if (stmt.kind.type !== "let") {
throw new Error("expected let statement");
}
this.resolveExpr(stmt.kind.value, syms);
const ident = stmt.kind.param.ident;
if (syms.definedLocally(ident)) {
this.reportAlreadyDefined(ident, stmt.pos, syms);
return;
}
syms.define(ident, {
ident,
type: "let",
pos: stmt.kind.param.pos,
stmt,
param: stmt.kind.param,
});
}
private resolveFnStmt(stmt: Stmt, syms: Syms) {
if (stmt.kind.type !== "fn") { if (stmt.kind.type !== "fn") {
throw new Error("expected fn statement"); throw new Error("expected fn statement");
} }
@ -206,7 +79,40 @@ export class Resolver {
param, param,
}); });
} }
this.resolveExpr(stmt.kind.body, fnScopeSyms); visitExpr(stmt.kind.body, this, fnScopeSyms);
return "stop";
}
visitIdentExpr(expr: Expr, syms: Syms): VisitRes {
if (expr.kind.type !== "ident") {
throw new Error("expected ident");
}
const ident = expr.kind;
const symResult = syms.get(ident.value);
if (!symResult.ok) {
this.reportUseOfUndefined(ident.value, expr.pos, syms);
return;
}
const sym = symResult.sym;
expr.kind = {
type: "sym",
ident: ident.value,
sym,
};
return "stop";
}
visitBlockExpr(expr: Expr, syms: Syms): VisitRes {
if (expr.kind.type !== "block") {
throw new Error();
}
const childSyms = new LeafSyms(syms);
this.scoutFnStmts(expr.kind.stmts, childSyms);
visitStmts(expr.kind.stmts, this, childSyms);
if (expr.kind.expr) {
visitExpr(expr.kind.expr, this, childSyms);
}
return "stop";
} }
private reportUseOfUndefined(ident: string, pos: Pos, _syms: Syms) { private reportUseOfUndefined(ident: string, pos: Pos, _syms: Syms) {