631 lines
18 KiB
TypeScript
631 lines
18 KiB
TypeScript
import * as ast from "../ast.ts";
|
|
import { FileReporter, Loc } from "../diagnostics.ts";
|
|
import { Ty } from "../ty.ts";
|
|
import { Sym, Syms } from "./resolve.ts";
|
|
|
|
export class Tys {
|
|
constructor(
|
|
private syms: Syms,
|
|
private reporter: FileReporter,
|
|
) {
|
|
this.cx = new CheckerCx(this.syms, this.reporter);
|
|
}
|
|
|
|
private cx: CheckerCx;
|
|
|
|
fnStmt(node: ast.NodeWithKind<"FnStmt">): Ty {
|
|
return this.cx.fnStmt(node);
|
|
}
|
|
|
|
param(node: ast.NodeWithKind<"Param">): Ty {
|
|
return this.cx.param(node);
|
|
}
|
|
|
|
place(node: ast.Node): Ty {
|
|
return this.cx.place(node);
|
|
}
|
|
|
|
expr(node: ast.Node): Ty {
|
|
return this.cx.expr(node);
|
|
}
|
|
|
|
ty(node: ast.Node): Ty {
|
|
return this.cx.ty(node);
|
|
}
|
|
}
|
|
|
|
class CheckerCx {
|
|
constructor(
|
|
private syms: Syms,
|
|
private reporter: FileReporter,
|
|
) {}
|
|
|
|
private nodeTys = new Map<number, Ty>();
|
|
|
|
private stmtChecker = new StmtChecker(this);
|
|
private paramChecker = new ParamChecker(this);
|
|
private placeChecker = new PlaceChecker(this);
|
|
private exprChecker = new ExprChecker(this);
|
|
private tyChecker = new TyChecker(this);
|
|
|
|
fnStmt(node: ast.NodeWithKind<"FnStmt">): Ty {
|
|
return this.cache(node, () => this.stmtChecker.checkFnStmt(node));
|
|
}
|
|
|
|
param(node: ast.NodeWithKind<"Param">): Ty {
|
|
return this.cache(node, () => this.paramChecker.checkParam(node));
|
|
}
|
|
|
|
place(node: ast.Node): Ty {
|
|
return this.cache(node, () => this.placeChecker.checkPlace(node));
|
|
}
|
|
|
|
expr(node: ast.Node): Ty {
|
|
return this.cache(node, () => this.exprChecker.checkExpr(node));
|
|
}
|
|
|
|
ty(node: ast.Node): Ty {
|
|
return this.cache(node, () => this.tyChecker.checkTy(node));
|
|
}
|
|
|
|
private cache(node: ast.Node, action: () => Ty): Ty {
|
|
if (this.nodeTys.has(node.id)) {
|
|
return this.nodeTys.get(node.id)!;
|
|
}
|
|
const ty = action();
|
|
this.nodeTys.set(node.id, ty);
|
|
return ty;
|
|
}
|
|
|
|
error(loc: Loc, message: string) {
|
|
this.reporter.error(loc, message);
|
|
}
|
|
|
|
info(loc: Loc, message: string) {
|
|
this.reporter.info(loc, message);
|
|
}
|
|
|
|
fail(): never {
|
|
this.reporter.abort();
|
|
}
|
|
|
|
sym(node: ast.Node): Sym {
|
|
return this.syms.get(node);
|
|
}
|
|
|
|
assertCompatible(left: Ty, right: Ty, loc: Loc): void {
|
|
if (!left.resolvableWith(right)) {
|
|
this.error(
|
|
loc,
|
|
`type '${left.pretty()}' not compatible with type '${right.pretty()}'`,
|
|
);
|
|
this.fail();
|
|
}
|
|
}
|
|
}
|
|
|
|
class StmtChecker {
|
|
constructor(
|
|
private cx: CheckerCx,
|
|
) {}
|
|
|
|
checkFnStmt(stmt: ast.NodeWithKind<"FnStmt">): Ty {
|
|
const k = stmt.kind;
|
|
|
|
const params = k.params
|
|
.map((param) => this.cx.param(param.as("Param")));
|
|
const retTy = k.retTy ? this.cx.ty(k.retTy) : Ty.Void;
|
|
|
|
k.body.visit({
|
|
visit: (node) => {
|
|
if (node.is("ReturnStmt")) {
|
|
const ty = node.kind.expr
|
|
? this.cx.expr(node.kind.expr)
|
|
: Ty.Void;
|
|
if (!ty.resolvableWith(retTy)) {
|
|
this.cx.error(
|
|
node.loc,
|
|
`type '${ty.pretty()}' not compatible with return type '${retTy.pretty()}'`,
|
|
);
|
|
this.cx.info(
|
|
stmt.kind.retTy?.loc ?? stmt.loc,
|
|
`return type '${retTy}' defined here`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
}
|
|
},
|
|
});
|
|
|
|
const ty = Ty.create("Fn", { params, retTy });
|
|
return Ty.create("FnStmt", { stmt, ty });
|
|
}
|
|
}
|
|
|
|
class ParamChecker {
|
|
constructor(
|
|
private cx: CheckerCx,
|
|
) {}
|
|
|
|
checkParam(node: ast.NodeWithKind<"Param">): Ty {
|
|
const sym = this.cx.sym(node);
|
|
|
|
if (sym.tag === "Let") {
|
|
const exprTy = this.cx.expr(sym.stmt.kind.expr);
|
|
if (node.kind.ty) {
|
|
const explicitTy = this.cx.ty(node.kind.ty);
|
|
this.cx.assertCompatible(
|
|
exprTy,
|
|
explicitTy,
|
|
sym.stmt.kind.expr.loc,
|
|
);
|
|
}
|
|
return exprTy;
|
|
}
|
|
if (sym.tag === "FnParam") {
|
|
if (!node.kind.ty) {
|
|
this.cx.error(node.loc, `parameter must have a type`);
|
|
this.cx.fail();
|
|
}
|
|
return this.cx.ty(node.kind.ty);
|
|
}
|
|
|
|
throw new Error(`'${sym.tag}' not handled`);
|
|
}
|
|
}
|
|
|
|
class PlaceChecker {
|
|
constructor(
|
|
private cx: CheckerCx,
|
|
) {}
|
|
|
|
checkPlace(node: ast.Node): Ty {
|
|
if (node.is("UnaryExpr")) {
|
|
if (node.kind.op === "Deref") {
|
|
const exprTy = this.checkPlace(node.kind.expr);
|
|
if (exprTy.is("Ptr") || exprTy.is("PtrMut")) {
|
|
return exprTy.kind.ty;
|
|
}
|
|
}
|
|
}
|
|
return this.cx.expr(node);
|
|
}
|
|
}
|
|
|
|
class ExprChecker {
|
|
constructor(
|
|
private cx: CheckerCx,
|
|
) {}
|
|
|
|
checkExpr(node: ast.Node): Ty {
|
|
const tag = node.kind.tag;
|
|
switch (tag) {
|
|
case "IdentExpr":
|
|
return this.checkIdentExpr(node.as(tag));
|
|
case "IntExpr":
|
|
return this.checkIntExpr(node.as(tag));
|
|
case "StrExpr":
|
|
return Ty.create("Ptr", {
|
|
ty: Ty.create("Slice", { ty: Ty.U8 }),
|
|
});
|
|
case "ArrayExpr":
|
|
return this.checkArrayExpr(node.as(tag));
|
|
case "IndexExpr":
|
|
return this.checkIndexExpr(node.as(tag));
|
|
case "CallExpr":
|
|
return this.checkCallExpr(node.as(tag));
|
|
case "UnaryExpr":
|
|
return this.checkUnaryExpr(node.as(tag));
|
|
case "BinaryExpr":
|
|
return this.checkBinaryExpr(node.as(tag));
|
|
case "RangeExpr":
|
|
return this.checkRangeExpr(node.as(tag));
|
|
default:
|
|
throw new Error(`'${node.kind.tag}' not unhandled`);
|
|
}
|
|
}
|
|
|
|
private checkIdentExpr(node: ast.NodeWithKind<"IdentExpr">): Ty {
|
|
const sym = this.cx.sym(node);
|
|
if (sym.tag === "Fn") {
|
|
return this.cx.fnStmt(sym.stmt);
|
|
}
|
|
if (sym.tag === "Bool") {
|
|
return Ty.Bool;
|
|
}
|
|
if (sym.tag === "Builtin") {
|
|
this.cx.error(node.loc, `invalid use of builtin '${sym.id}'`);
|
|
this.cx.fail();
|
|
}
|
|
if (sym.tag === "FnParam") {
|
|
return this.cx.expr(sym.param);
|
|
}
|
|
if (sym.tag === "Let") {
|
|
return this.cx.expr(sym.param);
|
|
}
|
|
throw new Error(`'${sym.tag}' not handled`);
|
|
}
|
|
|
|
private checkIntExpr(node: ast.NodeWithKind<"IntExpr">): Ty {
|
|
switch (node.kind.intTy) {
|
|
case "u8":
|
|
return Ty.U8;
|
|
case "u16":
|
|
return Ty.U16;
|
|
case "u32":
|
|
return Ty.U32;
|
|
case "u64":
|
|
return Ty.U64;
|
|
case "usize":
|
|
return Ty.USize;
|
|
case "i8":
|
|
return Ty.U8;
|
|
case "i16":
|
|
return Ty.U16;
|
|
case "i32":
|
|
return Ty.I32;
|
|
case "i64":
|
|
return Ty.U64;
|
|
case "isize":
|
|
return Ty.ISize;
|
|
case "any":
|
|
return Ty.I32;
|
|
default:
|
|
throw new Error(`intType '${node.kind.intTy}' not handled`);
|
|
}
|
|
}
|
|
|
|
private checkArrayExpr(node: ast.NodeWithKind<"ArrayExpr">): Ty {
|
|
let ty: Ty | null = null;
|
|
for (const value of node.kind.values) {
|
|
const valueTy = this.cx.expr(value);
|
|
if (ty) {
|
|
this.cx.assertCompatible(ty, valueTy, value.loc);
|
|
} else {
|
|
ty = valueTy;
|
|
}
|
|
}
|
|
if (!ty) {
|
|
this.cx.error(node.loc, `could not infer type of empty array`);
|
|
this.cx.fail();
|
|
}
|
|
const length = node.kind.values.length;
|
|
return Ty.create("Array", { ty, length });
|
|
}
|
|
|
|
private checkIndexExpr(node: ast.NodeWithKind<"IndexExpr">): Ty {
|
|
const exprTy = this.cx.place(node.kind.value);
|
|
const argTy = this.cx.expr(node.kind.arg);
|
|
if (
|
|
(exprTy.is("Array") || exprTy.is("Slice")) &&
|
|
argTy.resolvableWith(Ty.I32)
|
|
) {
|
|
return exprTy.kind.ty;
|
|
}
|
|
if (
|
|
(exprTy.is("Array") || exprTy.is("Slice")) &&
|
|
argTy.resolvableWith(Ty.create("Range", {}))
|
|
) {
|
|
return Ty.create("Slice", { ty: exprTy.kind.ty });
|
|
}
|
|
this.cx.error(
|
|
node.loc,
|
|
`cannot use index operator on '${exprTy.pretty()}' with '${argTy.pretty()}'`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
|
|
private checkCallExpr(node: ast.NodeWithKind<"CallExpr">): Ty {
|
|
if (node.kind.value.is("IdentExpr")) {
|
|
const sym = this.cx.sym(node.kind.value);
|
|
if (sym.tag === "Builtin") {
|
|
return this.checkCallExprBuiltin(node, sym);
|
|
}
|
|
}
|
|
|
|
const calleeTy = this.cx.expr(node.kind.value);
|
|
|
|
const callableTy = calleeTy.is("Fn")
|
|
? calleeTy
|
|
: calleeTy.is("FnStmt")
|
|
? calleeTy.kind.ty as Ty & { kind: { tag: "Fn" } }
|
|
: null;
|
|
|
|
if (!callableTy) {
|
|
this.cx.error(
|
|
node.loc,
|
|
`type '${calleeTy.pretty()}' not callable`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
|
|
const args = node.kind.args
|
|
.map((arg) => this.cx.expr(arg));
|
|
const params = callableTy.kind.params;
|
|
if (args.length !== params.length) {
|
|
this.reportArgsIncorrectAmount(
|
|
node,
|
|
args.length,
|
|
params.length,
|
|
calleeTy,
|
|
);
|
|
}
|
|
for (const i of args.keys()) {
|
|
if (!args[i].resolvableWith(params[i])) {
|
|
this.reportArgTypeNotCompatible(
|
|
node,
|
|
args,
|
|
params,
|
|
calleeTy,
|
|
i,
|
|
);
|
|
}
|
|
}
|
|
return callableTy.kind.retTy;
|
|
}
|
|
|
|
private checkCallExprBuiltin(
|
|
node: ast.NodeWithKind<"CallExpr">,
|
|
sym: Sym,
|
|
): Ty {
|
|
if (!node.kind.value.is("IdentExpr")) {
|
|
throw new Error();
|
|
}
|
|
if (sym.tag !== "Builtin") {
|
|
throw new Error();
|
|
}
|
|
if (sym.id === "len") {
|
|
if (node.kind.args.length !== 1) {
|
|
this.reportArgsIncorrectAmount(
|
|
node,
|
|
node.kind.args.length,
|
|
0,
|
|
null,
|
|
);
|
|
}
|
|
const argTy = this.cx.expr(node.kind.args[0]);
|
|
if (
|
|
!(argTy.is("Array") ||
|
|
argTy.is("Ptr") &&
|
|
(argTy.kind.ty.is("Array") ||
|
|
argTy.kind.ty.is("Slice")))
|
|
) {
|
|
this.reportArgTypeNotCompatible(
|
|
node,
|
|
[argTy],
|
|
[Ty.Error],
|
|
null,
|
|
0,
|
|
);
|
|
}
|
|
return Ty.I32;
|
|
}
|
|
if (sym.id === "print") {
|
|
void node.kind.args
|
|
.map((arg) => this.cx.expr(arg));
|
|
return Ty.Void;
|
|
}
|
|
throw new Error(`builtin '${sym.id}' not handled`);
|
|
}
|
|
|
|
private checkUnaryExpr(node: ast.NodeWithKind<"UnaryExpr">): Ty {
|
|
const exprTy = this.cx.expr(node.kind.expr);
|
|
if (node.kind.op === "Neg" && exprTy.resolvableWith(Ty.I32)) {
|
|
return Ty.I32;
|
|
}
|
|
if (node.kind.op === "Not" && exprTy.resolvableWith(Ty.Bool)) {
|
|
return Ty.Bool;
|
|
}
|
|
if (node.kind.op === "Ref") {
|
|
return Ty.create("Ptr", { ty: exprTy });
|
|
}
|
|
if (node.kind.op === "RefMut") {
|
|
return Ty.create("PtrMut", { ty: exprTy });
|
|
}
|
|
if (node.kind.op === "Deref") {
|
|
if (exprTy.is("Ptr") || exprTy.is("PtrMut")) {
|
|
if (!exprTy.kind.ty.isSized()) {
|
|
this.cx.error(
|
|
node.loc,
|
|
`cannot dereference unsized type '${exprTy.kind.ty.pretty()}' in an expression`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
return exprTy.kind.ty;
|
|
}
|
|
}
|
|
this.cx.error(
|
|
node.loc,
|
|
`operator '${node.kind.tok}' cannot be applied to type '${exprTy.pretty()}'`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
|
|
private checkBinaryExpr(node: ast.NodeWithKind<"BinaryExpr">): Ty {
|
|
const left = this.cx.expr(node.kind.left);
|
|
const right = this.cx.expr(node.kind.right);
|
|
const result = binaryOpTests
|
|
.map((test) => test(node.kind.op, left, right))
|
|
.filter((result) => result)
|
|
.at(0);
|
|
if (!result) {
|
|
this.cx.error(
|
|
node.loc,
|
|
`operator '${node.kind.tok}' cannot be applied to types '${left.pretty()}' and '${right.pretty()}'`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
return result;
|
|
}
|
|
|
|
private checkRangeExpr(node: ast.NodeWithKind<"RangeExpr">): Ty {
|
|
for (const operandExpr of [node.kind.begin, node.kind.end]) {
|
|
const operandTy = operandExpr && this.cx.expr(operandExpr);
|
|
if (operandTy && !operandTy.resolvableWith(Ty.I32)) {
|
|
this.cx.error(
|
|
operandExpr.loc,
|
|
`range operand must be '${Ty.I32.pretty()}', not '${operandTy.pretty()}'`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
}
|
|
return Ty.create("Range", {});
|
|
}
|
|
|
|
private reportArgsIncorrectAmount(
|
|
node: ast.NodeWithKind<"CallExpr">,
|
|
argsLength: number,
|
|
paramsLength: number,
|
|
calleeTy: Ty | null,
|
|
): never {
|
|
this.cx.error(
|
|
node.loc,
|
|
`incorrect amount of arguments. got ${argsLength} expected ${paramsLength}`,
|
|
);
|
|
if (calleeTy?.is("FnStmt")) {
|
|
this.cx.info(
|
|
calleeTy.kind.stmt.loc,
|
|
"function defined here",
|
|
);
|
|
}
|
|
this.cx.fail();
|
|
}
|
|
|
|
private reportArgTypeNotCompatible(
|
|
node: ast.NodeWithKind<"CallExpr">,
|
|
args: Ty[],
|
|
params: Ty[],
|
|
calleeTy: Ty | null,
|
|
i: number,
|
|
): never {
|
|
this.cx.error(
|
|
node.kind.args[i].loc,
|
|
`type '${args[i].pretty()}' not compatible with type '${
|
|
params[i].pretty()
|
|
}', for argument ${i}`,
|
|
);
|
|
if (calleeTy?.is("FnStmt")) {
|
|
this.cx.info(
|
|
calleeTy.kind.stmt.kind.params[i].loc,
|
|
`parameter '${
|
|
calleeTy.kind.stmt.kind.params[i]
|
|
.as("Param").kind.ident
|
|
}' defined here`,
|
|
);
|
|
}
|
|
this.cx.fail();
|
|
}
|
|
}
|
|
|
|
class TyChecker {
|
|
constructor(
|
|
private cx: CheckerCx,
|
|
) {}
|
|
|
|
checkTy(node: ast.Node): Ty {
|
|
if (node.is("IdentTy")) {
|
|
const sym = this.cx.sym(node);
|
|
if (sym.tag === "BuiltinTy") {
|
|
switch (sym.ident) {
|
|
case "void":
|
|
return Ty.Void;
|
|
case "bool":
|
|
return Ty.Bool;
|
|
case "i8":
|
|
return Ty.I8;
|
|
case "i16":
|
|
return Ty.I16;
|
|
case "i32":
|
|
return Ty.I32;
|
|
case "i64":
|
|
return Ty.I64;
|
|
case "isize":
|
|
return Ty.ISize;
|
|
case "u8":
|
|
return Ty.U8;
|
|
case "u16":
|
|
return Ty.U16;
|
|
case "u32":
|
|
return Ty.U32;
|
|
case "u64":
|
|
return Ty.U64;
|
|
case "usize":
|
|
return Ty.USize;
|
|
default:
|
|
throw new Error(
|
|
`unknown type '${node.kind.ident}'`,
|
|
);
|
|
}
|
|
}
|
|
this.cx.error(node.loc, `symbol is not a type`);
|
|
this.cx.fail();
|
|
}
|
|
|
|
if (node.is("PtrTy")) {
|
|
const ty = this.cx.ty(node.kind.ty);
|
|
return Ty.create("Ptr", { ty });
|
|
}
|
|
if (node.is("PtrMutTy")) {
|
|
const ty = this.cx.ty(node.kind.ty);
|
|
return Ty.create("PtrMut", { ty });
|
|
}
|
|
|
|
if (node.is("ArrayTy")) {
|
|
const ty = this.cx.ty(node.kind.ty);
|
|
const lengthTy = this.cx.expr(node.kind.length);
|
|
if (!lengthTy.resolvableWith(Ty.I32)) {
|
|
this.cx.error(
|
|
node.kind.length.loc,
|
|
`for array length, expected 'int', got '${lengthTy.pretty()}'`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
if (!node.kind.length.is("IntExpr")) {
|
|
this.cx.error(
|
|
node.kind.length.loc,
|
|
`array length must be an 'int' expression`,
|
|
);
|
|
this.cx.fail();
|
|
}
|
|
const length = node.kind.length.kind.value;
|
|
return Ty.create("Array", { ty, length });
|
|
}
|
|
|
|
if (node.is("SliceTy")) {
|
|
const ty = this.cx.ty(node.kind.ty);
|
|
return Ty.create("Slice", { ty });
|
|
}
|
|
|
|
throw new Error(`'${node.kind.tag}' not unhandled`);
|
|
}
|
|
}
|
|
|
|
type BinaryOpTest = (op: ast.BinaryOp, left: Ty, right: Ty) => Ty | null;
|
|
|
|
const binaryOpTests: BinaryOpTest[] = [
|
|
(op, left, right) => {
|
|
const ops: ast.BinaryOp[] = [
|
|
"Add",
|
|
"Sub",
|
|
"Mul",
|
|
"Div",
|
|
"Rem",
|
|
];
|
|
if (
|
|
ops.includes(op) && left.is("Int") && left.resolvableWith(right)
|
|
) {
|
|
return left;
|
|
}
|
|
return null;
|
|
},
|
|
(op, left, right) => {
|
|
const ops: ast.BinaryOp[] = ["Eq", "Ne", "Lt", "Gt", "Lte", "Gte"];
|
|
if (
|
|
ops.includes(op) && left.is("Int") && left.resolvableWith(right)
|
|
) {
|
|
return Ty.Bool;
|
|
}
|
|
return null;
|
|
},
|
|
];
|