2026-05-21 00:31:25 +02:00

329 lines
7.9 KiB
C

#include "ir.h"
#include "arena.h"
#include "parse.h"
#include <assert.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
void ir_block_init(IrBlock* block)
{
arena_init(&block->arena);
block->next_vreg = 0;
block->insts = NULL;
block->count = 0;
block->capacity = 0;
}
void ir_block_free(IrBlock* block)
{
if (!block)
return;
arena_free(&block->arena);
free(block->insts);
}
static int ir_inst_index(IrBlock* block, IrInst* inst)
{
for (size_t i = 0; i < block->count; i++) {
if (block->insts[i] == inst)
return (int)i;
}
return -1;
}
void ir_block_print(IrBlock* block)
{
if (!block)
return;
for (size_t i = 0; i < block->count; i++) {
IrInst* inst = block->insts[i];
printf("%%%zu = ", i);
switch (inst->op) {
case OP_INT:
printf("Int %llu\n", (unsigned long long)inst->value);
break;
case OP_ADD:
case OP_SUB:
case OP_MUL: {
const char* op_str = inst->op == OP_ADD ? "Add"
: inst->op == OP_SUB ? "Sub"
: "Mul";
printf("%s ", op_str);
for (size_t j = 0; j < inst->operand_count; j++) {
int idx = ir_inst_index(block, inst->operands[j]);
if (idx < 0) {
printf("<?>");
} else {
printf("%%%d", idx);
}
if (j + 1 < inst->operand_count)
printf(", ");
}
printf("\n");
break;
}
default:
printf("UnknownOp\n");
break;
}
}
}
static void ir_block_emit(IrBlock* block, IrInst* inst)
{
if (!block || !inst)
return;
if (block->count == block->capacity) {
size_t new_cap = block->capacity ? block->capacity * 2 : 8;
IrInst** new_buf = realloc(block->insts, new_cap * sizeof(IrInst*));
if (!new_buf)
return;
block->insts = new_buf;
block->capacity = new_cap;
}
block->insts[block->count++] = inst;
}
static IrInst* ir_alloc_inst(IrBlock* block)
{
IrInst* inst = arena_alloc(&block->arena, sizeof(IrInst));
if (!inst)
return NULL;
memset(inst, 0, sizeof(IrInst));
return inst;
}
static IrInst* ir_new_int(IrBlock* block, uint64_t value)
{
IrInst* inst = ir_alloc_inst(block);
if (!inst)
return NULL;
inst->op = OP_INT;
inst->value = value;
inst->vreg = block->next_vreg++;
return ir_block_emit(block, inst), inst;
}
static IrInst* ir_new_binop(IrBlock* block, OpCode op, IrInst* a, IrInst* b)
{
IrInst* inst = ir_alloc_inst(block);
if (!inst)
return NULL;
inst->op = op;
inst->operand_count = 2;
inst->operands[0] = a;
inst->operands[1] = b;
inst->vreg = block->next_vreg++;
return ir_block_emit(block, inst), inst;
}
static OpCode op_from_ident_strict(const char* s)
{
if (strcmp(s, "add") == 0)
return OP_ADD;
if (strcmp(s, "sub") == 0)
return OP_SUB;
if (strcmp(s, "mul") == 0)
return OP_MUL;
return (OpCode)-1;
}
static bool is_int_literal(const char* s, uint64_t* out)
{
if (!s || !*s)
return false;
char* end = NULL;
unsigned long long v = strtoull(s, &end, 10);
if (*end != '\0')
return false;
*out = (uint64_t)v;
return true;
}
IrInst* ir_lower_expr(IrBlock* block, const Expr* expr)
{
if (!expr || !block)
return NULL;
switch (expr->type) {
case EXPR_INT: {
uint64_t value;
if (!is_int_literal(expr->text, &value)) {
return NULL; // strict: invalid integer literal
}
return ir_new_int(block, value);
}
case EXPR_IDENT:
// identifiers not supported in this IR
return NULL;
case EXPR_SEXPR: {
size_t n = expr->sexpr.count;
if (n == 0) {
return NULL; // malformed: empty s-expression
}
const Expr* head = expr->sexpr.items[0];
if (!head || head->type != EXPR_IDENT) {
return NULL; // strict: first element must be operator
}
OpCode op = op_from_ident_strict(head->text);
if (op == (OpCode)-1) {
return NULL; // unknown operator
}
size_t argc = n - 1;
// STRICT ARITY RULES (you can adjust these)
if (argc < 2) {
return NULL; // e.g. (add x) is invalid
}
if (argc > IR_MAX_ARITY) {
return NULL; // prevent overflow
}
// Lower first operand
IrInst* first = ir_lower_expr(block, expr->sexpr.items[1]);
if (!first)
return NULL;
IrInst* acc = first;
// Left-associative lowering:
// (add a b c d) => (((a + b) + c) + d)
for (size_t i = 2; i < n; i++) {
IrInst* rhs = ir_lower_expr(block, expr->sexpr.items[i]);
if (!rhs)
return NULL;
IrInst* tmp = ir_new_binop(block, op, acc, rhs);
if (!tmp)
return NULL;
acc = tmp;
}
return acc;
}
}
return NULL;
}
static void test_ir_lower_simple_add(void)
{
IrBlock block;
ir_block_init(&block);
// Build AST: (add 2 3)
Expr two = { .type = EXPR_INT, .text = "2" };
Expr three = { .type = EXPR_INT, .text = "3" };
Expr add_ident = { .type = EXPR_IDENT, .text = "add" };
Expr sexpr = { .type = EXPR_SEXPR,
.sexpr
= { .items = (Expr*[]) { &add_ident, &two, &three }, .count = 3 } };
IrInst* result = ir_lower_expr(&block, &sexpr);
assert(result != NULL);
// Check instruction count
assert(block.count == 3);
// Int 2
assert(block.insts[0]->op == OP_INT);
assert(block.insts[0]->value == 2);
// Int 3
assert(block.insts[1]->op == OP_INT);
assert(block.insts[1]->value == 3);
// Add
assert(block.insts[2]->op == OP_ADD);
// Operand wiring
assert(block.insts[2]->operands[0] == block.insts[0]);
assert(block.insts[2]->operands[1] == block.insts[1]);
ir_block_free(&block);
}
static void test_ir_lower_nested_expr(void)
{
IrBlock block;
ir_block_init(&block);
Expr two = { .type = EXPR_INT, .text = "2" };
Expr three = { .type = EXPR_INT, .text = "3" };
Expr four = { .type = EXPR_INT, .text = "4" };
Expr mul_ident = { .type = EXPR_IDENT, .text = "mul" };
Expr add_ident = { .type = EXPR_IDENT, .text = "add" };
Expr mul_expr = { .type = EXPR_SEXPR,
.sexpr
= { .items = (Expr*[]) { &mul_ident, &three, &four }, .count = 3 } };
Expr add_expr = { .type = EXPR_SEXPR,
.sexpr
= { .items = (Expr*[]) { &add_ident, &two, &mul_expr }, .count = 3 } };
IrInst* result = ir_lower_expr(&block, &add_expr);
assert(result != NULL);
assert(block.count == 5);
assert(block.insts[0]->op == OP_INT && block.insts[0]->value == 2);
assert(block.insts[1]->op == OP_INT && block.insts[1]->value == 3);
assert(block.insts[2]->op == OP_INT && block.insts[2]->value == 4);
assert(block.insts[3]->op == OP_MUL);
assert(block.insts[4]->op == OP_ADD);
// Verify MUL operands
assert(block.insts[3]->operands[0] == block.insts[1]);
assert(block.insts[3]->operands[1] == block.insts[2]);
// Verify ADD operands
assert(block.insts[4]->operands[0] == block.insts[0]);
assert(block.insts[4]->operands[1] == block.insts[3]);
ir_block_free(&block);
}
void test_ast_lower(void)
{
test_ir_lower_simple_add();
test_ir_lower_nested_expr();
}