From cf36151115e40575afafb55ac43a34090d2bee71 Mon Sep 17 00:00:00 2001 From: sfja Date: Wed, 11 Mar 2026 23:33:09 +0100 Subject: [PATCH] add pointers --- src/ast.ts | 11 +++++- src/front/check.ts | 28 ++++++++++++-- src/front/parse.ts | 13 ++++++- src/middle.ts | 87 ++++++++++++++++++++++++++++++++++-------- src/mir_interpreter.ts | 42 +++++++++++--------- src/ty.ts | 8 ++++ tests/pointer.ethlang | 23 +++++++++++ tests/test.sh | 2 +- 8 files changed, 173 insertions(+), 41 deletions(-) create mode 100644 tests/pointer.ethlang diff --git a/src/ast.ts b/src/ast.ts index f344592..57a84be 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -88,6 +88,9 @@ export class Node { return visit(k.left, k.right); case "IdentTy": return visit(); + case "PtrTy": + case "PtrMutTy": + return visit(k.ty); } k satisfies never; } @@ -115,11 +118,15 @@ export type NodeKind = | { tag: "CallExpr"; expr: Node; args: Node[] } | { tag: "UnaryExpr"; op: UnaryOp; expr: Node; tok: string } | { tag: "BinaryExpr"; op: BinaryOp; left: Node; right: Node; tok: string } - | { tag: "IdentTy"; ident: string }; + | { tag: "IdentTy"; ident: string } + | { tag: "PtrTy" | "PtrMutTy"; ty: Node }; export type UnaryOp = | "Not" - | "Negate"; + | "Negate" + | "Ref" + | "RefMut" + | "Deref"; export type BinaryOp = | "Or" diff --git a/src/front/check.ts b/src/front/check.ts index 9963fc0..9323b70 100644 --- a/src/front/check.ts +++ b/src/front/check.ts @@ -85,16 +85,27 @@ export class Checker { } if (node.is("UnaryExpr")) { - const expr = this.check(node.kind.expr); - if (node.kind.op === "Negate" && expr.compatibleWith(Ty.Int)) { + const exprTy = this.check(node.kind.expr); + if (node.kind.op === "Negate" && exprTy.compatibleWith(Ty.Int)) { return Ty.Int; } - if (node.kind.op === "Not" && expr.compatibleWith(Ty.Bool)) { + if (node.kind.op === "Not" && exprTy.compatibleWith(Ty.Bool)) { return Ty.Bool; } + if (node.kind.op === "Ref") { + return Ty.create("Ptr", { ty: exprTy }); + } + if (node.kind.op === "RefMut") { + return Ty.create("PtrMut", { ty: exprTy }); + } + if (node.kind.op === "Deref") { + if (exprTy.is("Ptr") || exprTy.is("PtrMut")) { + return exprTy.kind.ty; + } + } this.error( node.line, - `operator '${node.kind.tok}' cannot be applied to type '${expr.pretty()}'`, + `operator '${node.kind.tok}' cannot be applied to type '${exprTy.pretty()}'`, ); this.fail(); } @@ -131,6 +142,15 @@ export class Checker { } } + if (node.is("PtrTy")) { + const ty = this.check(node.kind.ty); + return Ty.create("Ptr", { ty }); + } + if (node.is("PtrMutTy")) { + const ty = this.check(node.kind.ty); + return Ty.create("PtrMut", { ty }); + } + throw new Error(`'${k.tag}' not unhandled`); } diff --git a/src/front/parse.ts b/src/front/parse.ts index eff2636..7a83ab1 100644 --- a/src/front/parse.ts +++ b/src/front/parse.ts @@ -185,6 +185,7 @@ export class Parser { const ops: [Tok["type"], ast.UnaryOp][] = [ ["not", "Not"], ["-", "Negate"], + ["*", "Deref"], ]; for (const [tok, op] of ops) { if (this.eat(tok)) { @@ -192,6 +193,12 @@ export class Parser { return ast.Node.create(loc, "UnaryExpr", { op, expr, tok }); } } + if (this.eat("&")) { + const op: ast.UnaryOp = this.eat("mut") ? "RefMut" : "Ref"; + const expr = this.parsePrefix(); + const tok = op === "Ref" ? "&" : "&mut"; + return ast.Node.create(loc, "UnaryExpr", { op, expr, tok }); + } return this.parsePostfix(); } @@ -245,6 +252,10 @@ export class Parser { const ident = this.current.value; this.step(); return ast.Node.create(loc, "IdentTy", { ident }); + } else if (this.eat("*")) { + const mutable = this.eat("mut"); + const ty = this.parseTy(); + return ast.Node.create(loc, mutable ? "PtrMutTy" : "PtrTy", { ty }); } else { this.mustEat(""); throw new Error(); @@ -305,7 +316,7 @@ export class Parser { export type Tok = { type: string; value: string; line: number }; const keywordPattern = - /^(?:fn)|(?:return)|(?:let)|(?:if)|(?:else)|(?:while)|(?:break)|(?:or)|(?:and)|(?:not)$/; + /^(?:fn)|(?:return)|(?:let)|(?:if)|(?:else)|(?:while)|(?:break)|(?:or)|(?:and)|(?:not)|(?:mut)$/; const operatorPattern = /((?:\->)|(?:==)|(?:!=)|(?:<=)|(?:>=)|(?:<<)|(?:>>)|[\n\(\)\{\}\,\.\;\:\!\=\<\>\&\^\|\+\-\*\/\%])/g; diff --git a/src/middle.ts b/src/middle.ts index e9251e1..4aa10e6 100644 --- a/src/middle.ts +++ b/src/middle.ts @@ -21,6 +21,7 @@ export class MiddleLowerer { } class FnLowerer { + private allocs: Inst[] = []; private bbs: BasicBlock[] = [new BasicBlock()]; private localMap = new Map(); @@ -35,6 +36,7 @@ class FnLowerer { const ty = this.checker.check(this.stmt); this.lowerBlock(this.stmt.kind.body.as("Block")); this.pushInst(Ty.Void, "Return", { source: this.makeVoid() }); + this.bbs[0].insts.unshift(...this.allocs); return new Fn(this.stmt, ty, this.bbs); } @@ -48,8 +50,9 @@ class FnLowerer { if (stmt.is("LetStmt")) { const ty = this.checker.check(stmt.kind.param); const expr = this.lowerExpr(stmt.kind.expr); - const local = this.pushInst(ty, "AllocLocal", {}); - this.pushInst(Ty.Void, "LocalStore", { + const local = new Inst(ty, { tag: "Alloca" }); + this.allocs.push(local); + this.pushInst(Ty.Void, "Store", { target: local, source: expr, }); @@ -103,8 +106,8 @@ class FnLowerer { } if (stmt.is("AssignStmt")) { const source = this.lowerExpr(stmt.kind.expr); - const target = this.lowerAssignPlace(stmt.kind.place); - this.pushInst(Ty.Void, "LocalStore", { target, source }); + const target = this.lowerPlace(stmt.kind.place); + this.pushInst(Ty.Void, "Store", { target, source }); return; } if (stmt.is("ExprStmt")) { @@ -114,7 +117,7 @@ class FnLowerer { throw new Error(`'${stmt.kind.tag}' not handled`); } - private lowerAssignPlace(place: ast.Node): Inst { + private lowerPlace(place: ast.Node): Inst { if (place.is("IdentExpr")) { const sym = this.resols.get(place); if (sym.tag === "Let") { @@ -126,6 +129,11 @@ class FnLowerer { } throw new Error(`'${sym.tag}' not handled`); } + + if (place.is("UnaryExpr") && place.kind.op === "Deref") { + return this.lowerExpr(place.kind.expr); + } + throw new Error(`'${place.kind.tag}' not handled`); } @@ -148,7 +156,7 @@ class FnLowerer { if (!local) { throw new Error(); } - return this.pushInst(local.ty, "LocalLoad", { source: local }); + return this.pushInst(local.ty, "Load", { source: local }); } if (sym.tag === "Bool") { return this.pushInst(Ty.Bool, "Bool", { value: sym.value }); @@ -199,6 +207,29 @@ class FnLowerer { const operand = this.lowerExpr(expr.kind.expr); return this.pushInst(Ty.Bool, "Not", { source: operand }); } + if (expr.kind.op === "Ref" || expr.kind.op === "RefMut") { + const place = expr.kind.expr; + if (place.is("IdentExpr")) { + const sym = this.resols.get(place); + if (sym.tag === "Let") { + const local = this.localMap.get(sym.param.id); + if (!local) { + throw new Error(); + } + return local; + } + throw new Error( + `${expr.kind.op} with sym ${sym.tag} not handled`, + ); + } + throw new Error( + `${expr.kind.op} with place ${place.kind.tag} not handled`, + ); + } + if (expr.kind.op === "Deref") { + const source = this.lowerExpr(expr.kind.expr); + return this.pushInst(resultTy, "Load", { source }); + } throw new Error( `'${expr.kind.op}' with '${resultTy.pretty()}' not handled`, ); @@ -266,6 +297,12 @@ const binaryOpPatterns: BinaryOpPattern[] = [ { op: "Gte", tag: "Gte", result: Ty.Bool, left: Ty.Int }, ]; +export interface Visitor { + visitFn?(fn: Fn): void; + visitBasicBlock?(bb: BasicBlock): void; + visitInst?(inst: Inst): void; +} + export class Fn { constructor( public stmt: ast.FnStmt, @@ -273,6 +310,13 @@ export class Fn { public bbs: BasicBlock[], ) {} + visit(v: Visitor) { + v.visitFn?.(this); + for (const bb of this.bbs) { + bb.visit(v); + } + } + pretty(): string { const fnTy = this.ty.is("FnStmt") && this.ty.kind.ty.is("Fn") ? this.ty.kind.ty @@ -320,6 +364,13 @@ class PrettyCx { export class BasicBlock { public insts: Inst[] = []; + visit(v: Visitor) { + v.visitBasicBlock?.(this); + for (const inst of this.insts) { + inst.visit(v); + } + } + pretty(cx: PrettyCx): string { return `bb${cx.bbId(this)}:\n${ this.insts @@ -336,8 +387,12 @@ export class Inst { public kind: InstKind, ) {} + visit(v: Visitor) { + v.visitInst?.(this); + } + pretty(cx: PrettyCx): string { - const r = (v: Inst) => `_${cx.regId(v)}`; + const r = (v: Inst) => `%${cx.regId(v)}`; return `${`${r(this)}:`.padEnd(4, " ")} ${ this.ty.pretty().padEnd(4, " ") @@ -358,18 +413,18 @@ export class Inst { return ` ${k.idx}`; case "Call": return ` ${r(k.callee)} (${k.args.map(r).join(", ")})`; - case "AllocLocal": + case "Alloca": return ""; - case "LocalLoad": + case "Load": return ` ${r(k.source)}`; - case "LocalStore": + case "Store": return ` ${r(k.target)} = ${r(k.source)}`; case "Jump": return ` bb${cx.bbId(k.target)}`; case "Branch": - return ` ${r(k.cond)} ? bb${cx.bbId(k.truthy)} : bb${ - cx.bbId(k.falsy) - }`; + return ` if ${r(k.cond)} then bb${ + cx.bbId(k.truthy) + } else bb${cx.bbId(k.falsy)}`; case "Return": return ` ${r(k.source)}`; case "Not": @@ -409,9 +464,9 @@ export type InstKind = | { tag: "Fn"; fn: Fn } | { tag: "Param"; idx: number } | { tag: "Call"; callee: Inst; args: Inst[] } - | { tag: "AllocLocal" } - | { tag: "LocalLoad"; source: Inst } - | { tag: "LocalStore"; target: Inst; source: Inst } + | { tag: "Alloca" } + | { tag: "Load"; source: Inst } + | { tag: "Store"; target: Inst; source: Inst } | { tag: "Jump"; target: BasicBlock } | { tag: "Branch"; cond: Inst; truthy: BasicBlock; falsy: BasicBlock } | { tag: "Return"; source: Inst } diff --git a/src/mir_interpreter.ts b/src/mir_interpreter.ts index d189c37..18097e3 100644 --- a/src/mir_interpreter.ts +++ b/src/mir_interpreter.ts @@ -3,7 +3,6 @@ import * as mir from "./middle.ts"; export class FnInterpreter { private regs = new Map(); private locals: (Val | null)[] = []; - private localMap = new Map(); private bb: mir.BasicBlock; private instIdx = 0; @@ -42,30 +41,36 @@ export class FnInterpreter { this.regs.set(inst, val); break; } - case "AllocLocal": - this.localMap.set(inst, this.locals.length); + case "Alloca": { + const localIdx = this.locals.length; this.locals.push(null); - break; - case "LocalLoad": - if (!this.localMap.has(k.source)) { - throw new Error(); - } - if (this.locals[this.localMap.get(k.source)!] === null) { - throw new Error(); - } this.regs.set( inst, - this.locals[this.localMap.get(k.source)!]!, + new Val({ tag: "LocalPtr", localIdx, mutable: true }), ); break; - case "LocalStore": - if (!this.localMap.has(k.target)) { + } + case "Load": { + const source = this.regs.get(k.source); + if (!source || source.kind.tag !== "LocalPtr") { throw new Error(); } - this.locals[this.localMap.get(k.target)!] = this.regs.get( - k.source, - )!; + const value = this.locals[source.kind.localIdx]; + if (!value) { + throw new Error(); + } + this.regs.set(inst, value); break; + } + case "Store": { + const target = this.regs.get(k.target)!; + if (target.kind.tag !== "LocalPtr") { + throw new Error(); + } + const source = this.regs.get(k.source)!; + this.locals[target.kind.localIdx] = source; + break; + } case "Jump": { this.bb = k.target; this.instIdx = 0; @@ -207,6 +212,8 @@ class Val { case "Int": case "Bool": return `${k.value}`; + case "LocalPtr": + return ``; case "Fn": return `<${k.fn.ty.pretty()}>`; default: @@ -219,4 +226,5 @@ type ValKind = | { tag: "Void" } | { tag: "Int"; value: number } | { tag: "Bool"; value: boolean } + | { tag: "LocalPtr"; mutable: boolean; localIdx: number } | { tag: "Fn"; fn: mir.Fn }; diff --git a/src/ty.ts b/src/ty.ts index de3de7d..831b71e 100644 --- a/src/ty.ts +++ b/src/ty.ts @@ -102,6 +102,12 @@ export class Ty { if (this.is("Bool")) { return "bool"; } + if (this.is("Ptr")) { + return `*${this.kind.ty.pretty()}`; + } + if (this.is("PtrMut")) { + return `*mut ${this.kind.ty.pretty()}`; + } if (this.is("Fn")) { return `fn (${ this.kind.params.map((param) => param.pretty()).join(", ") @@ -124,5 +130,7 @@ export type TyKind = | { tag: "Void" } | { tag: "Int" } | { tag: "Bool" } + | { tag: "Ptr"; ty: Ty } + | { tag: "PtrMut"; ty: Ty } | { tag: "Fn"; params: Ty[]; retTy: Ty } | { tag: "FnStmt"; ty: Ty; stmt: ast.NodeWithKind<"FnStmt"> }; diff --git a/tests/pointer.ethlang b/tests/pointer.ethlang new file mode 100644 index 0000000..b78af91 --- /dev/null +++ b/tests/pointer.ethlang @@ -0,0 +1,23 @@ + +fn main() +{ + let a = 1; + let b: *int = &a; + + // expect: 1 + print_int(*b); + + a = 2; + + // expect: 2 + print_int(*b); + + let c: *mut int = &mut a; + *c = 3; + + // expect: 3 + print_int(a); + // expect: 3 + print_int(*c); +} + diff --git a/tests/test.sh b/tests/test.sh index 3ede3ed..7432ac7 100755 --- a/tests/test.sh +++ b/tests/test.sh @@ -29,7 +29,7 @@ run_test_file() { then if grep -q '// expect:' $file then - expected=$(grep '// expect:' $file | sed -E 's/\/\/ expect: (.*?)/\1/g') + expected=$(grep '// expect:' $file | sed -E 's/\s*\/\/\s+expect: (.*?)/\1/g') if [[ $output != $expected ]] then echo "-- failed: incorrect output --"