#include "asm.h"
#include "vm.h"
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

Line s_label(int label)
{
    return (Line) {
        .ty = LineTy_Label,
        .op1 = (Ex) { .label = label },
    };
}
Line s_data_i(uint16_t data)
{
    return (Line) {
        .ty = LineTy_DataImm,
        .op1 = (Ex) { .imm = data },
    };
}
Line s_data_l(int label)
{
    return (Line) {
        .ty = LineTy_DataLabel,
        .op1 = (Ex) { .label = label },

    };
}
Line s_nop(void)
{
    return (Line) { .ty = LineTy_Nop };
}
Line s_hlt(void)
{
    return (Line) { .ty = LineTy_Hlt };
}
Line s_jmp_l(int op1_label)
{
    return (Line) {
        .ty = LineTy_Jmp_Label,
        .op1 = (Ex) { .label = op1_label },
    };
}
Line s_jnz_l(Reg op1_reg, int op2_label)
{
    return (Line) {
        .ty = LineTy_Jnz_Label,
        .op1 = (Ex) { .reg = (uint16_t)op1_reg },
        .op2 = (Ex) { .label = op2_label },
    };
}
Line s_cmp_i(Reg op1_reg, uint16_t op2_imm)
{
    return (Line) {
        .ty = LineTy_Cmp_Imm,
        .op1 = (Ex) { .reg = (uint16_t)op1_reg },
        .op2 = (Ex) { .imm = op2_imm },
    };
}
Line s_mov8_mr_r(Reg dst_reg, Reg op2_reg)
{
    return (Line) {
        .ty = LineTy_Mov8_MemReg_Reg,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op2 = (Ex) { .reg = (uint16_t)op2_reg },
    };
}
Line s_mov8_mi_i(uint16_t dst_imm, uint16_t op2_imm)
{
    return (Line) {
        .ty = LineTy_Mov8_MemImm_Imm,
        .dst = (Ex) { .imm = dst_imm },
        .op2 = (Ex) { .imm = op2_imm },
    };
}
Line s_mov8_mi_r(uint16_t dst_imm, Reg op2_reg)
{
    return (Line) {
        .ty = LineTy_Mov8_MemImm_Reg,
        .dst = (Ex) { .imm = dst_imm },
        .op2 = (Ex) { .reg = (uint16_t)op2_reg },
    };
}
Line s_mov16_r_r(Reg dst_reg, Reg op2_reg)
{
    return (Line) {
        .ty = LineTy_Mov16_Reg_Reg,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op2 = (Ex) { .reg = (uint16_t)op2_reg },
    };
}
Line s_mov16_r_i(Reg dst_reg, uint16_t op2_imm)
{
    return (Line) {
        .ty = LineTy_Mov16_Reg_Imm,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op2 = (Ex) { .imm = op2_imm },
    };
}
Line s_mov16_r_mr(Reg dst_reg, Reg op2_reg, uint16_t op2_offset)
{
    return (Line) {
        .ty = LineTy_Mov16_Reg_MemReg,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op2 = (Ex) { .reg = (uint16_t)op2_reg },
        .offset = op2_offset,
    };
}
Line s_mov16_r_mi(Reg dst_reg, uint16_t op2_imm)
{
    return (Line) {
        .ty = LineTy_Mov16_Reg_MemImm,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op2 = (Ex) { .imm = op2_imm },
    };
}
Line s_mov16_r_ml(Reg dst_reg, int op2_label)
{
    return (Line) {
        .ty = LineTy_Mov16_Reg_MemLabel,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op2 = (Ex) { .label = op2_label },
    };
}
Line s_mov16_mr_r(Reg dst_reg, uint16_t dst_offset, Reg op2_reg)
{
    return (Line) {
        .ty = LineTy_Mov16_MemReg_Reg,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op2 = (Ex) { .reg = (uint16_t)op2_reg },
        .offset = dst_offset,
    };
}
Line s_mov16_ml_r(int dst_label, Reg op2_reg)
{
    return (Line) {
        .ty = LineTy_Mov16_MemLabel_Reg,
        .dst = (Ex) { .label = dst_label },
        .op2 = (Ex) { .reg = (uint16_t)op2_reg },
    };
}
Line s_in_i(Reg dst_reg, uint16_t op1_imm)
{
    return (Line) {
        .ty = LineTy_In_Imm,
        .dst = (Ex) { .reg = (uint16_t)dst_reg },
        .op1 = (Ex) { .imm = op1_imm },
    };
}
Line s_lit_i(uint16_t op1_imm)
{
    return (Line) {
        .ty = LineTy_Lit_Imm,
        .op1 = (Ex) { .imm = op1_imm },
    };
}
Line s_lit_l(int op1_label)
{
    return (Line) {
        .ty = LineTy_Lit_Label,
        .op1 = (Ex) { .label = op1_label },
    };
}
Line s_iret(void)
{
    return (Line) { .ty = LineTy_IRet };
}

