diff --git a/src/ast.ts b/src/ast.ts index d652445..0158656 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -170,7 +170,8 @@ export type IntTy = | "u16" | "u32" | "u64" - | "usize"; + | "usize" + | "any"; export type UnaryOp = | "Not" diff --git a/src/front/check.ts b/src/front/check.ts index 5c1c151..c9d5629 100644 --- a/src/front/check.ts +++ b/src/front/check.ts @@ -256,6 +256,8 @@ class ExprChecker { return Ty.U32; case "u64": return Ty.U64; + case "usize": + return Ty.USize; case "i8": return Ty.U8; case "i16": @@ -264,6 +266,10 @@ class ExprChecker { return Ty.I32; case "i64": return Ty.U64; + case "isize": + return Ty.ISize; + case "any": + return Ty.I32; default: throw new Error(`intType '${node.kind.intTy}' not handled`); } diff --git a/src/front/check2.ts b/src/front/check2.ts index 4699d3c..1042184 100644 --- a/src/front/check2.ts +++ b/src/front/check2.ts @@ -2,6 +2,7 @@ import { Syms } from "./resolve.ts"; import * as ast from "../ast.ts"; import { Ty } from "../ty.ts"; import { FileReporter, Loc } from "../diagnostics.ts"; +import { exit } from "node:process"; export function checkFn( fn: ast.FnStmt, @@ -57,22 +58,107 @@ class TypeChecker { this.fn.kind.body.visit({ visit: (node) => { - if (!node.is("ReturnStmt")) { - return; - } - const ty = node.kind.expr - ? this.expr(node.kind.expr, retTy) - : Ty.Void; - if (!ty.resolvableWith(retTy)) { - this.reporter.error( - node.loc, - `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, - ); - this.reporter.info( - this.fn.kind.retTy?.loc ?? this.fn.loc, - `return type '${retTy}' defined here`, - ); - this.reporter.abort(); + const k = node.kind; + switch (k.tag) { + case "ReturnStmt": { + const ty = k.expr ? this.expr(k.expr, retTy) : Ty.Void; + if (!ty.resolvableWith(retTy)) { + this.reporter.error( + node.loc, + `type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`, + ); + this.reporter.info( + this.fn.kind.retTy?.loc ?? this.fn.loc, + `return type '${retTy}' defined here`, + ); + this.reporter.abort(); + } + break; + } + case "LetStmt": { + let ty: Ty; + const paramTy = k.param.as("Param").kind.ty; + if (paramTy) { + const explicitTy = this.ty(paramTy); + const exprTy = this.expr(k.expr, explicitTy); + const res = this.resolve( + exprTy, + explicitTy, + k.expr.loc, + ); + if (res.rewriteSubtree) { + this.rewriteTree(k.expr, res.ty); + } + ty = res.ty; + } else { + ty = this.expr(k.expr, Ty.Any); + } + this.nodeTys.set(node, ty); + break; + } + case "ExprStmt": { + this.expr(k.expr, Ty.Any); + break; + } + case "AssignStmt": { + const placeTy = this.expr(k.place, Ty.Any); + const exprTy = this.expr(k.expr, Ty.Any); + if (!placeTy.resolvableWith(exprTy)) { + this.reporter.error( + k.expr.loc, + `type '${exprTy.pretty()}' not assignable to type '${placeTy.pretty()}'`, + ); + this.reporter.abort(); + } + break; + } + case "IfStmt": { + const condTy = this.expr(k.cond, Ty.Bool); + if (!condTy.is("Bool")) { + this.reporter.error( + k.cond.loc, + `expected condition to be 'bool', got '${condTy.pretty()}'`, + ); + this.reporter.abort(); + } + break; + } + case "WhileStmt": { + const condTy = this.expr(k.cond, Ty.Bool); + if (!condTy.is("Bool")) { + this.reporter.error( + k.cond.loc, + `expected condition to be 'bool', got '${condTy.pretty()}'`, + ); + this.reporter.abort(); + } + break; + } + case "BreakStmt": { + break; + } + case "FnStmt": + case "Error": + case "File": + case "Block": + case "Param": + case "IdentExpr": + case "IntExpr": + case "StrExpr": + case "ArrayExpr": + case "IndexExpr": + case "CallExpr": + case "UnaryExpr": + case "BinaryExpr": + case "RangeExpr": + case "IdentTy": + case "PtrTy": + case "PtrMutTy": + case "ArrayTy": + case "SliceTy": + break; + default: + k satisfies never; } }, }); @@ -111,10 +197,64 @@ class TypeChecker { case "SliceTy": throw new Error(`node '${k.tag}' not an expression`); case "IdentExpr": { - throw new Error("not implemented"); + const sym = this.syms.get(expr); + if (sym.tag === "Fn") { + throw new Error("todo"); + } + if (sym.tag === "Bool") { + return Ty.Bool; + } + if (sym.tag === "Builtin") { + this.reporter.error( + expr.loc, + `invalid use of builtin '${sym.id}'`, + ); + this.reporter.abort(); + } + if (sym.tag === "FnParam") { + return this.expr(sym.param, Ty.Any); + } + if (sym.tag === "Let") { + const ty = this.nodeTys.get(sym.stmt); + if (!ty) { + throw new Error(); + } + return ty; + } + throw new Error(`'${sym.tag}' not handled`); } case "IntExpr": { - return this.resolve(Ty.AnyInt, expected, expr.loc).ty; + const intTy = (() => { + switch (k.intTy) { + case "u8": + return Ty.U8; + case "u16": + return Ty.U16; + case "u32": + return Ty.U32; + case "u64": + return Ty.U64; + case "usize": + return Ty.ISize; + case "i8": + return Ty.U8; + case "i16": + return Ty.U16; + case "i32": + return Ty.I32; + case "i64": + return Ty.U64; + case "isize": + return Ty.ISize; + case "any": + return Ty.AnyInt; + default: + throw new Error( + `intType '${k.intTy}' not handled`, + ); + } + })(); + return this.resolve(intTy, expected, expr.loc).ty; } case "StrExpr": { return this.resolve( @@ -168,19 +308,191 @@ class TypeChecker { return res.ty; } case "IndexExpr": { - throw new Error("not implemented"); + const innerTy = this.expr( + k.value, + Ty.Indexable(expected), + ); + if (!innerTy.isIndexable()) { + this.reporter.error( + expr.loc, + `expected indexable type, got '${expected.pretty()}'`, + ); + this.reporter.abort(); + } + + let argTy = this.expr(k.arg, Ty.Any); + + if (argTy.is("AnyInt")) { + argTy = Ty.USize; + this.rewriteTree(k.arg, argTy); + return innerTy.indexableTy()!; + } else if (argTy.is("Range")) { + return Ty.Slice(innerTy.indexableTy()!); + } else { + throw new Error(); + } } case "CallExpr": { - throw new Error("not implemented"); + const checkArgs = ( + callableTy: Ty & { kind: { tag: "Fn" } }, + ) => { + const params = callableTy.kind.params; + if (k.args.length !== params.length) { + this.reporter.error( + expr.loc, + `incorrect amount of arguments. got ${k.args.length} expected ${params.length}`, + ); + if (calleeTy?.is("FnStmt")) { + this.reporter.info( + calleeTy.kind.stmt.loc, + "function defined here", + ); + } + this.reporter.abort(); + } + const args = k.args + .map((arg, i) => this.expr(arg, params[i])); + for (const i of args.keys()) { + this.resolve(args[i], params[i], k.args[i].loc, () => { + if (calleeTy?.is("FnStmt")) { + this.reporter.info( + calleeTy.kind.stmt.kind.params[i].loc, + `parameter '${ + calleeTy.kind.stmt.kind.params[i] + .as("Param").kind.ident + }' defined here`, + ); + } + this.reporter.error( + k.args[i].loc, + `type '${ + args[i].pretty() + }' not compatible with type '${ + params[i].pretty() + }', for argument ${i}`, + ); + }); + } + }; + + const sym = k.value.is("IdentExpr") + ? this.syms.get(k.value) + : null; + if (sym?.tag === "Builtin") { + if (sym.id === "len") { + checkArgs( + Ty.create("Fn", { + params: [Ty.Any], + retTy: Ty.USize, + }).as("Fn"), + ); + return Ty.USize; + } + if (sym.id === "print") { + void k.args + .map((arg) => this.expr(arg, Ty.Any)); + return Ty.Void; + } + throw new Error(`builtin '${sym.id}' not handled`); + } + + const calleeTy = this.expr(k.value, Ty.Callable(expected)); + if (!calleeTy.isCallable()) { + this.reporter.error( + expr.loc, + `expected callable type, got '${expected.pretty()}'`, + ); + this.reporter.abort(); + } + + const callableTy = calleeTy.callableTy(); + checkArgs(callableTy); + return callableTy.kind.retTy; } case "UnaryExpr": { - throw new Error("not implemented"); + switch (k.op) { + case "Not": { + const ty = this.expr(k.expr, Ty.Bool); + if (!ty.is("Bool")) { + this.reporter.error( + expr.loc, + `expected 'bool', got '${expected.pretty()}'`, + ); + this.reporter.abort(); + } + return ty; + } + case "Neg": { + const ty = this.expr(k.expr, expected); + if (!ty.is("Int")) { + this.reporter.error( + expr.loc, + `cannot apply ! to '${expected.pretty()}'`, + ); + this.reporter.abort(); + } + return ty; + } + case "Ref": { + const ty = this.expr(k.expr, expected); + return Ty.Ptr(ty); + } + case "RefMut": { + const ty = this.expr(k.expr, expected); + return Ty.Ptr(ty); + } + case "Deref": { + const ty = this.expr(k.expr, expected); + if (!ty.is("Ptr") && !ty.is("PtrMut")) { + this.reporter.error( + expr.loc, + `cannot dereference type '${expected.pretty()}'`, + ); + this.reporter.abort(); + } + if (!ty.kind.ty.isSized()) { + this.reporter.error( + expr.loc, + `cannot dereference unsized type '${ty.kind.ty.pretty()}' in an expression`, + ); + this.reporter.abort(); + } + return ty.kind.ty; + } + default: + k satisfies never; + throw new Error(); + } } case "BinaryExpr": { - throw new Error("not implemented"); + const left = this.expr(k.left, expected); + const right = this.expr(k.right, expected); + const result = binaryOpTests + .map((test) => test(k.op, left, right)) + .filter((result) => result) + .at(0); + if (!result) { + this.reporter.error( + expr.loc, + `operator '${k.tok}' cannot be applied to types '${left.pretty()}' and '${right.pretty()}'`, + ); + this.reporter.abort(); + } + return result; } case "RangeExpr": { - throw new Error("not implemented"); + for (const operandExpr of [k.begin, k.end]) { + const operandTy = operandExpr && + this.expr(operandExpr, Ty.USize); + if (operandTy && !operandTy.resolvableWith(Ty.I32)) { + this.reporter.error( + operandExpr.loc, + `range operand must be '${Ty.I32.pretty()}', not '${operandTy.pretty()}'`, + ); + this.reporter.abort(); + } + } + return Ty.create("Range", {}); } default: k satisfies never; @@ -193,10 +505,103 @@ class TypeChecker { } private checkTy(ty: ast.Node): Ty { - throw new Error("not implemented"); + const k = ty.kind; + switch (k.tag) { + case "Error": + case "File": + case "Block": + case "ExprStmt": + case "AssignStmt": + case "FnStmt": + case "ReturnStmt": + case "LetStmt": + case "IfStmt": + case "WhileStmt": + case "BreakStmt": + case "Param": + case "IdentExpr": + case "IntExpr": + case "StrExpr": + case "ArrayExpr": + case "IndexExpr": + case "CallExpr": + case "UnaryExpr": + case "BinaryExpr": + case "RangeExpr": + throw new Error(`node '${k.tag}' not a type`); + case "IdentTy": { + const sym = this.syms.get(ty); + if (sym.tag === "BuiltinTy") { + switch (sym.ident) { + case "void": + return Ty.Void; + case "bool": + return Ty.Bool; + case "i8": + return Ty.I8; + case "i16": + return Ty.I16; + case "i32": + return Ty.I32; + case "i64": + return Ty.I64; + case "isize": + return Ty.ISize; + case "u8": + return Ty.U8; + case "u16": + return Ty.U16; + case "u32": + return Ty.U32; + case "u64": + return Ty.U64; + case "usize": + return Ty.USize; + default: + throw new Error( + `unknown type '${k.ident}'`, + ); + } + } + this.reporter.error(ty.loc, `symbol is not a type`); + return this.reporter.abort(); + } + case "PtrTy": { + const ty = this.ty(k.ty); + return Ty.create("Ptr", { ty }); + } + case "PtrMutTy": { + const ty = this.ty(k.ty); + return Ty.create("PtrMut", { ty }); + } + case "ArrayTy": { + const ty = this.ty(k.ty); + const lengthTy = this.expr(k.length, Ty.USize); + if (!lengthTy.resolvableWith(Ty.I32)) { + this.reporter.error( + k.length.loc, + `for array length, expected 'int', got '${lengthTy.pretty()}'`, + ); + this.reporter.abort(); + } + if (!k.length.is("IntExpr")) { + this.reporter.error( + k.length.loc, + `array length must be an 'int' expression`, + ); + this.reporter.abort(); + } + const length = k.length.kind.value; + return Ty.create("Array", { ty, length }); + } + case "SliceTy": { + const ty = this.ty(k.ty); + return Ty.create("Slice", { ty }); + } + } } - private cachedCheck(node: ast.Node, func: () => Ty) { + private cachedCheck(node: ast.Node, func: () => Ty): Ty { const existing = this.nodeTys.get(node); if (existing !== undefined) { return existing; @@ -206,15 +611,32 @@ class TypeChecker { return ty; } - private resolve(ty: Ty, expected: Ty, loc: Loc): TyRes { + private resolve( + ty: Ty, + expected: Ty, + loc: Loc, + inCaseOfError?: () => void, + ): TyRes { + if (ty == expected) { + return { ty, rewriteSubtree: false }; + } + if (expected.is("Any")) { + return { ty, rewriteSubtree: false }; + } if (!ty.resolvableWith(expected)) { - this.reporter.error( - loc, - `expected type '${expected.pretty()}', got '${ty.pretty()}'`, - ); + if (inCaseOfError) { + inCaseOfError(); + } else { + this.reporter.error( + loc, + `expected type '${expected.pretty()}', got '${ty.pretty()}'`, + ); + } this.reporter.abort(); } - throw new Error("not implemented"); + throw new Error( + `resolving '${ty.pretty()}' with '${expected.pretty()}' not implemented`, + ); } private rewriteTree(node: ast.Node, ty: Ty) { @@ -230,3 +652,32 @@ type TyRes = { ty: Ty; rewriteSubtree: boolean; }; + +type BinaryOpTest = (op: ast.BinaryOp, left: Ty, right: Ty) => Ty | null; + +const binaryOpTests: BinaryOpTest[] = [ + (op, left, right) => { + const ops: ast.BinaryOp[] = [ + "Add", + "Sub", + "Mul", + "Div", + "Rem", + ]; + if ( + ops.includes(op) && left.is("Int") && left.resolvableWith(right) + ) { + return left; + } + return null; + }, + (op, left, right) => { + const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"]; + if ( + ops.includes(op) && left.is("Int") && left.resolvableWith(right) + ) { + return Ty.Bool; + } + return null; + }, +]; diff --git a/src/front/parse.ts b/src/front/parse.ts index 2965b0c..bc7eacd 100644 --- a/src/front/parse.ts +++ b/src/front/parse.ts @@ -292,7 +292,7 @@ export class Parser { throw new Error(); } const value = Number(match[1]); - const intTy = match[2] ?? "i32"; + const intTy = match[2]; if ( intTy && !["8", "16", "32", "64", "size"].includes(intTy.slice(1)) @@ -306,7 +306,7 @@ export class Parser { this.step(); return ast.Node.create(loc, "IntExpr", { value, - intTy: intTy as ast.IntTy ?? "i32", + intTy: intTy as ast.IntTy ?? "any", }); } else if (this.test("str")) { const value = this.current.value; diff --git a/src/main.ts b/src/main.ts index ac17e42..91acd9d 100644 --- a/src/main.ts +++ b/src/main.ts @@ -15,7 +15,6 @@ const fileRep = reporter.ofFile({ filename, text }); const fileAst = front.parse(text, fileRep); const syms = front.resolve(fileAst, fileRep); -const tys = new front.Tys(syms, fileRep); let mainFn: ast.NodeWithKind<"FnStmt"> | null = null; @@ -38,6 +37,37 @@ if (!mainFn) { Deno.exit(1); } +const tys = new front.Tys(syms, fileRep); +// const fnTys = front.checkFn(mainFn, syms, fileRep); +// +// fileAst.visit({ +// visit(node) { +// switch (node.kind.tag) { +// // case "IdentExpr": +// case "IntExpr": // case "StrExpr": +// // case "ArrayExpr": +// // case "IndexExpr": +// // case "CallExpr": +// // case "UnaryExpr": +// // case "BinaryExpr": +// // case "RangeExpr": +// { +// const oldTy = tys.expr(node); +// const newTy = fnTys.exprTy(node); +// if (oldTy !== newTy) { +// if (newTy.is("AnyInt") && oldTy.is("Int")) { +// break; +// } +// throw new Error( +// `'${newTy.pretty()}' != '${oldTy.pretty()}'`, +// ); +// } +// break; +// } +// } +// }, +// }); + const m = new middle.MiddleLowerer(syms, tys); const mainMiddleFn = m.lowerFn(mainFn); diff --git a/src/mir_interpreter.ts b/src/mir_interpreter.ts index 7e53d73..8029bd3 100644 --- a/src/mir_interpreter.ts +++ b/src/mir_interpreter.ts @@ -342,6 +342,8 @@ export class FnInterpreter { case "u64": case "usize": return value; + case "any": + throw new Error(); } })(value), }); diff --git a/src/stringify.ts b/src/stringify.ts index 36e0449..d98920e 100644 --- a/src/stringify.ts +++ b/src/stringify.ts @@ -104,7 +104,7 @@ export function tyPretty(ty: ty.Ty, colors = noColors): string { case "AnyInt": return `${c.typeIdent}{integer}`; } - return ""; + return `{${ty.kind.tag}}`; } export function mirFnPretty(fn: mir.Fn, colors = noColors): string { @@ -222,7 +222,11 @@ class MirFnPrettyStringifier { case "Void": return expr(``); case "Int": - return expr(` ${c.literal}${k.value}${k.intTy}`); + return expr( + ` ${c.literal}${k.value}${ + k.intTy !== "any" ? k.intTy : "" + }`, + ); case "Bool": return expr(` ${c.literal}${k.value}`); case "Str": diff --git a/src/ty.ts b/src/ty.ts index f406c28..562a0bb 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -53,6 +53,14 @@ export class Ty { static Any = Ty.create("Any", {}); /** Only used in type checker. */ static AnyInt = Ty.create("AnyInt", {}); + /** Only used in type checker. */ + static Indexable(ty: Ty): Ty { + return this.create("Indexable", { ty }); + } + /** Only used in type checker. */ + static Callable(ty: Ty): Ty { + return this.create("Callable", { ty }); + } private internHash(): string { return JSON.stringify(this.kind); @@ -177,13 +185,32 @@ export class Ty { return true; } - innerFn(): Ty & { kind: { tag: "Fn" } } { + isIndexable(): boolean { + return this.is("Array") || this.is("Slice") || this.is("Indexable"); + } + + indexableTy(): Ty | null { + if (this.is("Array") || this.is("Slice") || this.is("Indexable")) { + return this.kind.ty; + } + return null; + } + + isCallable(): boolean { + return this.is("FnStmt") || this.is("Fn"); + } + + callableTy(): Ty & { kind: { tag: "Fn" } } { + if (this.is("Fn")) { + return this; + } this.assertIs("FnStmt"); return this.kind.ty.as("Fn"); } isCheckerInternal(): boolean { - return this.is("Any") || this.is("AnyInt"); + return this.is("Any") || this.is("AnyInt") || this.is("Indexable") || + this.is("Callable"); } pretty(colors?: stringify.PrettyColors): string { @@ -203,4 +230,6 @@ export type TyKind = | { tag: "Range" } | { tag: "Fn"; params: Ty[]; retTy: Ty } | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> } - | { tag: "Any" | "AnyInt" }; + | { tag: "Any" | "AnyInt" } + | { tag: "Indexable"; ty: Ty } + | { tag: "Callable"; ty: Ty };