add generics

This commit is contained in:
sfja 2026-04-16 13:02:10 +02:00
parent 87f90c6700
commit f3315e9c31
5 changed files with 222 additions and 67 deletions

View File

@ -101,7 +101,7 @@ export class Node {
case "IndexExpr": case "IndexExpr":
return visit(k.value, k.arg); return visit(k.value, k.arg);
case "CallExpr": case "CallExpr":
return visit(k.value, ...k.args); return visit(k.value, ...k.generics ?? [], ...k.args);
case "UnaryExpr": case "UnaryExpr":
return visit(k.expr); return visit(k.expr);
case "BinaryExpr": case "BinaryExpr":
@ -117,6 +117,8 @@ export class Node {
return visit(k.ty, k.length); return visit(k.ty, k.length);
case "SliceTy": case "SliceTy":
return visit(k.ty); return visit(k.ty);
case "Generic":
return visit();
} }
k satisfies never; k satisfies never;
} }
@ -131,6 +133,7 @@ export type NodeKind =
| { | {
tag: "FnStmt"; tag: "FnStmt";
ident: string; ident: string;
genericParams: Node[] | null;
params: Node[]; params: Node[];
retTy: Node | null; retTy: Node | null;
body: Node; body: Node;
@ -146,7 +149,7 @@ export type NodeKind =
| { tag: "StrExpr"; value: string } | { tag: "StrExpr"; value: string }
| { tag: "ArrayExpr"; values: Node[] } | { tag: "ArrayExpr"; values: Node[] }
| { tag: "IndexExpr"; value: Node; arg: Node } | { tag: "IndexExpr"; value: Node; arg: Node }
| { tag: "CallExpr"; value: Node; args: Node[] } | { tag: "CallExpr"; value: Node; generics: Node[] | null; args: Node[] }
| { tag: "UnaryExpr"; op: UnaryOp; expr: Node; tok: string } | { tag: "UnaryExpr"; op: UnaryOp; expr: Node; tok: string }
| { tag: "BinaryExpr"; op: BinaryOp; left: Node; right: Node; tok: string } | { tag: "BinaryExpr"; op: BinaryOp; left: Node; right: Node; tok: string }
| { | {
@ -158,7 +161,8 @@ export type NodeKind =
| { tag: "IdentTy"; ident: string } | { tag: "IdentTy"; ident: string }
| { tag: "PtrTy" | "PtrMutTy"; ty: Node } | { tag: "PtrTy" | "PtrMutTy"; ty: Node }
| { tag: "ArrayTy"; ty: Node; length: Node } | { tag: "ArrayTy"; ty: Node; length: Node }
| { tag: "SliceTy"; ty: Node }; | { tag: "SliceTy"; ty: Node }
| { tag: "Generic"; ident: string };
export type IntTy = export type IntTy =
| "i8" | "i8"

View File

@ -61,6 +61,7 @@ export class CheckedFn {
class TypeChecker { class TypeChecker {
private nodeTys = new Map<ast.Node, Ty>(); private nodeTys = new Map<ast.Node, Ty>();
private generics?: Ty[];
private params!: Ty[]; private params!: Ty[];
constructor( constructor(
@ -71,6 +72,12 @@ class TypeChecker {
) {} ) {}
check(): CheckedFn { check(): CheckedFn {
const generics = this.fn.kind.genericParams
?.map((_node, idx) => {
return Ty.Generic(idx);
});
this.generics = generics;
const params = this.fn.kind.params const params = this.fn.kind.params
.map((node) => { .map((node) => {
const param = node.as("Param"); const param = node.as("Param");
@ -194,6 +201,7 @@ class TypeChecker {
case "PtrMutTy": case "PtrMutTy":
case "ArrayTy": case "ArrayTy":
case "SliceTy": case "SliceTy":
case "Generic":
break; break;
default: default:
k satisfies never; k satisfies never;
@ -205,7 +213,7 @@ class TypeChecker {
const ty = Ty.create("FnStmt", { const ty = Ty.create("FnStmt", {
stmt: this.fn, stmt: this.fn,
ty: Ty.create("Fn", { params, retTy }), ty: Ty.Fn(params, retTy, generics ?? null),
}); });
return new CheckedFn(ty, this.nodeTys); return new CheckedFn(ty, this.nodeTys);
@ -250,24 +258,6 @@ class TypeChecker {
private checkExpr(expr: ast.Node, expected: Ty): Ty { private checkExpr(expr: ast.Node, expected: Ty): Ty {
const k = expr.kind; const k = expr.kind;
switch (k.tag) { switch (k.tag) {
case "Error":
case "File":
case "Block":
case "ExprStmt":
case "AssignStmt":
case "FnStmt":
case "ReturnStmt":
case "LetStmt":
case "IfStmt":
case "WhileStmt":
case "BreakStmt":
case "Param":
case "IdentTy":
case "PtrTy":
case "PtrMutTy":
case "ArrayTy":
case "SliceTy":
throw new Error(`node '${k.tag}' not an expression`);
case "IdentExpr": { case "IdentExpr": {
const sym = this.syms.get(expr); const sym = this.syms.get(expr);
if (sym.tag === "Fn") { if (sym.tag === "Fn") {
@ -453,12 +443,7 @@ class TypeChecker {
: null; : null;
if (sym?.tag === "Builtin") { if (sym?.tag === "Builtin") {
if (sym.id === "len") { if (sym.id === "len") {
checkArgs( checkArgs(Ty.Fn([Ty.Any], Ty.USize, null).as("Fn"));
Ty.create("Fn", {
params: [Ty.Any],
retTy: Ty.USize,
}).as("Fn"),
);
return Ty.USize; return Ty.USize;
} }
if (sym.id === "print") { if (sym.id === "print") {
@ -483,8 +468,60 @@ class TypeChecker {
} }
const callableTy = calleeTy.callableTy(); const callableTy = calleeTy.callableTy();
checkArgs(callableTy);
return callableTy.kind.retTy; if (callableTy.kind.generics !== null) {
let generics: Ty[];
if (k.generics) {
generics = k.generics.map((ty) => this.ty(ty));
if (
generics.length !== callableTy.kind.generics.length
) {
this.reporter.error(
expr.loc,
`expected ${callableTy.kind.generics.length} generic type arguments, got ${generics.length}`,
);
this.reporter.abort();
}
} else {
generics = callableTy.kind.generics
.map((_ty) => Ty.Any);
}
const newCalleeTy = Ty.Fn(
callableTy.kind.params
.map((ty) =>
ty.is("Generic") ? generics[ty.kind.idx] : ty
),
callableTy.kind.retTy.is("Generic")
? generics[callableTy.kind.retTy.kind.idx]
: callableTy.kind.retTy,
null,
).as("Fn");
this.rewriteTree(k.value, newCalleeTy);
checkArgs(
Ty.Fn(
callableTy.kind.params
.map((ty) =>
ty.is("Generic")
? generics[ty.kind.idx]
: ty
),
callableTy.kind.retTy.is("Generic")
? generics[callableTy.kind.retTy.kind.idx]
: callableTy.kind.retTy,
null,
).as("Fn"),
);
return callableTy.kind.retTy;
} else {
if (k.generics) {
this.reporter.error(expr.loc, "no generics expected");
this.reporter.abort();
}
checkArgs(callableTy);
return callableTy.kind.retTy;
}
} }
case "UnaryExpr": { case "UnaryExpr": {
switch (k.op) { switch (k.op) {
@ -578,6 +615,25 @@ class TypeChecker {
} }
return Ty.create("Range", {}); return Ty.create("Range", {});
} }
case "Error":
case "File":
case "Block":
case "ExprStmt":
case "AssignStmt":
case "FnStmt":
case "ReturnStmt":
case "LetStmt":
case "IfStmt":
case "WhileStmt":
case "BreakStmt":
case "Param":
case "IdentTy":
case "PtrTy":
case "PtrMutTy":
case "ArrayTy":
case "SliceTy":
case "Generic":
throw new Error(`node '${k.tag}' not an expression`);
default: default:
k satisfies never; k satisfies never;
throw new Error(); throw new Error();
@ -591,28 +647,6 @@ class TypeChecker {
private checkTy(ty: ast.Node): Ty { private checkTy(ty: ast.Node): Ty {
const k = ty.kind; const k = ty.kind;
switch (k.tag) { switch (k.tag) {
case "Error":
case "File":
case "Block":
case "ExprStmt":
case "AssignStmt":
case "FnStmt":
case "ReturnStmt":
case "LetStmt":
case "IfStmt":
case "WhileStmt":
case "BreakStmt":
case "Param":
case "IdentExpr":
case "IntExpr":
case "StrExpr":
case "ArrayExpr":
case "IndexExpr":
case "CallExpr":
case "UnaryExpr":
case "BinaryExpr":
case "RangeExpr":
throw new Error(`node '${k.tag}' not a type`);
case "IdentTy": { case "IdentTy": {
const sym = this.syms.get(ty); const sym = this.syms.get(ty);
if (sym.tag === "BuiltinTy") { if (sym.tag === "BuiltinTy") {
@ -647,6 +681,12 @@ class TypeChecker {
); );
} }
} }
if (sym.tag === "Generic") {
if (!this.generics) {
throw new Error();
}
return this.generics[sym.idx];
}
this.reporter.error(ty.loc, `symbol is not a type`); this.reporter.error(ty.loc, `symbol is not a type`);
return this.reporter.abort(); return this.reporter.abort();
} }
@ -682,6 +722,29 @@ class TypeChecker {
const ty = this.ty(k.ty); const ty = this.ty(k.ty);
return Ty.create("Slice", { ty }); return Ty.create("Slice", { ty });
} }
case "Error":
case "File":
case "Block":
case "ExprStmt":
case "AssignStmt":
case "FnStmt":
case "ReturnStmt":
case "LetStmt":
case "IfStmt":
case "WhileStmt":
case "BreakStmt":
case "Param":
case "IdentExpr":
case "IntExpr":
case "StrExpr":
case "ArrayExpr":
case "IndexExpr":
case "CallExpr":
case "UnaryExpr":
case "BinaryExpr":
case "RangeExpr":
case "Generic":
throw new Error(`node '${k.tag}' not a type`);
} }
} }

View File

@ -71,6 +71,7 @@ export class Parser {
const loc = this.loc(); const loc = this.loc();
this.step(); this.step();
const ident = this.mustEat("ident").value; const ident = this.mustEat("ident").value;
const genericParams = this.parseGenericParams();
this.mustEat("("); this.mustEat("(");
const params: ast.Node[] = []; const params: ast.Node[] = [];
if (!this.test(")")) { if (!this.test(")")) {
@ -88,7 +89,13 @@ export class Parser {
retTy = this.parseTy(); retTy = this.parseTy();
} }
const body = this.parseBlock(); const body = this.parseBlock();
return ast.Node.create(loc, "FnStmt", { ident, params, retTy, body }); return ast.Node.create(loc, "FnStmt", {
ident,
genericParams,
params,
retTy,
body,
});
} }
parseReturnStmt(): ast.Node { parseReturnStmt(): ast.Node {
@ -259,19 +266,12 @@ export class Parser {
const arg = this.parseExpr(); const arg = this.parseExpr();
this.mustEat("]"); this.mustEat("]");
expr = ast.Node.create(loc, "IndexExpr", { value: expr, arg }); expr = ast.Node.create(loc, "IndexExpr", { value: expr, arg });
} else if (this.test("::<")) {
const generics = this.parseGenericArgs();
this.mustEat("(");
expr = this.parseCallExprTail(expr, loc, generics);
} else if (this.eat("(")) { } else if (this.eat("(")) {
const args: ast.Node[] = []; expr = this.parseCallExprTail(expr, loc, null);
if (!this.test(")")) {
args.push(this.parseExpr());
while (this.eat(",")) {
if (this.done || this.test(")")) {
break;
}
args.push(this.parseExpr());
}
}
this.mustEat(")");
expr = ast.Node.create(loc, "CallExpr", { value: expr, args });
} else { } else {
break; break;
} }
@ -279,6 +279,29 @@ export class Parser {
return expr; return expr;
} }
parseCallExprTail(
expr: ast.Node,
loc: Loc,
generics: ast.Node[] | null,
): ast.Node {
const args: ast.Node[] = [];
if (!this.test(")")) {
args.push(this.parseExpr());
while (this.eat(",")) {
if (this.done || this.test(")")) {
break;
}
args.push(this.parseExpr());
}
}
this.mustEat(")");
return ast.Node.create(loc, "CallExpr", {
value: expr,
generics,
args,
});
}
parseOperand(): ast.Node { parseOperand(): ast.Node {
const loc = this.loc(); const loc = this.loc();
if (this.test("ident")) { if (this.test("ident")) {
@ -361,6 +384,38 @@ export class Parser {
} }
} }
parseGenericArgs(): ast.Node[] | null {
if (!this.eat("::<")) {
return null;
}
const args: ast.Node[] = [];
while (!this.done && !this.test("<")) {
args.push(this.parseTy());
if (!this.eat(",")) {
break;
}
}
this.mustEat(">");
return args;
}
parseGenericParams(): ast.Node[] | null {
if (!this.eat("<")) {
return null;
}
const params: ast.Node[] = [];
while (!this.done && !this.test("<")) {
const loc = this.loc();
const identTok = this.mustEat("ident");
params.push(ast.create(loc, "Generic", { ident: identTok.value }));
if (!this.eat(",")) {
break;
}
}
this.mustEat(">");
return params;
}
private mustEat(type: string, loc = this.loc()): Tok { private mustEat(type: string, loc = this.loc()): Tok {
const tok = this.current; const tok = this.current;
if (tok.type !== type) { if (tok.type !== type) {
@ -423,7 +478,7 @@ const keywordPattern =
/^(?:(?:fn)|(?:return)|(?:let)|(?:if)|(?:else)|(?:while)|(?:break)|(?:or)|(?:and)|(?:not)|(?:mut))/; /^(?:(?:fn)|(?:return)|(?:let)|(?:if)|(?:else)|(?:while)|(?:break)|(?:or)|(?:and)|(?:not)|(?:mut))/;
const operatorPattern2 = const operatorPattern2 =
/((?:\->)|(?:==)|(?:!=)|(?:<=)|(?:>=)|(?:<<)|(?:>>)|(?:\.\*)|(?:\.\.)|(?:\.\.=)|[\n\(\)\{\}\[\]\,\.\;\:\!\=\<\>\&\^\|\+\-\*\/\%])/g; /((?:\->)|(?:==)|(?:!=)|(?:<=)|(?:>=)|(?:\:\:<)|(?:<<)|(?:>>)|(?:\.\*)|(?:\.\.)|(?:\.\.=)|[\n\(\)\{\}\[\]\,\.\;\:\!\=\<\>\&\^\|\+\-\*\/\%])/g;
export function tokenize(text: string, reporter: FileReporter): Tok[] { export function tokenize(text: string, reporter: FileReporter): Tok[] {
return new Lexer() return new Lexer()

View File

@ -19,6 +19,12 @@ export type Sym =
| { tag: "Bool"; value: boolean } | { tag: "Bool"; value: boolean }
| { tag: "Builtin"; id: string } | { tag: "Builtin"; id: string }
| { tag: "Fn"; stmt: ast.NodeWithKind<"FnStmt"> } | { tag: "Fn"; stmt: ast.NodeWithKind<"FnStmt"> }
| {
tag: "Generic";
stmt: ast.Node;
generic: ast.NodeWithKind<"Generic">;
idx: number;
}
| { | {
tag: "FnParam"; tag: "FnParam";
stmt: ast.NodeWithKind<"FnStmt">; stmt: ast.NodeWithKind<"FnStmt">;
@ -61,6 +67,19 @@ export function resolve(
if (k.tag === "FnStmt") { if (k.tag === "FnStmt") {
ast.assertNodeWithKind(node, "FnStmt"); ast.assertNodeWithKind(node, "FnStmt");
syms = ResolverSyms.forkFrom(syms); syms = ResolverSyms.forkFrom(syms);
if (k.genericParams) {
for (const [idx, param] of k.genericParams?.entries()) {
ast.assertNodeWithKind(param, "Generic");
const sym: Sym = {
tag: "Generic",
stmt: node,
generic: param,
idx,
};
syms.define(param.kind.ident, sym);
resols.set(param.id, sym);
}
}
for (const [idx, param] of k.params.entries()) { for (const [idx, param] of k.params.entries()) {
ast.assertNodeWithKind(param, "Param"); ast.assertNodeWithKind(param, "Param");
const sym: Sym = { tag: "FnParam", stmt: node, param, idx }; const sym: Sym = { tag: "FnParam", stmt: node, param, idx };

View File

@ -48,6 +48,15 @@ export class Ty {
static Array(ty: Ty, length: number): Ty { static Array(ty: Ty, length: number): Ty {
return this.create("Array", { ty, length }); return this.create("Array", { ty, length });
} }
static Fn(params: Ty[], retTy: Ty, generics: Ty[] | null): Ty {
return this.create("Fn", { params, retTy, generics });
}
static Generic(idx: number): Ty {
return this.create("Generic", { idx });
}
static Instance(ty: Ty, args: Ty[]): Ty {
return this.create("Instance", { ty, args });
}
/** Only used in type checker. */ /** Only used in type checker. */
static Any = Ty.create("Any", {}); static Any = Ty.create("Any", {});
@ -67,6 +76,9 @@ export class Ty {
} }
private internHash(): string { private internHash(): string {
if (this.is("FnStmt")) {
return JSON.stringify({ ...this.kind, stmt: this.kind.stmt.id });
}
return JSON.stringify(this.kind); return JSON.stringify(this.kind);
} }
@ -244,8 +256,10 @@ export type TyKind =
| { tag: "Array"; ty: Ty; length: number } | { tag: "Array"; ty: Ty; length: number }
| { tag: "Slice"; ty: Ty } | { tag: "Slice"; ty: Ty }
| { tag: "Range" } | { tag: "Range" }
| { tag: "Fn"; params: Ty[]; retTy: Ty } | { tag: "Fn"; params: Ty[]; retTy: Ty; generics: Ty[] | null }
| { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> } | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> }
| { tag: "Generic"; idx: number }
| { tag: "Instance"; ty: Ty; args: Ty[] }
| { tag: "Any" | "AnyInt" } | { tag: "Any" | "AnyInt" }
| { tag: "AnyIndexable"; ty: Ty } | { tag: "AnyIndexable"; ty: Ty }
| { tag: "AnyCallable"; ty: Ty } | { tag: "AnyCallable"; ty: Ty }