From 6e5ac26c8494614330ee91a788290436904ef560 Mon Sep 17 00:00:00 2001 From: sfja Date: Mon, 3 Mar 2025 01:57:53 +0100 Subject: [PATCH] compiler: check match expressions --- slige/compiler/check/checker.ts | 128 +++++++++++++++++++++++++++-- slige/compiler/middle/ast_lower.ts | 64 ++++++++++++++- slige/compiler/parse/parser.ts | 3 + slige/compiler/program.slg | 8 +- slige/compiler/stringify/hir.ts | 16 ++-- 5 files changed, 197 insertions(+), 22 deletions(-) diff --git a/slige/compiler/check/checker.ts b/slige/compiler/check/checker.ts index 5b45fed..303c98d 100644 --- a/slige/compiler/check/checker.ts +++ b/slige/compiler/check/checker.ts @@ -72,26 +72,136 @@ export class Checker { const res = this.assignPatTy(kind.pat, ty.val); if (!res.ok) { - this.report(res.val, stmt.span); + this.report(res.val.msg, res.val.span); return Ty({ tag: "error" }); } } - private assignPatTy(pat: ast.Pat, ty: Ty): Res { + private assignPatTy( + pat: ast.Pat, + ty: Ty, + ): Res { + this.patTys.set(pat.id, ty); const k = pat.kind; switch (k.tag) { case "error": // don't report, already reported return Res.Ok(undefined); case "bind": - this.patTys.set(pat.id, ty); return Ok(undefined); case "path": return todo(); - case "tuple": + case "tuple": { + if (k.path) { + const re = this.re.pathRes(k.path.id); + if (re.kind.tag === "variant") { + if (ty.kind.tag !== "enum") { + return Res.Err({ + msg: "type/pattern mismatch", + span: pat.span, + }); + } + const variantTy = + ty.kind.variants[re.kind.variantIdx].data; + if (variantTy.tag !== "tuple") { + return Res.Err({ + msg: "type is not a tuple variant", + span: pat.span, + }); + } + const datak = re.kind.variant.data.kind; + if (datak.tag !== "tuple") { + return Res.Err({ + msg: "variant is not a tuple", + span: pat.span, + }); + } + if (k.elems.length !== datak.elems.length) { + return Res.Err({ + msg: `incorrect amount of elements, expected ${datak.elems.length} got ${k.elems.length}`, + span: pat.span, + }); + } + for (const [i, pat] of k.elems.entries()) { + const res = this.assignPatTy( + pat, + variantTy.elems[i].ty, + ); + if (!res.ok) { + return res; + } + } + return Res.Ok(undefined); + } + } return todo(); - case "struct": + } + case "struct": { + if (k.path) { + const re = this.re.pathRes(k.path.id); + if (re.kind.tag === "variant") { + if (ty.kind.tag !== "enum") { + return Res.Err({ + msg: "type/pattern mismatch", + span: pat.span, + }); + } + const variantTy = + ty.kind.variants[re.kind.variantIdx].data; + if (variantTy.tag !== "struct") { + return Res.Err({ + msg: "type is not a struct variant", + span: pat.span, + }); + } + const datak = re.kind.variant.data.kind; + if (datak.tag !== "struct") { + return Res.Err({ + msg: "variant is not a struct", + span: pat.span, + }); + } + const fieldIdents = datak.fields + .reduce( + ( + map, + field, + ) => (map.set(field.ident!.id, field), map), + new Map(), + ); + for (const field of k.fields) { + if (!fieldIdents.has(field.ident.id)) { + return Res.Err({ + msg: `no field '${field.ident.text}' on type`, + span: pat.span, + }); + } + const res = this.assignPatTy( + field.pat, + fieldIdents.get(field.ident.id)!.ty, + ); + if (!res.ok) { + return res; + } + fieldIdents.delete(field.ident.id); + } + if (fieldIdents.size !== 0) { + return Res.Err({ + msg: `fields ${ + fieldIdents + .values() + .map((field) => `'${field.ident.text}'`) + .toArray() + .join(", ") + } not covered`, + span: pat.span, + }); + } + return Res.Ok(undefined); + } + } return todo(); + } } exhausted(k); } @@ -776,7 +886,7 @@ export class Checker { for (const arm of kind.arms) { const res = this.assignPatTy(arm.pat, ty); if (!res.ok) { - this.report(res.val, arm.pat.span); + this.report(res.val.msg, res.val.span); continue; } } @@ -785,16 +895,16 @@ export class Checker { if (!earlier.ok) { return earlier; } - const exprTy = this.exprTy(arm.expr); + const exprTy = this.exprTy(arm.expr, earlier.val); return this.resolveTys(exprTy, earlier.val); - }, Res.Ok(Ty({ tag: "null" }))); + }, Res.Ok(Ty({ tag: "unknown" }))); if (!tyRes.ok) { this.report(tyRes.val, expr.span); this.exprTys.set(expr.id, Ty({ tag: "error" })); return Ty({ tag: "error" }); } this.exprTys.set(expr.id, tyRes.val); - return todo(); + return tyRes.val; } private checkLoopExpr( diff --git a/slige/compiler/middle/ast_lower.ts b/slige/compiler/middle/ast_lower.ts index 1426a17..5a51d6d 100644 --- a/slige/compiler/middle/ast_lower.ts +++ b/slige/compiler/middle/ast_lower.ts @@ -515,7 +515,69 @@ export class FnLowerer { } private lowerMatchExpr(expr: ast.Expr, kind: ast.MatchExpr): RVal { - return todo(); + if (kind.arms.length === 0) { + return todo(); + } + const ty = this.ch.exprTy(expr); + const dest = this.local(ty); + + const discr = this.lowerExpr(kind.expr); + const exitBlock = this.createBlock(); + + for (const arm of kind.arms) { + const exprBlock = this.createBlock(); + const nextArmBlock = this.createBlock(); + this.lowerMatchArmPattern( + discr, + arm.pat, + exprBlock, + nextArmBlock, + ); + this.pushCreatedBlock(exprBlock); + const rval = this.lowerExpr(arm.expr); + this.addStmt({ + tag: "assign", + place: { local: dest, proj: [] }, + rval, + }); + this.setTer({ tag: "goto", target: exitBlock.id }); + this.pushCreatedBlock(nextArmBlock); + } + this.setTer({ tag: "goto", target: exitBlock.id }); + this.pushCreatedBlock(exitBlock); + return { tag: "use", operand: this.copyOrMoveLocal(dest, ty) }; + } + + private lowerMatchArmPattern( + discr: RVal, + pat: ast.Pat, + truthyBlock: Block, + falsyBlock: Block, + ) { + const k = pat.kind; + switch (k.tag) { + case "error": + return; + case "bind": { + const ty = this.ch.patTy(pat); + const local = this.local(ty); + this.addStmt({ + tag: "assign", + place: { local, proj: [] }, + rval: discr, + }); + this.setTer({ tag: "goto", target: truthyBlock.id }); + return; + } + case "path": + return todo(); + case "tuple": { + return todo(); + } + case "struct": + return todo(); + } + exhausted(k); } private lowerLoopExpr(expr: ast.Expr, kind: ast.LoopExpr): RVal { diff --git a/slige/compiler/parse/parser.ts b/slige/compiler/parse/parser.ts index 126d28a..13b5b46 100644 --- a/slige/compiler/parse/parser.ts +++ b/slige/compiler/parse/parser.ts @@ -1144,6 +1144,9 @@ export class Parser { if (this.test("if")) { return this.parseIf(); } + if (this.test("match")) { + return this.parseMatch(); + } if (this.test("loop")) { return this.parseLoop(); } diff --git a/slige/compiler/program.slg b/slige/compiler/program.slg index 0e0e374..8c677c6 100644 --- a/slige/compiler/program.slg +++ b/slige/compiler/program.slg @@ -6,10 +6,10 @@ enum S { fn main() { let s = S::A(123); - match s { - S::A(v) => {}, - S::B { v: v } => {}, - } + let r = match s { + S::A(v) => { 3 + 2 }, + S::B { v: v } => { 4 }, + }; } diff --git a/slige/compiler/stringify/hir.ts b/slige/compiler/stringify/hir.ts index c969d9f..0fa1e04 100644 --- a/slige/compiler/stringify/hir.ts +++ b/slige/compiler/stringify/hir.ts @@ -137,14 +137,14 @@ export class HirStringifyer { }`; case "match": return `match ${this.expr(k.expr, d)} ${ - k.arms.length === 0 - ? "{}" - : `{${ - k.arms.map((arm) => this.matchArm(arm, d + 1)).map( - (s) => - `\n${s},`, + k.arms.length === 0 ? "{}" : `{${ + k.arms + .map((arm) => this.matchArm(arm, d + 1)) + .map((s) => + `\n${indent(d + 1)}${s},` ) - }\n${indent(d)}}` + .join("") + }\n${indent(d)}}` }`; case "loop": return `loop ${this.expr(k.body, d)}`; @@ -163,7 +163,7 @@ export class HirStringifyer { } public matchArm(arm: ast.MatchArm, d: number): string { - return `${this.pat(arm.pat, d)} => ${this.expr(arm.expr, d + 1)}`; + return `${this.pat(arm.pat, d)} => ${this.expr(arm.expr, d)}`; } public pat(pat: ast.Pat, d: number): string {