diff --git a/src/front/check.ts b/src/front/check.ts index e3a743d..b677a4c 100644 --- a/src/front/check.ts +++ b/src/front/check.ts @@ -61,6 +61,9 @@ export class Checker { if (sym.tag === "Fn") { return this.check(sym.stmt); } + if (sym.tag === "Bool") { + return Ty.Bool; + } if (sym.tag === "Builtin") { return builtins.find((s) => s.id === sym.id)!.ty; } @@ -106,6 +109,8 @@ export class Checker { return Ty.Void; case "int": return Ty.Int; + case "bool": + return Ty.Bool; default: this.error(node.line, `unknown type '${node.kind.ident}'`); } @@ -148,9 +153,9 @@ export class Checker { private checkCall(node: ast.NodeWithKind<"CallExpr">): Ty { const calleeTy = this.check(node.kind.expr); - const callableTy = calleeTy.isKind("Fn") + const callableTy = calleeTy.is("Fn") ? calleeTy - : calleeTy.isKind("FnStmt") + : calleeTy.is("FnStmt") ? calleeTy.kind.ty as Ty & { kind: { tag: "Fn" } } : null; @@ -170,7 +175,7 @@ export class Checker { node.line, `incorrect amount of arguments. got ${args.length} expected ${params.length}`, ); - if (calleeTy.isKind("FnStmt")) { + if (calleeTy.is("FnStmt")) { this.info( calleeTy.kind.stmt.line, "function defined here", @@ -186,7 +191,7 @@ export class Checker { params[i] }', for argument ${i}`, ); - if (calleeTy.isKind("FnStmt")) { + if (calleeTy.is("FnStmt")) { this.info( calleeTy.kind.stmt.kind.params[i].line, `parameter '${ @@ -246,4 +251,14 @@ type BinaryOpPattern = { const binaryOpPatterns: BinaryOpPattern[] = [ { op: "Add", left: Ty.Int, right: Ty.Int, result: Ty.Int }, { op: "Subtract", left: Ty.Int, right: Ty.Int, result: Ty.Int }, + { op: "Multiply", left: Ty.Int, right: Ty.Int, result: Ty.Int }, + { op: "Divide", left: Ty.Int, right: Ty.Int, result: Ty.Int }, + { op: "Remainder", left: Ty.Int, right: Ty.Int, result: Ty.Int }, + + { op: "Eq", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, + { op: "Ne", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, + { op: "Lt", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, + { op: "Gt", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, + { op: "Lte", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, + { op: "Gte", left: Ty.Int, right: Ty.Int, result: Ty.Bool }, ]; diff --git a/src/front/resolve.ts b/src/front/resolve.ts index b81ea80..6ec65e4 100644 --- a/src/front/resolve.ts +++ b/src/front/resolve.ts @@ -1,6 +1,5 @@ import * as ast from "../ast.ts"; import { printDiagnostics } from "../diagnostics.ts"; -import { Ty } from "../ty.ts"; import { builtins } from "./builtins.ts"; export class ResolveMap { @@ -18,6 +17,7 @@ export class ResolveMap { export type Sym = | { tag: "Error" } + | { tag: "Bool"; value: boolean } | { tag: "Builtin"; id: string } | { tag: "Fn"; stmt: ast.NodeWithKind<"FnStmt"> } | { @@ -79,7 +79,7 @@ export function resolve( } if (k.tag === "IdentExpr") { - const sym = syms.resolve(k.ident); + const sym = syms.resolveExpr(k.ident); if (sym === null) { printDiagnostics( filename, @@ -126,12 +126,16 @@ class ResolverSyms { this.syms.set(ident, sym); } - resolve(ident: string): Sym | null { + resolveExpr(ident: string): Sym | null { + if (ident === "false" || ident === "true") { + return { tag: "Bool", value: ident === "true" }; + } + if (this.syms.has(ident)) { return this.syms.get(ident)!; } if (this.parent) { - return this.parent.resolve(ident); + return this.parent.resolveExpr(ident); } return null; } diff --git a/src/middle.ts b/src/middle.ts index c15ebd3..7c0a972 100644 --- a/src/middle.ts +++ b/src/middle.ts @@ -112,6 +112,9 @@ class FnLowerer { } return this.pushInst(local.ty, "LocalLoad", { source: local }); } + if (sym.tag === "Bool") { + return this.pushInst(Ty.Bool, "Bool", { value: sym.value }); + } throw new Error(`'${sym.tag}' not handled`); } if (expr.is("IntExpr")) { @@ -140,17 +143,24 @@ class FnLowerer { return this.pushInst(ty, "Call", { callee, args }); } if (expr.is("BinaryExpr")) { - const ty = this.checker.check(expr); + const resultTy = this.checker.check(expr); + const leftTy = this.checker.check(expr.kind.left); + const rightTy = this.checker.check(expr.kind.right); const binaryOp = binaryOpPatterns .find((pat) => - expr.kind.op === pat.op && ty.compatibleWith(pat.ty) + expr.kind.op === pat.op && + resultTy.compatibleWith(pat.result) && + leftTy.compatibleWith(pat.left ?? pat.result) && + rightTy.compatibleWith(pat.right ?? pat.left ?? pat.result) ); if (!binaryOp) { - throw new Error(); + throw new Error( + `'${expr.kind.op}' with '${resultTy.pretty()}' not handled`, + ); } const left = this.lowerExpr(expr.kind.left); const right = this.lowerExpr(expr.kind.right); - return this.pushInst(ty, binaryOp.tag, { left, right }); + return this.pushInst(resultTy, binaryOp.tag, { left, right }); } throw new Error(`'${expr.kind.tag}' not handled`); } @@ -174,13 +184,25 @@ class FnLowerer { type BinaryOpPattern = { op: ast.BinaryOp; - ty: Ty; tag: BinaryOp; + result: Ty; + left?: Ty; + right?: Ty; }; const binaryOpPatterns: BinaryOpPattern[] = [ - { op: "Add", ty: Ty.Int, tag: "Add" }, - { op: "Subtract", ty: Ty.Int, tag: "Sub" }, + { op: "Add", tag: "Add", result: Ty.Int, left: Ty.Int }, + { op: "Subtract", tag: "Sub", result: Ty.Int, left: Ty.Int }, + { op: "Multiply", tag: "Mul", result: Ty.Int, left: Ty.Int }, + { op: "Divide", tag: "Div", result: Ty.Int, left: Ty.Int }, + { op: "Remainder", tag: "Rem", result: Ty.Int }, + + { op: "Eq", tag: "Eq", result: Ty.Bool, left: Ty.Int }, + { op: "Ne", tag: "Ne", result: Ty.Bool, left: Ty.Int }, + { op: "Lt", tag: "Lt", result: Ty.Bool, left: Ty.Int }, + { op: "Gt", tag: "Gt", result: Ty.Bool, left: Ty.Int }, + { op: "Lte", tag: "Lte", result: Ty.Bool, left: Ty.Int }, + { op: "Gte", tag: "Gte", result: Ty.Bool, left: Ty.Int }, ]; export class Fn { @@ -191,7 +213,7 @@ export class Fn { ) {} pretty(): string { - const fnTy = this.ty.isKind("FnStmt") && this.ty.kind.ty.isKind("Fn") + const fnTy = this.ty.is("FnStmt") && this.ty.kind.ty.is("Fn") ? this.ty.kind.ty : null; if (!fnTy) { @@ -267,6 +289,7 @@ export class Inst { case "Void": return ""; case "Int": + case "Bool": return ` ${k.value}`; case "Fn": return ` ${k.fn.stmt.kind.ident}`; @@ -312,6 +335,7 @@ export type InstKind = | { tag: "Error" } | { tag: "Void" } | { tag: "Int"; value: number } + | { tag: "Bool"; value: boolean } | { tag: "Fn"; fn: Fn } | { tag: "Param"; idx: number } | { tag: "Call"; callee: Inst; args: Inst[] } diff --git a/src/mir_interpreter.ts b/src/mir_interpreter.ts index 792ba8a..95c8471 100644 --- a/src/mir_interpreter.ts +++ b/src/mir_interpreter.ts @@ -20,6 +20,7 @@ export class MirInterpreter { throw new Error(); case "Void": case "Int": + case "Bool": case "Fn": regs.set(inst, new Val(k)); continue; @@ -79,14 +80,28 @@ export class MirInterpreter { const rk = right.kind; if (lk.tag === "Int" && rk.tag === "Int") { + const left = lk.value; + const right = lk.value; + const value = (() => { + const Int = (value: number) => + new Val({ tag: "Int", value }); + const Bool = (value: boolean) => + new Val({ tag: "Bool", value }); + switch (k.tag) { case "Eq": + return Bool(left === right); case "Ne": + return Bool(left !== right); case "Lt": + return Bool(left < right); case "Gt": + return Bool(left > right); case "Lte": + return Bool(left <= right); case "Gte": + return Bool(left >= right); case "BitOr": case "BitXor": case "BitAnd": @@ -94,17 +109,19 @@ export class MirInterpreter { case "Shr": break; case "Add": - return lk.value + rk.value; + return Int(left + right); case "Sub": - return lk.value - rk.value; + return Int(left - right); case "Mul": + return Int(left * right); case "Div": + return Int(Math.floor(left / right)); case "Rem": - break; + return Int(left % right); } throw new Error(`'${k.tag}' not handled`); })(); - regs.set(inst, new Val({ tag: "Int", value })); + regs.set(inst, value); continue; } throw new Error(`'${k.tag}' not handled`); @@ -136,6 +153,7 @@ class Val { case "Void": return "void"; case "Int": + case "Bool": return `${k.value}`; case "Fn": return `<${k.fn.ty.pretty()}>`; @@ -147,4 +165,5 @@ class Val { type ValKind = | { tag: "Void" } | { tag: "Int"; value: number } + | { tag: "Bool"; value: boolean } | { tag: "Fn"; fn: mir.Fn }; diff --git a/src/ty.ts b/src/ty.ts index 05cd2f1..82a99d3 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -24,6 +24,7 @@ export class Ty { static Error = Ty.create("Error", {}); static Void = Ty.create("Void", {}); static Int = Ty.create("Int", {}); + static Bool = Ty.create("Bool", {}); private internHash(): string { return JSON.stringify(this.kind); @@ -34,24 +35,27 @@ export class Ty { public kind: TyKind, ) {} - isKind< + is< Tag extends TyKind["tag"], >(tag: Tag): this is Ty & { kind: { tag: Tag } } { return this.kind.tag === tag; } compatibleWith(other: Ty): boolean { - if (this.isKind("Error")) { + if (this.is("Error")) { return false; } - if (this.isKind("Void")) { - return other.isKind("Void"); + if (this.is("Void")) { + return other.is("Void"); } - if (this.isKind("Int")) { - return other.isKind("Int"); + if (this.is("Int")) { + return other.is("Int"); } - if (this.isKind("Fn")) { - if (!other.isKind("Fn")) { + if (this.is("Bool")) { + return other.is("Bool"); + } + if (this.is("Fn")) { + if (!other.is("Fn")) { return false; } for (const i of this.kind.params.keys()) { @@ -64,8 +68,8 @@ export class Ty { } return true; } - if (this.isKind("FnStmt")) { - if (!other.isKind("FnStmt")) { + if (this.is("FnStmt")) { + if (!other.is("FnStmt")) { return false; } if (!this.kind.ty.compatibleWith(other.kind.ty)) { @@ -81,22 +85,25 @@ export class Ty { } pretty(): string { - if (this.isKind("Error")) { + if (this.is("Error")) { return ""; } - if (this.isKind("Void")) { + if (this.is("Void")) { return "void"; } - if (this.isKind("Int")) { + if (this.is("Int")) { return "int"; } - if (this.isKind("Fn")) { + if (this.is("Bool")) { + return "bool"; + } + if (this.is("Fn")) { return `fn (${ this.kind.params.map((param) => param.pretty()).join(", ") }) -> ${this.kind.retTy.pretty()}`; } - if (this.isKind("FnStmt")) { - if (!this.kind.ty.isKind("Fn")) throw new Error(); + if (this.is("FnStmt")) { + if (!this.kind.ty.is("Fn")) throw new Error(); return `fn ${this.kind.stmt.kind.ident}(${ this.kind.ty.kind.params.map((param) => param.pretty()).join( ", ", @@ -111,5 +118,6 @@ export type TyKind = | { tag: "Error" } | { tag: "Void" } | { tag: "Int" } + | { tag: "Bool" } | { tag: "Fn"; params: Ty[]; retTy: Ty } | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> }; diff --git a/tests/bool.ethlang b/tests/bool.ethlang new file mode 100644 index 0000000..a8f2180 --- /dev/null +++ b/tests/bool.ethlang @@ -0,0 +1,8 @@ + +fn main() +{ + let my_bool = false; + let my_other_bool: bool = true; + + let cond: bool = 1 == 2; +} diff --git a/tests/operators.ethlang b/tests/operators.ethlang index ed623bd..e35964d 100644 --- a/tests/operators.ethlang +++ b/tests/operators.ethlang @@ -1,5 +1,8 @@ // expect: 8 // expect: 2 +// expect: 15 +// expect: 7 +// expect: 2 fn main() { @@ -8,5 +11,8 @@ fn main() print_int(a + b); print_int(a - b); + print_int(a * b); + print_int(a * b / 2); + print_int(a % b); } diff --git a/tests/test.sh b/tests/test.sh index 24a0aa3..da67a72 100755 --- a/tests/test.sh +++ b/tests/test.sh @@ -10,11 +10,12 @@ TEST_SRC=$(fd '\.ethlang' $TEST_DIR) count_total=0 count_succeeded=0 -for test_file in $TEST_SRC -do - echo "- $(basename $test_file)" +run_test_file() { + local file=$1 + + echo "- $(basename $file)" set +e - output=$(deno run -A $SRC_DIR/main.ts $test_file --test) + output=$(deno run -A $SRC_DIR/main.ts $file --test) status=$? set -e @@ -26,9 +27,9 @@ do if [[ status -eq 0 ]] then - if grep -q '// expect:' $test_file + if grep -q '// expect:' $file then - expected=$(grep '// expect:' $test_file | sed -E 's/\/\/ expect: (.*?)/\1/g') + expected=$(grep '// expect:' $file | sed -E 's/\/\/ expect: (.*?)/\1/g') if [[ $output != $expected ]] then echo "-- failed: incorrect output --" @@ -45,17 +46,25 @@ do if [[ status -eq 0 ]] then count_succeeded=$(($count_succeeded + 1)) - else - echo "failed" fi +} -done +if [[ $1 == "" ]] +then + for file in $TEST_SRC + do + run_test_file $file + done +else + run_test_file $1 +fi if [[ $count_succeeded -eq $count_total ]] then - echo "=== all tests passed ($count_succeeded/$count_total passed) ===" + echo "== all tests passed ($count_succeeded/$count_total passed) ==" else - echo "=== tests failed ($count_succeeded/$count_total passed) ===" + echo "== tests failed ($count_succeeded/$count_total passed) ==" + exit 1 fi