From b6c7f09f8bdf4a3ce466fab82937902acd92c6bd Mon Sep 17 00:00:00 2001 From: sfja Date: Fri, 13 Dec 2024 23:22:08 +0100 Subject: [PATCH] add visitor, apply on resolver --- compiler/ast.ts | 30 ------ compiler/ast_visitor.ts | 232 ++++++++++++++++++++++++++++++++++++++++ compiler/lowerer.ts | 8 +- compiler/resolver.ts | 212 ++++++++++-------------------------- 4 files changed, 295 insertions(+), 187 deletions(-) create mode 100644 compiler/ast_visitor.ts diff --git a/compiler/ast.ts b/compiler/ast.ts index 4a88117..38bfe0a 100644 --- a/compiler/ast.ts +++ b/compiler/ast.ts @@ -119,33 +119,3 @@ export type Anno = { values: Expr[]; pos: Pos; }; - -export function stmtToString(stmt: Stmt): string { - const body = (() => { - switch (stmt.kind.type) { - case "assign": - return `{ subject: ${exprToString(stmt.kind.subject)}, value: ${ - exprToString(stmt.kind.value) - } }`; - } - return "()"; - })(); - const { line } = stmt.pos; - return `${stmt.kind.type}:${line}${body}`; -} - -export function exprToString(expr: Expr): string { - const body = (() => { - switch (expr.kind.type) { - case "binary": - return `(${ - exprToString(expr.kind.left) - } ${expr.kind.binaryType} ${exprToString(expr.kind.right)})`; - case "sym": - return `(${expr.kind.ident})`; - } - return "()"; - })(); - const { line } = expr.pos; - return `${expr.kind.type}:${line}${body}`; -} diff --git a/compiler/ast_visitor.ts b/compiler/ast_visitor.ts new file mode 100644 index 0000000..25e85b1 --- /dev/null +++ b/compiler/ast_visitor.ts @@ -0,0 +1,232 @@ +import { EType, Expr, Param, Stmt } from "./ast.ts"; + +export type VisitRes = "stop" | void; + +export interface AstVisitor { + visitStmts?(stmts: Stmt[], ...args: Args): VisitRes; + visitStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitErrorStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitImportStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitBreakStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitReturnStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitFnStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitLetStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitAssignStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitExprStmt?(stmt: Stmt, ...args: Args): VisitRes; + visitExpr?(expr: Expr, ...args: Args): VisitRes; + visitErrorExpr?(expr: Expr, ...args: Args): VisitRes; + visitIntExpr?(expr: Expr, ...args: Args): VisitRes; + visitStringExpr?(expr: Expr, ...args: Args): VisitRes; + visitIdentExpr?(expr: Expr, ...args: Args): VisitRes; + visitGroupExpr?(expr: Expr, ...args: Args): VisitRes; + visitFieldExpr?(expr: Expr, ...args: Args): VisitRes; + visitIndexExpr?(expr: Expr, ...args: Args): VisitRes; + visitCallExpr?(expr: Expr, ...args: Args): VisitRes; + visitUnaryExpr?(expr: Expr, ...args: Args): VisitRes; + visitBinaryExpr?(expr: Expr, ...args: Args): VisitRes; + visitIfExpr?(expr: Expr, ...args: Args): VisitRes; + visitBoolExpr?(expr: Expr, ...args: Args): VisitRes; + visitNullExpr?(expr: Expr, ...args: Args): VisitRes; + visitLoopExpr?(expr: Expr, ...args: Args): VisitRes; + visitBlockExpr?(expr: Expr, ...args: Args): VisitRes; + visitSymExpr?(expr: Expr, ...args: Args): VisitRes; + visitParam?(param: Param, ...args: Args): VisitRes; + visitEType?(etype: EType, ...args: Args): VisitRes; + visitErrorEType?(etype: EType, ...args: Args): VisitRes; + visitIdentEType?(etype: EType, ...args: Args): VisitRes; + visitArrayEType?(etype: EType, ...args: Args): VisitRes; + visitStructEType?(etype: EType, ...args: Args): VisitRes; + visitAnno?(etype: EType, ...args: Args): VisitRes; +} + +export function visitStmts( + stmts: Stmt[], + v: AstVisitor, + ...args: Args +) { + if (v.visitStmts?.(stmts, ...args) === "stop") return; + stmts.map((stmt) => visitStmt(stmt, v, ...args)); +} + +export function visitStmt( + stmt: Stmt, + v: AstVisitor, + ...args: Args +) { + if (v.visitStmt?.(stmt, ...args) == "stop") return; + switch (stmt.kind.type) { + case "error": + if (v.visitErrorStmt?.(stmt, ...args) == "stop") return; + break; + case "import": + if (v.visitImportStmt?.(stmt, ...args) == "stop") return; + visitExpr(stmt.kind.path, v, ...args); + break; + case "break": + if (v.visitBreakStmt?.(stmt, ...args) == "stop") return; + if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args); + break; + case "return": + if (v.visitReturnStmt?.(stmt, ...args) == "stop") return; + if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args); + break; + case "fn": + if (v.visitFnStmt?.(stmt, ...args) == "stop") return; + stmt.kind.params.map((param) => visitParam(param, v, ...args)); + if (stmt.kind.returnType) { + visitEType(stmt.kind.returnType, v, ...args); + } + visitExpr(stmt.kind.body, v, ...args); + break; + case "let": + if (v.visitLetStmt?.(stmt, ...args) == "stop") return; + visitParam(stmt.kind.param, v, ...args); + visitExpr(stmt.kind.value, v, ...args); + break; + case "assign": + if (v.visitAssignStmt?.(stmt, ...args) == "stop") return; + visitExpr(stmt.kind.subject, v, ...args); + visitExpr(stmt.kind.value, v, ...args); + break; + case "expr": + if (v.visitExprStmt?.(stmt, ...args) == "stop") return; + visitExpr(stmt.kind.expr, v, ...args); + break; + } +} + +export function visitExpr( + expr: Expr, + v: AstVisitor, + ...args: Args +) { + if (v.visitExpr?.(expr, ...args) == "stop") return; + switch (expr.kind.type) { + case "error": + if (v.visitErrorExpr?.(expr, ...args) == "stop") return; + break; + case "string": + if (v.visitStringExpr?.(expr, ...args) == "stop") return; + break; + case "int": + if (v.visitIntExpr?.(expr, ...args) == "stop") return; + break; + case "ident": + if (v.visitIdentExpr?.(expr, ...args) == "stop") return; + break; + case "group": + if (v.visitGroupExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.expr, v, ...args); + break; + case "field": + if (v.visitFieldExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.subject, v, ...args); + break; + case "index": + if (v.visitIndexExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.subject, v, ...args); + visitExpr(expr.kind.value, v, ...args); + break; + case "call": + if (v.visitCallExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.subject, v, ...args); + expr.kind.args.map((arg) => visitExpr(arg, v, ...args)); + break; + case "unary": + if (v.visitUnaryExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.subject, v, ...args); + break; + case "binary": + if (v.visitBinaryExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.left, v, ...args); + visitExpr(expr.kind.right, v, ...args); + break; + case "if": + if (v.visitIfExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.cond, v, ...args); + visitExpr(expr.kind.truthy, v, ...args); + if (expr.kind.falsy) visitExpr(expr.kind.falsy, v, ...args); + break; + case "bool": + if (v.visitBoolExpr?.(expr, ...args) == "stop") return; + break; + case "null": + if (v.visitNullExpr?.(expr, ...args) == "stop") return; + break; + case "loop": + if (v.visitLoopExpr?.(expr, ...args) == "stop") return; + visitExpr(expr.kind.body, v, ...args); + break; + case "block": + if (v.visitBlockExpr?.(expr, ...args) == "stop") return; + expr.kind.stmts.map((stmt) => visitStmt(stmt, v, ...args)); + if (expr.kind.expr) visitExpr(expr.kind.expr, v, ...args); + break; + case "sym": + if (v.visitSymExpr?.(expr, ...args) == "stop") return; + break; + } +} + +export function visitParam( + param: Param, + v: AstVisitor, + ...args: Args +) { + if (v.visitParam?.(param, ...args) == "stop") return; + if (param.etype) visitEType(param.etype, v, ...args); +} + +export function visitEType( + etype: EType, + v: AstVisitor, + ...args: Args +) { + if (v.visitEType?.(etype, ...args) == "stop") return; + switch (etype.kind.type) { + case "error": + if (v.visitErrorEType?.(etype, ...args) == "stop") return; + break; + case "ident": + if (v.visitIdentEType?.(etype, ...args) == "stop") return; + break; + case "array": + if (v.visitArrayEType?.(etype, ...args) == "stop") return; + if (etype.kind.inner) visitEType(etype.kind.inner, v, ...args); + break; + case "struct": + if (v.visitStructEType?.(etype, ...args) == "stop") return; + etype.kind.fields.map((field) => visitParam(field, v, ...args)); + break; + } +} + +export function stmtToString(stmt: Stmt): string { + const body = (() => { + switch (stmt.kind.type) { + case "assign": + return `{ subject: ${exprToString(stmt.kind.subject)}, value: ${ + exprToString(stmt.kind.value) + } }`; + } + return "()"; + })(); + const { line } = stmt.pos; + return `${stmt.kind.type}:${line}${body}`; +} + +export function exprToString(expr: Expr): string { + const body = (() => { + switch (expr.kind.type) { + case "binary": + return `(${ + exprToString(expr.kind.left) + } ${expr.kind.binaryType} ${exprToString(expr.kind.right)})`; + case "sym": + return `(${expr.kind.ident})`; + } + return "()"; + })(); + const { line } = expr.pos; + return `${expr.kind.type}:${line}${body}`; +} diff --git a/compiler/lowerer.ts b/compiler/lowerer.ts index aea880e..f0c41da 100644 --- a/compiler/lowerer.ts +++ b/compiler/lowerer.ts @@ -1,5 +1,5 @@ import { Builtins } from "./arch.ts"; -import { Expr, Stmt, stmtToString } from "./ast.ts"; +import { Expr, Stmt } from "./ast.ts"; import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts"; import { Ops } from "./mod.ts"; import { Assembler, Label } from "./assembler.ts"; @@ -224,7 +224,7 @@ export class Lowerer { case "field": break; case "index": - return this.lowerIndexExpr(expr); + return this.lowerIndexExpr(expr); case "call": return this.lowerCallExpr(expr); case "unary": @@ -245,8 +245,8 @@ export class Lowerer { if (expr.kind.type !== "index") { throw new Error(); } - this.lowerExpr(expr.kind.subject) - this.lowerExpr(expr.kind.value) + this.lowerExpr(expr.kind.subject); + this.lowerExpr(expr.kind.value); if (expr.kind.subject.vtype?.type == "array") { this.program.add(Ops.Builtin, Builtins.ArrayAt); diff --git a/compiler/resolver.ts b/compiler/resolver.ts index 607e4d3..7483131 100644 --- a/compiler/resolver.ts +++ b/compiler/resolver.ts @@ -1,5 +1,5 @@ -import { Builtins } from "./arch.ts"; import { Expr, Stmt } from "./ast.ts"; +import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts"; import { printStackTrace, Reporter } from "./info.ts"; import { FnSyms, @@ -10,23 +10,37 @@ import { } from "./resolver_syms.ts"; import { Pos } from "./token.ts"; -export class Resolver { +export class Resolver implements AstVisitor<[Syms]> { private root = new GlobalSyms(); public constructor(private reporter: Reporter) { - this.root.define("print", { - type: "builtin", - ident: "print", - builtinId: Builtins.Print, - }); } - public resolve(stmts: Stmt[]) { + public resolve(stmts: Stmt[]): VisitRes { const scopeSyms = new StaticSyms(this.root); this.scoutFnStmts(stmts, scopeSyms); - for (const stmt of stmts) { - this.resolveStmt(stmt, scopeSyms); + visitStmts(stmts, this, scopeSyms); + return "stop"; + } + + visitLetStmt(stmt: Stmt, syms: Syms): VisitRes { + if (stmt.kind.type !== "let") { + throw new Error("expected let statement"); } + visitExpr(stmt.kind.value, this, syms); + const ident = stmt.kind.param.ident; + if (syms.definedLocally(ident)) { + this.reportAlreadyDefined(ident, stmt.pos, syms); + return; + } + syms.define(ident, { + ident, + type: "let", + pos: stmt.kind.param.pos, + stmt, + param: stmt.kind.param, + }); + return "stop"; } private scoutFnStmts(stmts: Stmt[], syms: Syms) { @@ -48,148 +62,7 @@ export class Resolver { } } - private resolveExpr(expr: Expr, syms: Syms) { - if (expr.kind.type === "error") { - return; - } - if (expr.kind.type === "ident") { - this.resolveIdentExpr(expr, syms); - return; - } - if (expr.kind.type === "binary") { - this.resolveExpr(expr.kind.left, syms); - this.resolveExpr(expr.kind.right, syms); - return; - } - if (expr.kind.type === "block") { - const childSyms = new LeafSyms(syms); - this.scoutFnStmts(expr.kind.stmts, childSyms); - for (const stmt of expr.kind.stmts) { - this.resolveStmt(stmt, childSyms); - } - if (expr.kind.expr) { - this.resolveExpr(expr.kind.expr, childSyms); - } - return; - } - if (expr.kind.type === "group") { - this.resolveExpr(expr.kind.expr, syms); - return; - } - if (expr.kind.type === "field") { - this.resolveExpr(expr.kind.subject, syms); - return; - } - if (expr.kind.type === "index") { - this.resolveExpr(expr.kind.subject, syms); - this.resolveExpr(expr.kind.value, syms); - return; - } - if (expr.kind.type === "call") { - this.resolveExpr(expr.kind.subject, syms); - for (const e of expr.kind.args) { - this.resolveExpr(e, syms); - } - return; - } - if (expr.kind.type === "unary") { - this.resolveExpr(expr.kind.subject, syms); - return; - } - if (expr.kind.type === "if") { - this.resolveExpr(expr.kind.cond, syms); - this.resolveExpr(expr.kind.truthy, syms); - if (expr.kind.falsy !== undefined) { - this.resolveExpr(expr.kind.falsy, syms); - } - return; - } - if (expr.kind.type === "loop") { - this.resolveExpr(expr.kind.body, syms); - return; - } - if ( - expr.kind.type === "int" || expr.kind.type === "bool" || - expr.kind.type === "null" || expr.kind.type === "string" || - expr.kind.type === "sym" - ) { - return; - } - } - - private resolveIdentExpr(expr: Expr, syms: Syms) { - if (expr.kind.type !== "ident") { - throw new Error("expected ident"); - } - const ident = expr.kind; - const symResult = syms.get(ident.value); - if (!symResult.ok) { - this.reportUseOfUndefined(ident.value, expr.pos, syms); - return; - } - const sym = symResult.sym; - expr.kind = { - type: "sym", - ident: ident.value, - sym, - }; - } - - private resolveStmt(stmt: Stmt, syms: Syms) { - if (stmt.kind.type === "error") { - return; - } - if (stmt.kind.type === "let") { - this.resolveLetStmt(stmt, syms); - return; - } - if (stmt.kind.type === "fn") { - this.resolveFnStmt(stmt, syms); - return; - } - if (stmt.kind.type === "return") { - if (stmt.kind.expr) { - this.resolveExpr(stmt.kind.expr, syms); - } - return; - } - if (stmt.kind.type === "break") { - if (stmt.kind.expr !== undefined) { - this.resolveExpr(stmt.kind.expr, syms); - } - return; - } - if (stmt.kind.type === "assign") { - this.resolveExpr(stmt.kind.subject, syms); - this.resolveExpr(stmt.kind.value, syms); - return; - } - if (stmt.kind.type === "expr") { - this.resolveExpr(stmt.kind.expr, syms); - return; - } - } - - private resolveLetStmt(stmt: Stmt, syms: Syms) { - if (stmt.kind.type !== "let") { - throw new Error("expected let statement"); - } - this.resolveExpr(stmt.kind.value, syms); - const ident = stmt.kind.param.ident; - if (syms.definedLocally(ident)) { - this.reportAlreadyDefined(ident, stmt.pos, syms); - return; - } - syms.define(ident, { - ident, - type: "let", - pos: stmt.kind.param.pos, - stmt, - param: stmt.kind.param, - }); - } - - private resolveFnStmt(stmt: Stmt, syms: Syms) { + visitFnStmt(stmt: Stmt, syms: Syms): VisitRes { if (stmt.kind.type !== "fn") { throw new Error("expected fn statement"); } @@ -206,7 +79,40 @@ export class Resolver { param, }); } - this.resolveExpr(stmt.kind.body, fnScopeSyms); + visitExpr(stmt.kind.body, this, fnScopeSyms); + return "stop"; + } + + visitIdentExpr(expr: Expr, syms: Syms): VisitRes { + if (expr.kind.type !== "ident") { + throw new Error("expected ident"); + } + const ident = expr.kind; + const symResult = syms.get(ident.value); + if (!symResult.ok) { + this.reportUseOfUndefined(ident.value, expr.pos, syms); + return; + } + const sym = symResult.sym; + expr.kind = { + type: "sym", + ident: ident.value, + sym, + }; + return "stop"; + } + + visitBlockExpr(expr: Expr, syms: Syms): VisitRes { + if (expr.kind.type !== "block") { + throw new Error(); + } + const childSyms = new LeafSyms(syms); + this.scoutFnStmts(expr.kind.stmts, childSyms); + visitStmts(expr.kind.stmts, this, childSyms); + if (expr.kind.expr) { + visitExpr(expr.kind.expr, this, childSyms); + } + return "stop"; } private reportUseOfUndefined(ident: string, pos: Pos, _syms: Syms) {