compiler: check match expressions

This commit is contained in:
sfja 2025-03-03 01:57:53 +01:00 committed by SimonFJ20
parent 6d1693c318
commit 6e5ac26c84
5 changed files with 197 additions and 22 deletions

View File

@ -72,27 +72,137 @@ export class Checker {
const res = this.assignPatTy(kind.pat, ty.val); const res = this.assignPatTy(kind.pat, ty.val);
if (!res.ok) { if (!res.ok) {
this.report(res.val, stmt.span); this.report(res.val.msg, res.val.span);
return Ty({ tag: "error" }); return Ty({ tag: "error" });
} }
} }
private assignPatTy(pat: ast.Pat, ty: Ty): Res<void, string> { private assignPatTy(
pat: ast.Pat,
ty: Ty,
): Res<void, { msg: string; span: Span }> {
this.patTys.set(pat.id, ty);
const k = pat.kind; const k = pat.kind;
switch (k.tag) { switch (k.tag) {
case "error": case "error":
// don't report, already reported // don't report, already reported
return Res.Ok(undefined); return Res.Ok(undefined);
case "bind": case "bind":
this.patTys.set(pat.id, ty);
return Ok(undefined); return Ok(undefined);
case "path": case "path":
return todo(); return todo();
case "tuple": case "tuple": {
if (k.path) {
const re = this.re.pathRes(k.path.id);
if (re.kind.tag === "variant") {
if (ty.kind.tag !== "enum") {
return Res.Err({
msg: "type/pattern mismatch",
span: pat.span,
});
}
const variantTy =
ty.kind.variants[re.kind.variantIdx].data;
if (variantTy.tag !== "tuple") {
return Res.Err({
msg: "type is not a tuple variant",
span: pat.span,
});
}
const datak = re.kind.variant.data.kind;
if (datak.tag !== "tuple") {
return Res.Err({
msg: "variant is not a tuple",
span: pat.span,
});
}
if (k.elems.length !== datak.elems.length) {
return Res.Err({
msg: `incorrect amount of elements, expected ${datak.elems.length} got ${k.elems.length}`,
span: pat.span,
});
}
for (const [i, pat] of k.elems.entries()) {
const res = this.assignPatTy(
pat,
variantTy.elems[i].ty,
);
if (!res.ok) {
return res;
}
}
return Res.Ok(undefined);
}
}
return todo(); return todo();
case "struct": }
case "struct": {
if (k.path) {
const re = this.re.pathRes(k.path.id);
if (re.kind.tag === "variant") {
if (ty.kind.tag !== "enum") {
return Res.Err({
msg: "type/pattern mismatch",
span: pat.span,
});
}
const variantTy =
ty.kind.variants[re.kind.variantIdx].data;
if (variantTy.tag !== "struct") {
return Res.Err({
msg: "type is not a struct variant",
span: pat.span,
});
}
const datak = re.kind.variant.data.kind;
if (datak.tag !== "struct") {
return Res.Err({
msg: "variant is not a struct",
span: pat.span,
});
}
const fieldIdents = datak.fields
.reduce(
(
map,
field,
) => (map.set(field.ident!.id, field), map),
new Map(),
);
for (const field of k.fields) {
if (!fieldIdents.has(field.ident.id)) {
return Res.Err({
msg: `no field '${field.ident.text}' on type`,
span: pat.span,
});
}
const res = this.assignPatTy(
field.pat,
fieldIdents.get(field.ident.id)!.ty,
);
if (!res.ok) {
return res;
}
fieldIdents.delete(field.ident.id);
}
if (fieldIdents.size !== 0) {
return Res.Err({
msg: `fields ${
fieldIdents
.values()
.map((field) => `'${field.ident.text}'`)
.toArray()
.join(", ")
} not covered`,
span: pat.span,
});
}
return Res.Ok(undefined);
}
}
return todo(); return todo();
} }
}
exhausted(k); exhausted(k);
} }
@ -776,7 +886,7 @@ export class Checker {
for (const arm of kind.arms) { for (const arm of kind.arms) {
const res = this.assignPatTy(arm.pat, ty); const res = this.assignPatTy(arm.pat, ty);
if (!res.ok) { if (!res.ok) {
this.report(res.val, arm.pat.span); this.report(res.val.msg, res.val.span);
continue; continue;
} }
} }
@ -785,16 +895,16 @@ export class Checker {
if (!earlier.ok) { if (!earlier.ok) {
return earlier; return earlier;
} }
const exprTy = this.exprTy(arm.expr); const exprTy = this.exprTy(arm.expr, earlier.val);
return this.resolveTys(exprTy, earlier.val); return this.resolveTys(exprTy, earlier.val);
}, Res.Ok(Ty({ tag: "null" }))); }, Res.Ok(Ty({ tag: "unknown" })));
if (!tyRes.ok) { if (!tyRes.ok) {
this.report(tyRes.val, expr.span); this.report(tyRes.val, expr.span);
this.exprTys.set(expr.id, Ty({ tag: "error" })); this.exprTys.set(expr.id, Ty({ tag: "error" }));
return Ty({ tag: "error" }); return Ty({ tag: "error" });
} }
this.exprTys.set(expr.id, tyRes.val); this.exprTys.set(expr.id, tyRes.val);
return todo(); return tyRes.val;
} }
private checkLoopExpr( private checkLoopExpr(

View File

@ -515,8 +515,70 @@ export class FnLowerer {
} }
private lowerMatchExpr(expr: ast.Expr, kind: ast.MatchExpr): RVal { private lowerMatchExpr(expr: ast.Expr, kind: ast.MatchExpr): RVal {
if (kind.arms.length === 0) {
return todo(); return todo();
} }
const ty = this.ch.exprTy(expr);
const dest = this.local(ty);
const discr = this.lowerExpr(kind.expr);
const exitBlock = this.createBlock();
for (const arm of kind.arms) {
const exprBlock = this.createBlock();
const nextArmBlock = this.createBlock();
this.lowerMatchArmPattern(
discr,
arm.pat,
exprBlock,
nextArmBlock,
);
this.pushCreatedBlock(exprBlock);
const rval = this.lowerExpr(arm.expr);
this.addStmt({
tag: "assign",
place: { local: dest, proj: [] },
rval,
});
this.setTer({ tag: "goto", target: exitBlock.id });
this.pushCreatedBlock(nextArmBlock);
}
this.setTer({ tag: "goto", target: exitBlock.id });
this.pushCreatedBlock(exitBlock);
return { tag: "use", operand: this.copyOrMoveLocal(dest, ty) };
}
private lowerMatchArmPattern(
discr: RVal,
pat: ast.Pat,
truthyBlock: Block,
falsyBlock: Block,
) {
const k = pat.kind;
switch (k.tag) {
case "error":
return;
case "bind": {
const ty = this.ch.patTy(pat);
const local = this.local(ty);
this.addStmt({
tag: "assign",
place: { local, proj: [] },
rval: discr,
});
this.setTer({ tag: "goto", target: truthyBlock.id });
return;
}
case "path":
return todo();
case "tuple": {
return todo();
}
case "struct":
return todo();
}
exhausted(k);
}
private lowerLoopExpr(expr: ast.Expr, kind: ast.LoopExpr): RVal { private lowerLoopExpr(expr: ast.Expr, kind: ast.LoopExpr): RVal {
const entryBlock = this.currentBlock!; const entryBlock = this.currentBlock!;

View File

@ -1144,6 +1144,9 @@ export class Parser {
if (this.test("if")) { if (this.test("if")) {
return this.parseIf(); return this.parseIf();
} }
if (this.test("match")) {
return this.parseMatch();
}
if (this.test("loop")) { if (this.test("loop")) {
return this.parseLoop(); return this.parseLoop();
} }

View File

@ -6,10 +6,10 @@ enum S {
fn main() { fn main() {
let s = S::A(123); let s = S::A(123);
match s { let r = match s {
S::A(v) => {}, S::A(v) => { 3 + 2 },
S::B { v: v } => {}, S::B { v: v } => { 4 },
} };
} }

View File

@ -137,13 +137,13 @@ export class HirStringifyer {
}`; }`;
case "match": case "match":
return `match ${this.expr(k.expr, d)} ${ return `match ${this.expr(k.expr, d)} ${
k.arms.length === 0 k.arms.length === 0 ? "{}" : `{${
? "{}" k.arms
: `{${ .map((arm) => this.matchArm(arm, d + 1))
k.arms.map((arm) => this.matchArm(arm, d + 1)).map( .map((s) =>
(s) => `\n${indent(d + 1)}${s},`
`\n${s},`,
) )
.join("")
}\n${indent(d)}}` }\n${indent(d)}}`
}`; }`;
case "loop": case "loop":
@ -163,7 +163,7 @@ export class HirStringifyer {
} }
public matchArm(arm: ast.MatchArm, d: number): string { public matchArm(arm: ast.MatchArm, d: number): string {
return `${this.pat(arm.pat, d)} => ${this.expr(arm.expr, d + 1)}`; return `${this.pat(arm.pat, d)} => ${this.expr(arm.expr, d)}`;
} }
public pat(pat: ast.Pat, d: number): string { public pat(pat: ast.Pat, d: number): string {