#define DEFINE_BINARY_I(FN, LINETY)                                            \
    Line s_##FN##_i(Reg dst_reg, Reg op1_reg, uint16_t op2_imm)                \
    {                                                                          \
        return (Line) {                                                        \
            .ty = LineTy_##LINETY##_Imm,                                       \
            .dst = (Ex) { .reg = (uint16_t)dst_reg },                          \
            .op1 = (Ex) { .reg = (uint16_t)op1_reg },                          \
            .op2 = (Ex) { .imm = op2_imm },                                    \
        };                                                                     \
    }

DEFINE_BINARY_I(or, Or)
DEFINE_BINARY_I(and, And)
DEFINE_BINARY_I(add, Add)
DEFINE_BINARY_I(sub, Sub)

static inline void add_dst_reg(uint32_t* ins, uint16_t reg);
static inline void add_op1_reg(uint32_t* ins, uint16_t reg);
static inline void add_op2_reg(uint32_t* ins, uint16_t reg);
static inline void set_is_imm(uint32_t* ins);
static inline void set_mov_is_memory(uint32_t* ins);
static inline void set_mov_addr_is_reg(uint32_t* ins);
static inline void set_mov_is_store(uint32_t* ins);
static inline uint16_t linety_arithm_ins(LineTy ty);

