diff --git a/src/ast.ts b/src/ast.ts index 0158656..da6cd8c 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -101,7 +101,7 @@ export class Node { case "IndexExpr": return visit(k.value, k.arg); case "CallExpr": - return visit(k.value, ...k.args); + return visit(k.value, ...k.generics ?? [], ...k.args); case "UnaryExpr": return visit(k.expr); case "BinaryExpr": @@ -117,6 +117,8 @@ export class Node { return visit(k.ty, k.length); case "SliceTy": return visit(k.ty); + case "Generic": + return visit(); } k satisfies never; } @@ -131,6 +133,7 @@ export type NodeKind = | { tag: "FnStmt"; ident: string; + genericParams: Node[] | null; params: Node[]; retTy: Node | null; body: Node; @@ -146,7 +149,7 @@ export type NodeKind = | { tag: "StrExpr"; value: string } | { tag: "ArrayExpr"; values: Node[] } | { tag: "IndexExpr"; value: Node; arg: Node } - | { tag: "CallExpr"; value: Node; args: Node[] } + | { tag: "CallExpr"; value: Node; generics: Node[] | null; args: Node[] } | { tag: "UnaryExpr"; op: UnaryOp; expr: Node; tok: string } | { tag: "BinaryExpr"; op: BinaryOp; left: Node; right: Node; tok: string } | { @@ -158,7 +161,8 @@ export type NodeKind = | { tag: "IdentTy"; ident: string } | { tag: "PtrTy" | "PtrMutTy"; ty: Node } | { tag: "ArrayTy"; ty: Node; length: Node } - | { tag: "SliceTy"; ty: Node }; + | { tag: "SliceTy"; ty: Node } + | { tag: "Generic"; ident: string }; export type IntTy = | "i8" diff --git a/src/front/check.ts b/src/front/check.ts index 45f5d92..0831f2b 100644 --- a/src/front/check.ts +++ b/src/front/check.ts @@ -61,6 +61,7 @@ export class CheckedFn { class TypeChecker { private nodeTys = new Map(); + private generics?: Ty[]; private params!: Ty[]; constructor( @@ -71,6 +72,12 @@ class TypeChecker { ) {} check(): CheckedFn { + const generics = this.fn.kind.genericParams + ?.map((_node, idx) => { + return Ty.Generic(idx); + }); + this.generics = generics; + const params = this.fn.kind.params .map((node) => { const param = node.as("Param"); @@ -194,6 +201,7 @@ class TypeChecker { case "PtrMutTy": case "ArrayTy": case "SliceTy": + case "Generic": break; default: k satisfies never; @@ -205,7 +213,7 @@ class TypeChecker { const ty = Ty.create("FnStmt", { stmt: this.fn, - ty: Ty.create("Fn", { params, retTy }), + ty: Ty.Fn(params, retTy, generics ?? null), }); return new CheckedFn(ty, this.nodeTys); @@ -250,24 +258,6 @@ class TypeChecker { private checkExpr(expr: ast.Node, expected: Ty): Ty { const k = expr.kind; switch (k.tag) { - case "Error": - case "File": - case "Block": - case "ExprStmt": - case "AssignStmt": - case "FnStmt": - case "ReturnStmt": - case "LetStmt": - case "IfStmt": - case "WhileStmt": - case "BreakStmt": - case "Param": - case "IdentTy": - case "PtrTy": - case "PtrMutTy": - case "ArrayTy": - case "SliceTy": - throw new Error(`node '${k.tag}' not an expression`); case "IdentExpr": { const sym = this.syms.get(expr); if (sym.tag === "Fn") { @@ -453,12 +443,7 @@ class TypeChecker { : null; if (sym?.tag === "Builtin") { if (sym.id === "len") { - checkArgs( - Ty.create("Fn", { - params: [Ty.Any], - retTy: Ty.USize, - }).as("Fn"), - ); + checkArgs(Ty.Fn([Ty.Any], Ty.USize, null).as("Fn")); return Ty.USize; } if (sym.id === "print") { @@ -483,8 +468,60 @@ class TypeChecker { } const callableTy = calleeTy.callableTy(); - checkArgs(callableTy); - return callableTy.kind.retTy; + + if (callableTy.kind.generics !== null) { + let generics: Ty[]; + if (k.generics) { + generics = k.generics.map((ty) => this.ty(ty)); + if ( + generics.length !== callableTy.kind.generics.length + ) { + this.reporter.error( + expr.loc, + `expected ${callableTy.kind.generics.length} generic type arguments, got ${generics.length}`, + ); + this.reporter.abort(); + } + } else { + generics = callableTy.kind.generics + .map((_ty) => Ty.Any); + } + + const newCalleeTy = Ty.Fn( + callableTy.kind.params + .map((ty) => + ty.is("Generic") ? generics[ty.kind.idx] : ty + ), + callableTy.kind.retTy.is("Generic") + ? generics[callableTy.kind.retTy.kind.idx] + : callableTy.kind.retTy, + null, + ).as("Fn"); + this.rewriteTree(k.value, newCalleeTy); + + checkArgs( + Ty.Fn( + callableTy.kind.params + .map((ty) => + ty.is("Generic") + ? generics[ty.kind.idx] + : ty + ), + callableTy.kind.retTy.is("Generic") + ? generics[callableTy.kind.retTy.kind.idx] + : callableTy.kind.retTy, + null, + ).as("Fn"), + ); + return callableTy.kind.retTy; + } else { + if (k.generics) { + this.reporter.error(expr.loc, "no generics expected"); + this.reporter.abort(); + } + checkArgs(callableTy); + return callableTy.kind.retTy; + } } case "UnaryExpr": { switch (k.op) { @@ -578,6 +615,25 @@ class TypeChecker { } return Ty.create("Range", {}); } + case "Error": + case "File": + case "Block": + case "ExprStmt": + case "AssignStmt": + case "FnStmt": + case "ReturnStmt": + case "LetStmt": + case "IfStmt": + case "WhileStmt": + case "BreakStmt": + case "Param": + case "IdentTy": + case "PtrTy": + case "PtrMutTy": + case "ArrayTy": + case "SliceTy": + case "Generic": + throw new Error(`node '${k.tag}' not an expression`); default: k satisfies never; throw new Error(); @@ -591,28 +647,6 @@ class TypeChecker { private checkTy(ty: ast.Node): Ty { const k = ty.kind; switch (k.tag) { - case "Error": - case "File": - case "Block": - case "ExprStmt": - case "AssignStmt": - case "FnStmt": - case "ReturnStmt": - case "LetStmt": - case "IfStmt": - case "WhileStmt": - case "BreakStmt": - case "Param": - case "IdentExpr": - case "IntExpr": - case "StrExpr": - case "ArrayExpr": - case "IndexExpr": - case "CallExpr": - case "UnaryExpr": - case "BinaryExpr": - case "RangeExpr": - throw new Error(`node '${k.tag}' not a type`); case "IdentTy": { const sym = this.syms.get(ty); if (sym.tag === "BuiltinTy") { @@ -647,6 +681,12 @@ class TypeChecker { ); } } + if (sym.tag === "Generic") { + if (!this.generics) { + throw new Error(); + } + return this.generics[sym.idx]; + } this.reporter.error(ty.loc, `symbol is not a type`); return this.reporter.abort(); } @@ -682,6 +722,29 @@ class TypeChecker { const ty = this.ty(k.ty); return Ty.create("Slice", { ty }); } + case "Error": + case "File": + case "Block": + case "ExprStmt": + case "AssignStmt": + case "FnStmt": + case "ReturnStmt": + case "LetStmt": + case "IfStmt": + case "WhileStmt": + case "BreakStmt": + case "Param": + case "IdentExpr": + case "IntExpr": + case "StrExpr": + case "ArrayExpr": + case "IndexExpr": + case "CallExpr": + case "UnaryExpr": + case "BinaryExpr": + case "RangeExpr": + case "Generic": + throw new Error(`node '${k.tag}' not a type`); } } diff --git a/src/front/parse.ts b/src/front/parse.ts index bc7eacd..95f266c 100644 --- a/src/front/parse.ts +++ b/src/front/parse.ts @@ -71,6 +71,7 @@ export class Parser { const loc = this.loc(); this.step(); const ident = this.mustEat("ident").value; + const genericParams = this.parseGenericParams(); this.mustEat("("); const params: ast.Node[] = []; if (!this.test(")")) { @@ -88,7 +89,13 @@ export class Parser { retTy = this.parseTy(); } const body = this.parseBlock(); - return ast.Node.create(loc, "FnStmt", { ident, params, retTy, body }); + return ast.Node.create(loc, "FnStmt", { + ident, + genericParams, + params, + retTy, + body, + }); } parseReturnStmt(): ast.Node { @@ -259,19 +266,12 @@ export class Parser { const arg = this.parseExpr(); this.mustEat("]"); expr = ast.Node.create(loc, "IndexExpr", { value: expr, arg }); + } else if (this.test("::<")) { + const generics = this.parseGenericArgs(); + this.mustEat("("); + expr = this.parseCallExprTail(expr, loc, generics); } else if (this.eat("(")) { - const args: ast.Node[] = []; - if (!this.test(")")) { - args.push(this.parseExpr()); - while (this.eat(",")) { - if (this.done || this.test(")")) { - break; - } - args.push(this.parseExpr()); - } - } - this.mustEat(")"); - expr = ast.Node.create(loc, "CallExpr", { value: expr, args }); + expr = this.parseCallExprTail(expr, loc, null); } else { break; } @@ -279,6 +279,29 @@ export class Parser { return expr; } + parseCallExprTail( + expr: ast.Node, + loc: Loc, + generics: ast.Node[] | null, + ): ast.Node { + const args: ast.Node[] = []; + if (!this.test(")")) { + args.push(this.parseExpr()); + while (this.eat(",")) { + if (this.done || this.test(")")) { + break; + } + args.push(this.parseExpr()); + } + } + this.mustEat(")"); + return ast.Node.create(loc, "CallExpr", { + value: expr, + generics, + args, + }); + } + parseOperand(): ast.Node { const loc = this.loc(); if (this.test("ident")) { @@ -361,6 +384,38 @@ export class Parser { } } + parseGenericArgs(): ast.Node[] | null { + if (!this.eat("::<")) { + return null; + } + const args: ast.Node[] = []; + while (!this.done && !this.test("<")) { + args.push(this.parseTy()); + if (!this.eat(",")) { + break; + } + } + this.mustEat(">"); + return args; + } + + parseGenericParams(): ast.Node[] | null { + if (!this.eat("<")) { + return null; + } + const params: ast.Node[] = []; + while (!this.done && !this.test("<")) { + const loc = this.loc(); + const identTok = this.mustEat("ident"); + params.push(ast.create(loc, "Generic", { ident: identTok.value })); + if (!this.eat(",")) { + break; + } + } + this.mustEat(">"); + return params; + } + private mustEat(type: string, loc = this.loc()): Tok { const tok = this.current; if (tok.type !== type) { @@ -423,7 +478,7 @@ const keywordPattern = /^(?:(?:fn)|(?:return)|(?:let)|(?:if)|(?:else)|(?:while)|(?:break)|(?:or)|(?:and)|(?:not)|(?:mut))/; const operatorPattern2 = - /((?:\->)|(?:==)|(?:!=)|(?:<=)|(?:>=)|(?:<<)|(?:>>)|(?:\.\*)|(?:\.\.)|(?:\.\.=)|[\n\(\)\{\}\[\]\,\.\;\:\!\=\<\>\&\^\|\+\-\*\/\%])/g; + /((?:\->)|(?:==)|(?:!=)|(?:<=)|(?:>=)|(?:\:\:<)|(?:<<)|(?:>>)|(?:\.\*)|(?:\.\.)|(?:\.\.=)|[\n\(\)\{\}\[\]\,\.\;\:\!\=\<\>\&\^\|\+\-\*\/\%])/g; export function tokenize(text: string, reporter: FileReporter): Tok[] { return new Lexer() diff --git a/src/front/resolve.ts b/src/front/resolve.ts index f9cc4bd..50916b3 100644 --- a/src/front/resolve.ts +++ b/src/front/resolve.ts @@ -19,6 +19,12 @@ export type Sym = | { tag: "Bool"; value: boolean } | { tag: "Builtin"; id: string } | { tag: "Fn"; stmt: ast.NodeWithKind<"FnStmt"> } + | { + tag: "Generic"; + stmt: ast.Node; + generic: ast.NodeWithKind<"Generic">; + idx: number; + } | { tag: "FnParam"; stmt: ast.NodeWithKind<"FnStmt">; @@ -61,6 +67,19 @@ export function resolve( if (k.tag === "FnStmt") { ast.assertNodeWithKind(node, "FnStmt"); syms = ResolverSyms.forkFrom(syms); + if (k.genericParams) { + for (const [idx, param] of k.genericParams?.entries()) { + ast.assertNodeWithKind(param, "Generic"); + const sym: Sym = { + tag: "Generic", + stmt: node, + generic: param, + idx, + }; + syms.define(param.kind.ident, sym); + resols.set(param.id, sym); + } + } for (const [idx, param] of k.params.entries()) { ast.assertNodeWithKind(param, "Param"); const sym: Sym = { tag: "FnParam", stmt: node, param, idx }; diff --git a/src/ty.ts b/src/ty.ts index cb2e143..11e2420 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -48,6 +48,15 @@ export class Ty { static Array(ty: Ty, length: number): Ty { return this.create("Array", { ty, length }); } + static Fn(params: Ty[], retTy: Ty, generics: Ty[] | null): Ty { + return this.create("Fn", { params, retTy, generics }); + } + static Generic(idx: number): Ty { + return this.create("Generic", { idx }); + } + static Instance(ty: Ty, args: Ty[]): Ty { + return this.create("Instance", { ty, args }); + } /** Only used in type checker. */ static Any = Ty.create("Any", {}); @@ -67,6 +76,9 @@ export class Ty { } private internHash(): string { + if (this.is("FnStmt")) { + return JSON.stringify({ ...this.kind, stmt: this.kind.stmt.id }); + } return JSON.stringify(this.kind); } @@ -244,8 +256,10 @@ export type TyKind = | { tag: "Array"; ty: Ty; length: number } | { tag: "Slice"; ty: Ty } | { tag: "Range" } - | { tag: "Fn"; params: Ty[]; retTy: Ty } + | { tag: "Fn"; params: Ty[]; retTy: Ty; generics: Ty[] | null } | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> } + | { tag: "Generic"; idx: number } + | { tag: "Instance"; ty: Ty; args: Ty[] } | { tag: "Any" | "AnyInt" } | { tag: "AnyIndexable"; ty: Ty } | { tag: "AnyCallable"; ty: Ty }