karlkode generic type inferrence

This commit is contained in:
sfja 2024-12-26 03:56:59 +01:00
parent 5150090d2d
commit a4c1b60a61
6 changed files with 329 additions and 68 deletions

View File

@ -1,5 +1,5 @@
import { Pos } from "./token.ts"; import { Pos } from "./token.ts";
import { VType } from "./vtype.ts"; import { GenericArgsMap, VType } from "./vtype.ts";
export type Stmt = { export type Stmt = {
kind: StmtKind; kind: StmtKind;
@ -48,7 +48,12 @@ export type ExprKind =
| { type: "group"; expr: Expr } | { type: "group"; expr: Expr }
| { type: "field"; subject: Expr; ident: string } | { type: "field"; subject: Expr; ident: string }
| { type: "index"; subject: Expr; value: Expr } | { type: "index"; subject: Expr; value: Expr }
| { type: "call"; subject: Expr; args: Expr[] } | {
type: "call";
subject: Expr;
args: Expr[];
genericArgs?: GenericArgsMap;
}
| { type: "path"; subject: Expr; ident: string } | { type: "path"; subject: Expr; ident: string }
| { type: "etype_args"; subject: Expr; etypeArgs: EType[] } | { type: "etype_args"; subject: Expr; etypeArgs: EType[] }
| { type: "unary"; unaryType: UnaryType; subject: Expr } | { type: "unary"; unaryType: UnaryType; subject: Expr }

View File

@ -414,23 +414,111 @@ export class Checker {
const subject = this.checkExpr(expr.kind.subject); const subject = this.checkExpr(expr.kind.subject);
if (subject.type === "error") return subject; if (subject.type === "error") return subject;
if (subject.type === "fn") { if (subject.type === "fn") {
if (subject.genericParams !== undefined) { if (expr.kind.args.length !== subject.params.length) {
throw new Error("😭😭😭");
}
const args = expr.kind.args.map((arg) => this.checkExpr(arg));
if (args.length !== subject.params.length) {
this.report( this.report(
`incorrect number of arguments` + `incorrect number of arguments` +
`, expected ${subject.params.length}`, `, expected ${subject.params.length}`,
pos, pos,
); );
} }
const args = expr.kind.args.map((arg) => this.checkExpr(arg));
if (subject.genericParams === undefined) {
return this.checkCallExprNoGenericsTail(
expr,
subject,
args,
pos,
);
}
return this.checkCallExprInferredGenericsTail(
expr,
subject,
args,
pos,
);
}
if (subject.type === "generic_spec" && subject.subject.type === "fn") {
return this.checkCallExprExplicitGenericsTail(expr, subject);
}
this.report("cannot call non-fn", pos);
return { type: "error" };
}
private checkCallExprNoGenericsTail(
expr: Expr,
subject: VType,
args: VType[],
pos: Pos,
): VType {
if (
expr.kind.type !== "call" || subject.type !== "fn"
) {
throw new Error();
}
for (let i = 0; i < args.length; ++i) {
if (this.vtypeContainsGeneric(args[i])) {
this.report(
`amfibious generic parameter for argument ${i}, please specify generic types explicitly`,
pos,
);
return { type: "error" };
}
}
for (let i = 0; i < args.length; ++i) { for (let i = 0; i < args.length; ++i) {
if (!vtypesEqual(args[i], subject.params[i].vtype)) { if (!vtypesEqual(args[i], subject.params[i].vtype)) {
this.report(
`incorrect argument ${i} '${subject.params[i].ident}'` +
`, expected ${vtypeToString(subject.params[i].vtype)}` +
`, got ${vtypeToString(args[i])}`,
pos,
);
break;
}
}
return subject.returnType;
}
private checkCallExprInferredGenericsTail(
expr: Expr,
subject: VType,
args: VType[],
pos: Pos,
): VType {
if (
expr.kind.type !== "call" || subject.type !== "fn" ||
subject.genericParams === undefined
) {
throw new Error();
}
const genericArgsRes = this.inferGenericArgs(
subject.genericParams,
subject.params,
args,
pos,
);
if (!genericArgsRes.ok) {
return { type: "error" };
}
const genericArgs = genericArgsRes.value;
for (let i = 0; i < args.length; ++i) {
const vtypeCompatible = vtypesEqual(
args[i],
subject.params[i].vtype,
genericArgs,
);
if (!vtypeCompatible) {
this.report( this.report(
`incorrect argument ${i} '${subject.params[i].ident}'` + `incorrect argument ${i} '${subject.params[i].ident}'` +
`, expected ${ `, expected ${
vtypeToString(subject.params[i].vtype) vtypeToString(
extractGenericType(
subject.params[i].vtype,
genericArgs,
),
)
}` + }` +
`, got ${vtypeToString(args[i])}`, `, got ${vtypeToString(args[i])}`,
pos, pos,
@ -438,9 +526,99 @@ export class Checker {
break; break;
} }
} }
return subject.returnType;
expr.kind.genericArgs = genericArgs;
return this.concretizeVType(subject.returnType, genericArgs);
} }
if (subject.type === "generic_spec" && subject.subject.type === "fn") {
private inferGenericArgs(
genericParams: VTypeGenericParam[],
params: VTypeParam[],
args: VType[],
pos: Pos,
): { ok: true; value: GenericArgsMap } | { ok: false } {
const genericArgs: GenericArgsMap = {};
for (let i = 0; i < params.length; ++i) {
if (!this.vtypeContainsGeneric(params[i].vtype)) {
continue;
}
const {
a: generic,
b: concrete,
} = this.reduceToSignificant(params[i].vtype, args[i]);
if (generic.type !== "generic") {
throw new Error();
}
const paramId = generic.param.id;
if (
paramId in genericArgs &&
!vtypesEqual(genericArgs[paramId], concrete)
) {
this.report(
`according to inferrence, argument ${i} has a conflicting type`,
pos,
);
return { ok: false };
}
genericArgs[paramId] = concrete;
}
for (const param of genericParams) {
if (!(param.id in genericArgs)) {
this.report(`could not infer generic type ${param.ident}`, pos);
return { ok: false };
}
}
return { ok: true, value: genericArgs };
}
private reduceToSignificant(a: VType, b: VType): { a: VType; b: VType } {
if (a.type !== b.type) {
return { a, b };
}
if (a.type === "array" && b.type === "array") {
return this.reduceToSignificant(a.inner, b.inner);
}
throw new Error("idk what to do here");
}
private vtypeContainsGeneric(vtype: VType): boolean {
switch (vtype.type) {
case "error":
case "string":
case "unknown":
case "null":
case "int":
case "bool":
return false;
case "array":
return this.vtypeContainsGeneric(vtype.inner);
case "struct":
return vtype.fields.some((field) =>
this.vtypeContainsGeneric(field.vtype)
);
case "fn":
throw new Error("not implemented");
case "generic":
return true;
case "generic_spec":
throw new Error("listen kid, grrrrrrrr");
}
}
private checkCallExprExplicitGenericsTail(
expr: Expr,
subject: VType,
): VType {
if (
expr.kind.type !== "call" || subject.type !== "generic_spec" ||
subject.subject.type !== "fn"
) {
throw new Error();
}
const pos = expr.pos;
const inner = subject.subject; const inner = subject.subject;
const params = inner.params; const params = inner.params;
const args = expr.kind.args.map((arg) => this.checkExpr(arg)); const args = expr.kind.args.map((arg) => this.checkExpr(arg));
@ -474,13 +652,46 @@ export class Checker {
break; break;
} }
} }
return extractGenericType(
expr.kind.genericArgs = subject.genericArgs;
return this.concretizeVType(
subject.subject.returnType, subject.subject.returnType,
subject.genericArgs, subject.genericArgs,
); );
} }
this.report("cannot call non-fn", pos);
return { type: "error" }; private concretizeVType(
vtype: VType,
generics: GenericArgsMap,
): VType {
switch (vtype.type) {
case "error":
case "unknown":
case "string":
case "null":
case "int":
case "bool":
return vtype;
case "array":
return {
type: "array",
inner: this.concretizeVType(vtype.inner, generics),
};
case "struct":
return {
type: "struct",
fields: vtype.fields.map((field) => ({
...field,
vtype: this.concretizeVType(field.vtype, generics),
})),
};
case "fn":
throw new Error("not implemented");
case "generic":
return generics[vtype.param.id];
case "generic_spec":
throw new Error("not implemented");
}
} }
public checkPathExpr(expr: Expr): VType { public checkPathExpr(expr: Expr): VType {

View File

@ -50,7 +50,7 @@ export class Compiler {
const lowerer = new Lowerer(monoFns, callMap, lexer.currentPos()); const lowerer = new Lowerer(monoFns, callMap, lexer.currentPos());
const { program, fnNames } = lowerer.lower(); const { program, fnNames } = lowerer.lower();
//lowerer.printProgram(); lowerer.printProgram();
return { program, fnNames }; return { program, fnNames };
} }

View File

@ -3,6 +3,7 @@ import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts";
import { GenericArgsMap, VType } from "./vtype.ts"; import { GenericArgsMap, VType } from "./vtype.ts";
export class Monomorphizer { export class Monomorphizer {
private fnIdCounter = 0;
private fns: MonoFnsMap = {}; private fns: MonoFnsMap = {};
private callMap: MonoCallNameGenMap = {}; private callMap: MonoCallNameGenMap = {};
private allFns: Map<number, Stmt>; private allFns: Map<number, Stmt>;
@ -22,11 +23,13 @@ export class Monomorphizer {
stmt: Stmt, stmt: Stmt,
genericArgs?: GenericArgsMap, genericArgs?: GenericArgsMap,
): MonoFn { ): MonoFn {
const nameGen = monoFnNameGen(stmt, genericArgs); const id = this.fnIdCounter;
this.fnIdCounter += 1;
const nameGen = monoFnNameGen(id, stmt, genericArgs);
if (nameGen in this.fns) { if (nameGen in this.fns) {
return this.fns[nameGen]; return this.fns[nameGen];
} }
const monoFn = { nameGen, stmt, genericArgs }; const monoFn = { id, nameGen, stmt, genericArgs };
this.fns[nameGen] = monoFn; this.fns[nameGen] = monoFn;
const calls = new CallCollector().collect(stmt); const calls = new CallCollector().collect(stmt);
for (const call of calls) { for (const call of calls) {
@ -34,7 +37,10 @@ export class Monomorphizer {
if (call.kind.type !== "call") { if (call.kind.type !== "call") {
throw new Error(); throw new Error();
} }
if (call.kind.subject.vtype?.type === "fn") { if (
call.kind.subject.vtype?.type === "fn" &&
call.kind.subject.vtype.genericParams === undefined
) {
const fn = this.allFns.get(call.kind.subject.vtype.stmtId); const fn = this.allFns.get(call.kind.subject.vtype.stmtId);
if (fn === undefined) { if (fn === undefined) {
throw new Error(); throw new Error();
@ -43,6 +49,40 @@ export class Monomorphizer {
this.callMap[call.id] = monoFn.nameGen; this.callMap[call.id] = monoFn.nameGen;
continue; continue;
} }
if (
call.kind.subject.vtype?.type === "fn" &&
call.kind.subject.vtype.genericParams !== undefined
) {
if (call.kind.genericArgs === undefined) {
throw new Error();
}
const genericArgs = call.kind.genericArgs;
const monoArgs: GenericArgsMap = {};
for (const key in call.kind.genericArgs) {
const vtype = genericArgs[key];
if (vtype.type === "generic") {
if (genericArgs === undefined) {
throw new Error();
}
monoArgs[key] = genericArgs[vtype.param.id];
} else {
monoArgs[key] = vtype;
}
}
const fnType = call.kind.subject.vtype!;
if (fnType.type !== "fn") {
throw new Error();
}
const fn = this.allFns.get(fnType.stmtId);
if (fn === undefined) {
throw new Error();
}
const monoFn = this.monomorphizeFn(fn, monoArgs);
this.callMap[call.id] = monoFn.nameGen;
continue;
}
if (call.kind.subject.vtype?.type === "generic_spec") { if (call.kind.subject.vtype?.type === "generic_spec") {
const genericSpecType = call.kind.subject.vtype!; const genericSpecType = call.kind.subject.vtype!;
if (genericSpecType.subject.type !== "fn") { if (genericSpecType.subject.type !== "fn") {
@ -85,6 +125,7 @@ export type MonoResult = {
export type MonoFnsMap = { [nameGen: string]: MonoFn }; export type MonoFnsMap = { [nameGen: string]: MonoFn };
export type MonoFn = { export type MonoFn = {
id: number;
nameGen: string; nameGen: string;
stmt: Stmt; stmt: Stmt;
genericArgs?: GenericArgsMap; genericArgs?: GenericArgsMap;
@ -92,7 +133,11 @@ export type MonoFn = {
export type MonoCallNameGenMap = { [exprId: number]: string }; export type MonoCallNameGenMap = { [exprId: number]: string };
function monoFnNameGen(stmt: Stmt, genericArgs?: GenericArgsMap): string { function monoFnNameGen(
id: number,
stmt: Stmt,
genericArgs?: GenericArgsMap,
): string {
if (stmt.kind.type !== "fn") { if (stmt.kind.type !== "fn") {
throw new Error(); throw new Error();
} }
@ -100,12 +145,12 @@ function monoFnNameGen(stmt: Stmt, genericArgs?: GenericArgsMap): string {
return "main"; return "main";
} }
if (genericArgs === undefined) { if (genericArgs === undefined) {
return `${stmt.kind.ident}_${stmt.id}`; return `${stmt.kind.ident}_${id}`;
} }
const args = Object.values(genericArgs) const args = Object.values(genericArgs)
.map((arg) => vtypeNameGenPart(arg)) .map((arg) => vtypeNameGenPart(arg))
.join("_"); .join("_");
return `${stmt.kind.ident}_${stmt.id}_${args}`; return `${stmt.kind.ident}_${id}_${args}`;
} }
function vtypeNameGenPart(vtype: VType): string { function vtypeNameGenPart(vtype: VType): string {

View File

@ -1,4 +1,4 @@
//
fn array_new<T>() -> [T] #[builtin(ArrayNew)] {} fn array_new<T>() -> [T] #[builtin(ArrayNew)] {}
fn array_push<T>(array: [T], value: T) #[builtin(ArrayPush)] {} fn array_push<T>(array: [T], value: T) #[builtin(ArrayPush)] {}
@ -8,12 +8,12 @@ fn array_at<T>(array: [T], index: int) -> string #[builtin(ArrayAt)] {}
fn main() { fn main() {
let strings = array_new::<string>(); let strings = array_new::<string>();
array_push::<string>(strings, "hello"); array_push(strings, "hello");
array_push::<string>(strings, "world"); array_push(strings, "world");
let ints = array_new::<int>(); let ints = array_new::<int>();
array_push::<int>(ints, 1); array_push(ints, 1);
array_push::<int>(ints, 2); array_push(ints, 2);
} }

View File

@ -13,7 +13,7 @@ fi
echo Compiling $1... echo Compiling $1...
deno run --allow-read --allow-write compiler/main.ts $1 deno run --allow-read --allow-write --check compiler/main.ts $1
echo Running out.slgbc... echo Running out.slgbc...