#include "ir.h" #include "arena.h" #include "parse.h" #include #include #include #include #include #include void ir_block_init(IrBlock* block) { arena_init(&block->arena); block->next_vreg = 0; block->insts = NULL; block->count = 0; block->capacity = 0; } void ir_block_free(IrBlock* block) { if (!block) return; arena_free(&block->arena); free(block->insts); } static int ir_inst_index(IrBlock* block, IrInst* inst) { for (size_t i = 0; i < block->count; i++) { if (block->insts[i] == inst) return (int)i; } return -1; } void ir_block_print(IrBlock* block) { if (!block) return; for (size_t i = 0; i < block->count; i++) { IrInst* inst = block->insts[i]; printf("%%%zu = ", i); switch (inst->op) { case OP_INT: printf("Int %llu\n", (unsigned long long)inst->value); break; case OP_ADD: case OP_SUB: case OP_MUL: { const char* op_str = inst->op == OP_ADD ? "Add" : inst->op == OP_SUB ? "Sub" : "Mul"; printf("%s ", op_str); for (size_t j = 0; j < inst->operand_count; j++) { int idx = ir_inst_index(block, inst->operands[j]); if (idx < 0) { printf(""); } else { printf("%%%d", idx); } if (j + 1 < inst->operand_count) printf(", "); } printf("\n"); break; } default: printf("UnknownOp\n"); break; } } } static void ir_block_emit(IrBlock* block, IrInst* inst) { if (!block || !inst) return; if (block->count == block->capacity) { size_t new_cap = block->capacity ? block->capacity * 2 : 8; IrInst** new_buf = realloc(block->insts, new_cap * sizeof(IrInst*)); if (!new_buf) return; block->insts = new_buf; block->capacity = new_cap; } block->insts[block->count++] = inst; } static IrInst* ir_alloc_inst(IrBlock* block) { IrInst* inst = arena_alloc(&block->arena, sizeof(IrInst)); if (!inst) return NULL; memset(inst, 0, sizeof(IrInst)); return inst; } static IrInst* ir_new_int(IrBlock* block, uint64_t value) { IrInst* inst = ir_alloc_inst(block); if (!inst) return NULL; inst->op = OP_INT; inst->value = value; inst->vreg = block->next_vreg++; return ir_block_emit(block, inst), inst; } static IrInst* ir_new_binop(IrBlock* block, OpCode op, IrInst* a, IrInst* b) { IrInst* inst = ir_alloc_inst(block); if (!inst) return NULL; inst->op = op; inst->operand_count = 2; inst->operands[0] = a; inst->operands[1] = b; inst->vreg = block->next_vreg++; return ir_block_emit(block, inst), inst; } static OpCode op_from_ident_strict(const char* s) { if (strcmp(s, "add") == 0) return OP_ADD; if (strcmp(s, "sub") == 0) return OP_SUB; if (strcmp(s, "mul") == 0) return OP_MUL; return (OpCode)-1; } static bool is_int_literal(const char* s, uint64_t* out) { if (!s || !*s) return false; char* end = NULL; unsigned long long v = strtoull(s, &end, 10); if (*end != '\0') return false; *out = (uint64_t)v; return true; } IrInst* ir_lower_expr(IrBlock* block, const Expr* expr) { if (!expr || !block) return NULL; switch (expr->type) { case EXPR_INT: { uint64_t value; if (!is_int_literal(expr->text, &value)) { return NULL; // strict: invalid integer literal } return ir_new_int(block, value); } case EXPR_IDENT: // identifiers not supported in this IR return NULL; case EXPR_SEXPR: { size_t n = expr->sexpr.count; if (n == 0) { return NULL; // malformed: empty s-expression } const Expr* head = expr->sexpr.items[0]; if (!head || head->type != EXPR_IDENT) { return NULL; // strict: first element must be operator } OpCode op = op_from_ident_strict(head->text); if (op == (OpCode)-1) { return NULL; // unknown operator } size_t argc = n - 1; // STRICT ARITY RULES (you can adjust these) if (argc < 2) { return NULL; // e.g. (add x) is invalid } if (argc > IR_MAX_ARITY) { return NULL; // prevent overflow } // Lower first operand IrInst* first = ir_lower_expr(block, expr->sexpr.items[1]); if (!first) return NULL; IrInst* acc = first; // Left-associative lowering: // (add a b c d) => (((a + b) + c) + d) for (size_t i = 2; i < n; i++) { IrInst* rhs = ir_lower_expr(block, expr->sexpr.items[i]); if (!rhs) return NULL; IrInst* tmp = ir_new_binop(block, op, acc, rhs); if (!tmp) return NULL; acc = tmp; } return acc; } } return NULL; } static void test_ir_lower_simple_add(void) { IrBlock block; ir_block_init(&block); // Build AST: (add 2 3) Expr two = { .type = EXPR_INT, .text = "2" }; Expr three = { .type = EXPR_INT, .text = "3" }; Expr add_ident = { .type = EXPR_IDENT, .text = "add" }; Expr sexpr = { .type = EXPR_SEXPR, .sexpr = { .items = (Expr*[]) { &add_ident, &two, &three }, .count = 3 } }; IrInst* result = ir_lower_expr(&block, &sexpr); assert(result != NULL); // Check instruction count assert(block.count == 3); // Int 2 assert(block.insts[0]->op == OP_INT); assert(block.insts[0]->value == 2); // Int 3 assert(block.insts[1]->op == OP_INT); assert(block.insts[1]->value == 3); // Add assert(block.insts[2]->op == OP_ADD); // Operand wiring assert(block.insts[2]->operands[0] == block.insts[0]); assert(block.insts[2]->operands[1] == block.insts[1]); ir_block_free(&block); } static void test_ir_lower_nested_expr(void) { IrBlock block; ir_block_init(&block); Expr two = { .type = EXPR_INT, .text = "2" }; Expr three = { .type = EXPR_INT, .text = "3" }; Expr four = { .type = EXPR_INT, .text = "4" }; Expr mul_ident = { .type = EXPR_IDENT, .text = "mul" }; Expr add_ident = { .type = EXPR_IDENT, .text = "add" }; Expr mul_expr = { .type = EXPR_SEXPR, .sexpr = { .items = (Expr*[]) { &mul_ident, &three, &four }, .count = 3 } }; Expr add_expr = { .type = EXPR_SEXPR, .sexpr = { .items = (Expr*[]) { &add_ident, &two, &mul_expr }, .count = 3 } }; IrInst* result = ir_lower_expr(&block, &add_expr); assert(result != NULL); assert(block.count == 5); assert(block.insts[0]->op == OP_INT && block.insts[0]->value == 2); assert(block.insts[1]->op == OP_INT && block.insts[1]->value == 3); assert(block.insts[2]->op == OP_INT && block.insts[2]->value == 4); assert(block.insts[3]->op == OP_MUL); assert(block.insts[4]->op == OP_ADD); // Verify MUL operands assert(block.insts[3]->operands[0] == block.insts[1]); assert(block.insts[3]->operands[1] == block.insts[2]); // Verify ADD operands assert(block.insts[4]->operands[0] == block.insts[0]); assert(block.insts[4]->operands[1] == block.insts[3]); ir_block_free(&block); } void test_ast_lower(void) { test_ir_lower_simple_add(); test_ir_lower_nested_expr(); }