void assemble_to_binary(uint16_t* out, const Line* lines, size_t lines_size)
{
    uint16_t ip = 0;

    typedef struct {
        int label;
        uint16_t ptr;
    } UnresolvedLabel;
    typedef struct {
        int label;
        uint16_t ip;
    } ResolvedLabel;

    UnresolvedLabel* unres_labels = malloc(sizeof(UnresolvedLabel) * 64);
    size_t unres_labels_size = 0;

    ResolvedLabel* res_labels = malloc(sizeof(ResolvedLabel) * 64);
    size_t res_labels_size = 0;

#define ADD_LABEL(LABEL)                                                       \
    {                                                                          \
        unres_labels[unres_labels_size++] = (UnresolvedLabel) { LABEL, ip };   \
        out[ip++] = 0;                                                         \
    }

    printf("assembling...\n");
    printf("ip op    n data...\n");
    for (size_t i = 0; i < lines_size; ++i) {
        bool is_label = false;
        bool is_data = false;

        const Line* line = &lines[i];
        uint16_t ins_ip = ip;
        switch (line->ty) {
            case LineTy_Label: {
                res_labels[res_labels_size++]
                    = (ResolvedLabel) { line->op1.label, ip * 2 };

                is_label = true;
                break;
            }
            case LineTy_DataImm: {
                out[ip++] = line->op1.imm;

                is_data = true;
                break;
            }
            case LineTy_DataLabel: {
                ADD_LABEL(line->op1.label);

                is_data = true;
                break;
            }
            case LineTy_Nop: {
                out[ip++] = Op_Nop;
                break;
            }
            case LineTy_Hlt: {
                out[ip++] = Op_Hlt;
                break;
            }
            case LineTy_Jmp_Label: {
                int op1 = line->op1.label;

                uint32_t ins = Op_Jmp;
                set_is_imm(&ins);
                out[ip++] = (uint16_t)ins;
                ADD_LABEL(op1);
                break;
            }
            case LineTy_Jnz_Label: {
                uint16_t op1 = line->op1.reg;
                int op2 = line->op2.label;

                uint32_t ins = Op_Jnz;
                set_is_imm(&ins);
                add_op1_reg(&ins, op1);

                out[ip++] = (uint16_t)ins;
                ADD_LABEL(op2);
                break;
            }
            case LineTy_Cmp_Imm: {
                uint16_t op1 = line->op1.reg;
                uint16_t op2 = line->op2.imm;

                uint32_t ins = Op_Cmp;
                set_is_imm(&ins);
                add_op1_reg(&ins, op1);

                out[ip++] = (uint16_t)ins;
                out[ip++] = op2;
                break;
                break;
            }
            case LineTy_Mov8_MemReg_Reg: {
                uint16_t dst = line->dst.reg;
                uint16_t op2 = line->op2.reg;

                uint32_t ins = Op_Mov8;
                add_op2_reg(&ins, op2);
                set_mov_is_memory(&ins);
                set_mov_addr_is_reg(&ins);
                set_mov_is_store(&ins);
                add_dst_reg(&ins, dst);

                out[ip++] = (uint16_t)ins;
                out[ip++] = line->offset;
                break;
            }
            case LineTy_Mov8_MemImm_Imm: {
                uint16_t dst = line->dst.imm;
                uint16_t op2 = line->op2.imm;

                uint32_t ins = Op_Mov8;
                set_is_imm(&ins);
                set_mov_is_memory(&ins);
                set_mov_is_store(&ins);

                out[ip++] = (uint16_t)ins;
                out[ip++] = dst;
                out[ip++] = op2;
                break;
            }
            case LineTy_Mov8_MemImm_Reg: {
                uint16_t dst = line->dst.imm;
                uint16_t op2 = line->op2.reg;

                uint32_t ins = Op_Mov8;
                add_op2_reg(&ins, op2);
                set_mov_is_memory(&ins);
                set_mov_is_store(&ins);

                out[ip++] = (uint16_t)ins;
                out[ip++] = dst;
                break;
            }
            case LineTy_Mov16_Reg_Reg: {
                uint16_t dst = line->dst.reg;
                uint16_t op2 = line->op2.reg;

                uint32_t ins = Op_Mov16;
                ins |= (op2 & 0xfu) << 7;
                ins |= (dst & 0xfu) << 12;

                out[ip++] = (uint16_t)ins;
                break;
            }
            case LineTy_Mov16_Reg_Imm: {
                uint16_t dst = line->dst.reg;
                uint16_t op2 = line->op2.imm;

                uint32_t ins = Op_Mov16;
                set_is_imm(&ins);
                ins |= (dst & 0xfu) << 12;

                out[ip++] = (uint16_t)ins;
                out[ip++] = op2;
                break;
            }
            case LineTy_Mov16_Reg_MemReg: {
                uint16_t dst = line->dst.reg;
                uint16_t op2 = line->op2.reg;

                uint32_t ins = Op_Mov16;
                add_op2_reg(&ins, op2);
                set_mov_is_memory(&ins);
                set_mov_addr_is_reg(&ins);
                add_dst_reg(&ins, dst);

                out[ip++] = (uint16_t)ins;
                out[ip++] = line->offset;
                break;
            }
            case LineTy_Mov16_Reg_MemImm: {
                uint16_t dst = line->dst.reg;
                uint16_t op2 = line->op2.imm;

                uint32_t ins = Op_Mov16;
                set_is_imm(&ins);
                set_mov_is_memory(&ins);
                add_dst_reg(&ins, dst);

                out[ip++] = (uint16_t)ins;
                out[ip++] = op2;
                break;
            }
            case LineTy_Mov16_Reg_MemLabel: {
                uint16_t dst = line->dst.reg;

                uint32_t ins = Op_Mov16;
                set_is_imm(&ins);
                set_mov_is_memory(&ins);
                add_dst_reg(&ins, dst);

                out[ip++] = (uint16_t)ins;
                ADD_LABEL(line->op2.label);
                break;
            }
            case LineTy_Mov16_MemReg_Reg: {
                uint16_t dst = line->dst.reg;
                uint16_t op2 = line->op2.reg;

                uint32_t ins = Op_Mov16;
                add_op2_reg(&ins, op2);
                set_mov_is_memory(&ins);
                set_mov_addr_is_reg(&ins);
                set_mov_is_store(&ins);
                add_dst_reg(&ins, dst);

                out[ip++] = (uint16_t)ins;
                out[ip++] = line->offset;
                break;
            }
            case LineTy_Mov16_MemLabel_Reg: {
                uint16_t op2 = line->op2.reg;

                uint32_t ins = Op_Mov16;
                add_op2_reg(&ins, op2);
                set_mov_is_memory(&ins);
                set_mov_is_store(&ins);

                out[ip++] = (uint16_t)ins;
                ADD_LABEL(line->dst.label);
                break;
            }
            case LineTy_In_Imm: {
                uint16_t dst = line->dst.reg;
                uint16_t op1 = line->op1.imm;

                uint32_t ins = Op_In;
                set_is_imm(&ins);
                add_dst_reg(&ins, dst);

                out[ip++] = (uint16_t)ins;
                out[ip++] = op1;
                break;
            }
            case LineTy_Lit_Imm: {
                uint16_t op1 = line->op1.imm;

                uint32_t ins = Op_Lit;
                set_is_imm(&ins);

                out[ip++] = (uint16_t)ins;
                out[ip++] = op1;
                break;
            }
            case LineTy_Lit_Label: {
                int op1 = line->op1.label;

                uint32_t ins = Op_Lit;
                set_is_imm(&ins);

                out[ip++] = (uint16_t)ins;
                ADD_LABEL(op1);
                break;
            }
            case LineTy_IRet: {
                out[ip++] = Op_IRet;
                break;
            }
            case LineTy_Or_Imm:
            case LineTy_And_Imm:
            case LineTy_Add_Imm:
            case LineTy_Sub_Imm: {
                uint16_t dst = line->dst.reg;
                uint16_t op1 = line->op1.reg;
                uint16_t op2 = line->op2.imm;

                uint32_t ins = linety_arithm_ins(line->ty);
                set_is_imm(&ins);
                add_op1_reg(&ins, op1);
                add_dst_reg(&ins, dst);

                out[ip++] = (uint16_t)ins;
                out[ip++] = op2;
                break;
            }
        }

        if (!is_label) {
            printf("%02x %-5s %d",
                ins_ip * 2,
                is_data ? "data" : op_str(out[ins_ip] & 0x3f),
                ip - ins_ip);
            for (uint16_t i = 0; i < ip - ins_ip; ++i) {
                printf(" %02x %c%c%c%c %c%c%c%c  %02x %c%c%c%c %c%c%c%c ",
                    out[ins_ip + i] & 0xff,
                    fmt_binary(out[ins_ip + i] & 0xff),
                    out[ins_ip + i] >> 8,
                    fmt_binary(out[ins_ip + i] >> 8));
            }
            printf("\n");
        }
    }

    printf("resolving...\n");
    printf(" l ip  v  data\n");
    for (size_t i = 0; i < unres_labels_size; ++i) {
        bool found = false;
        for (size_t j = 0; j < res_labels_size; ++j) {
            if (res_labels[j].label == unres_labels[i].label) {
                out[unres_labels[i].ptr] = res_labels[j].ip;
                found = true;

                printf("%2d %02x %02x  %02x %c%c%c%c %c%c%c%c  %02x %c%c%c%c "
                       "%c%c%c%c\n",
                    res_labels[j].label,
                    unres_labels[i].ptr * 2,
                    res_labels[j].ip,
                    out[unres_labels[i].ptr] & 0xff,
                    fmt_binary(out[unres_labels[i].ptr] & 0xff),
                    out[unres_labels[i].ptr] >> 8,
                    fmt_binary(out[unres_labels[i].ptr] >> 8));
                break;
            }
        }
        if (!found) {
            fprintf(stderr,
                "warning: label '%d' could not be resolved\n",
                unres_labels[i].label);
        }
    }
    printf("done!\n");
}

static inline void add_dst_reg(uint32_t* ins, uint16_t reg)
{
    *ins |= (reg & 0x7u) << 13;
}
static inline void add_op1_reg(uint32_t* ins, uint16_t reg)
{
    *ins |= (reg & 0x7u) << 10;
}
static inline void add_op2_reg(uint32_t* ins, uint16_t reg)
{
    *ins |= (reg & 0x7u) << 7;
}
static inline void set_is_imm(uint32_t* ins)
{
    *ins |= 1 << 6;
}
static inline void set_mov_is_memory(uint32_t* ins)
{
    *ins |= 1 << 10;
}
static inline void set_mov_addr_is_reg(uint32_t* ins)
{
    *ins |= 1 << 11;
}
static inline void set_mov_is_store(uint32_t* ins)
{
    *ins |= 1 << 12;
}

static inline uint16_t linety_arithm_ins(LineTy ty)
{
    switch (ty) {
        case LineTy_Or_Imm:
            return Op_Or;
        case LineTy_And_Imm:
            return Op_And;
        case LineTy_Add_Imm:
            return Op_Add;
        case LineTy_Sub_Imm:
            return Op_Sub;
        default:
            fprintf(stderr, "error: line type '%d' not handled\n", ty);
            exit(1);
    }
}