compiler: investigate pattern lower

This commit is contained in:
sfja 2025-02-05 01:35:46 +01:00
parent f9e14e5056
commit a5502810cd
5 changed files with 56 additions and 19 deletions

View File

@ -224,31 +224,36 @@ export function visitItem<
item: Item,
...p: P
) {
visitIdent(v, item.ident, ...p);
const kind = item.kind;
switch (kind.tag) {
case "error":
if (v.visitErrorItem?.(item, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
return;
case "mod_block":
if (v.visitModBlockItem?.(item, kind, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
visitBlock(v, kind.block, ...p);
return;
case "mod_file":
if (v.visitModFileItem?.(item, kind, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
return;
case "enum":
if (v.visitEnumItem?.(item, kind, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
for (const variant of kind.variants) {
visitVariant(v, variant, ...p);
}
return;
case "struct":
if (v.visitStructItem?.(item, kind, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
visitVariantData(v, kind.data, ...p);
return;
case "fn":
if (v.visitFnItem?.(item, kind, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
for (const param of kind.params) {
visitParam(v, param, ...p);
}
@ -256,9 +261,11 @@ export function visitItem<
return;
case "use":
if (v.visitUseItem?.(item, kind, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
return;
case "type_alias":
if (v.visitTypeAliasItem?.(item, kind, ...p) === "stop") return;
visitIdent(v, item.ident, ...p);
visitTy(v, kind.ty, ...p);
return;
}

View File

@ -128,7 +128,7 @@ export class Checker {
return resu.val;
}
case "local": {
const ty = this.exprTy(expr);
const patResu = this.resols.patRes();
const resu = this.resolveTys(ty, expected);
if (!resu.ok) {
this.report(resu.val, expr.span);

View File

@ -2,6 +2,7 @@
"workspace": ["./ast", "./check", "./middle", "./parse", "./resolve", "./ty", "./common"],
"lint": {
"rules": {
"tags": ["recommended"],
"exclude": [
"verbatim-module-syntax",
"no-unused-vars"

View File

@ -1,5 +1,5 @@
import * as ast from "@slige/ast";
import { IdBase, IdentId, IdMap, Res } from "@slige/common";
import { AstId, IdBase, IdentId, IdMap, Res } from "@slige/common";
export interface Syms {
getVal(ident: ast.Ident): Resolve;
@ -14,13 +14,21 @@ export type Resolve = {
kind: ResolveKind;
};
export type LocalId = IdBase & { readonly _: unique symbol };
export type ResolveKind =
| { tag: "error" }
| { tag: "fn"; item: ast.Item; kind: ast.FnItem }
| { tag: "local"; id: LocalId };
export type PatResolve =
| { tag: "param"; paramIdx: number }
| { tag: "let"; stmt: ast.Stmt; kind: ast.LetStmt };
export type LocalId = IdBase & { readonly _: unique symbol };
export type Local =
| { tag: "param"; paramIdx: number }
| { tag: "let"; stmt: ast.Stmt; kind: ast.LetStmt };
export const ResolveError = (ident: ast.Ident): Resolve => ({
ident,
kind: { tag: "error" },

View File

@ -4,6 +4,7 @@ import {
FnSyms,
LocalId,
LocalSyms,
PatResolve,
Resolve,
ResolveError,
RootSyms,
@ -14,6 +15,7 @@ export { type LocalId } from "./cx.ts";
export class Resols {
public constructor(
private exprResols: IdMap<AstId, Resolve>,
private patResols: IdMap<AstId, PatResolve>,
) {}
public exprRes(id: AstId): Resolve {
@ -22,6 +24,13 @@ export class Resols {
}
return this.exprResols.get(id)!;
}
public patRes(id: AstId): PatResolve {
if (!this.patResols.has(id)) {
throw new Error();
}
return this.patResols.get(id)!;
}
}
export class Resolver implements ast.Visitor {
@ -30,6 +39,9 @@ export class Resolver implements ast.Visitor {
private syms: Syms = this.rootSyms;
private exprResols = new IdMap<AstId, Resolve>();
private patResols = new IdMap<AstId, PatResolve>();
private patResolveStack: PatResolve[] = [];
private localIds = new Ids<LocalId>();
@ -42,13 +54,23 @@ export class Resolver implements ast.Visitor {
ast.visitFile(this, this.entryFileAst);
return new Resols(
this.exprResols,
this.patResols,
);
}
visitFile(file: ast.File): ast.VisitRes {
this.currentFile = file.file;
this.fnBodiesToCheck.push([]);
ast.visitStmts(this, file.stmts);
this.visitFnBodies();
this.popAndVisitFnBodies();
return "stop";
}
visitBlock(block: ast.Block): ast.VisitRes {
this.fnBodiesToCheck.push([]);
ast.visitStmts(this, block.stmts);
this.popAndVisitFnBodies();
block.expr && ast.visitExpr(this, block.expr);
return "stop";
}
@ -56,7 +78,9 @@ export class Resolver implements ast.Visitor {
kind.ty && ast.visitTy(this, kind.ty);
kind.expr && ast.visitExpr(this, kind.expr);
this.syms = new LocalSyms(this.syms);
this.patResolveStack.push({ tag: "let", stmt, kind });
ast.visitPat(this, kind.pat);
this.patResolveStack.pop();
return "stop";
}
@ -77,25 +101,28 @@ export class Resolver implements ast.Visitor {
todo();
}
private fnBodiesToCheck: [ast.Item, ast.FnItem][] = [];
private fnBodiesToCheck: [ast.Item, ast.FnItem][][] = [];
visitFnItem(item: ast.Item, kind: ast.FnItem): ast.VisitRes {
this.syms.defVal(item.ident, { tag: "fn", item, kind });
this.fnBodiesToCheck.push([item, kind]);
this.fnBodiesToCheck.at(-1)!.push([item, kind]);
return "stop";
}
private visitFnBodies() {
for (const [_item, kind] of this.fnBodiesToCheck) {
private popAndVisitFnBodies() {
for (const [_item, kind] of this.fnBodiesToCheck.at(-1)!) {
const outerSyms = this.syms;
this.syms = new FnSyms(this.syms);
this.syms = new LocalSyms(this.syms);
for (const param of kind.params) {
for (const [paramIdx, param] of kind.params.entries()) {
this.patResolveStack.push({ tag: "param", paramIdx });
ast.visitParam(this, param);
this.patResolveStack.pop();
}
ast.visitBlock(this, kind.body!);
this.syms = outerSyms;
}
this.fnBodiesToCheck = [];
this.fnBodiesToCheck.pop();
}
visitPathExpr(expr: ast.Expr, kind: ast.PathExpr): ast.VisitRes {
@ -137,6 +164,7 @@ export class Resolver implements ast.Visitor {
}
visitBindPat(pat: ast.Pat, kind: ast.BindPat): ast.VisitRes {
this.patResols.set(pat.id, this.patResolveStack.at(-1)!);
const res = this.syms.defVal(kind.ident, {
tag: "local",
id: this.localIds.nextThenStep(),
@ -157,13 +185,6 @@ export class Resolver implements ast.Visitor {
todo(pat, kind);
}
visitBlock(block: ast.Block): ast.VisitRes {
ast.visitStmts(this, block.stmts);
this.visitFnBodies();
block.expr && ast.visitExpr(this, block.expr);
return "stop";
}
visitPath(_path: ast.Path): ast.VisitRes {
throw new Error("should not be reached");
}