#include "vm.h"
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

typedef struct {
    uint16_t regs[10];
    uint8_t* mem;
    Drive* boot_drive;
    uint16_t seg_count;
    uint16_t seg_size;
    IODevice* io_device;
    uint16_t int_table;
    int interrupt_timeout;
    uint16_t keyboard_port_input;
} VM;

static inline void run_arithm_ins(VM* vm, uint16_t ins);
static inline int jump_to_interrupt(VM* vm, uint16_t int_id);
static inline void maybe_update_vcd(VM* vm, uint16_t addr);
static inline uint16_t eat_uint16(VM* vm);
static inline uint16_t read_seg_uint16(VM* vm, uint16_t ptr);
static inline Op ins_op(uint16_t ins);

static inline Reg ins_dst_reg(uint16_t ins);
static inline Reg ins_op1_reg(uint16_t ins);
static inline Reg ins_op2_reg(uint16_t ins);
static inline uint16_t ins_op1(VM* vm, uint16_t ins);
static inline uint16_t ins_op2(VM* vm, uint16_t ins);
static inline uint16_t ins_reg_val_or_imm(
    VM* vm, uint16_t ins, uint16_t is_imm_bit, uint16_t reg_bit, uint16_t mask);
static inline uint16_t ins_op1_or_imm(VM* vm, uint16_t ins);
static inline uint16_t ins_op2_or_imm(VM* vm, uint16_t ins);

void vm_start(Drive* boot_drive, IODevice* io_device)
{
    const uint16_t seg_count = 16;
    const uint16_t seg_size = 4096;

    VM vm_inst = {
        .regs = { 0 },
        .mem = calloc(seg_count * seg_size, sizeof(uint8_t)),
        .boot_drive = boot_drive,
        .seg_count = seg_count,
        .seg_size = seg_size,
        .io_device = io_device,
        .int_table = 0,
        .interrupt_timeout = 10,
        .keyboard_port_input = 0,
    };

    VM* vm = &vm_inst;

    uint16_t* rip = &vm->regs[Rip];
    uint16_t* rsp = &vm->regs[Rsp];
    uint16_t* rfl = &vm->regs[Rfl];
    uint16_t* rcs = &vm->regs[Rcs];

    const uint16_t block_size = vm->boot_drive->block_size;
    for (uint16_t i = 0; i * block_size < 512; ++i) {
        vm->boot_drive->read(vm->boot_drive, &vm->mem[i * block_size], i);
    }

    while (true) {
        uint16_t ins = eat_uint16(vm);
        Op op = ins_op(ins);

        /*printf("[%3d] = %3d %s\n", *rip - 2, op, op_str(op));*/

        if (*rip >= 300) {
            printf("killed: rip >= 200\n");
            exit(0);
        }

        switch (op) {
            case Op_Nop:
                break;
            case Op_Hlt:
                vm->io_device->wait_for_interrupt(vm->io_device);
                vm->interrupt_timeout = 0;
                break;
            case Op_Jmp: {
                bool is_farjump = ins >> 7 & 1;
                if (is_farjump) {
                    uint16_t cs = ins_reg_val_or_imm(vm, ins, 8, 13, 0x7);
                    uint16_t op1 = ins_op1_or_imm(vm, ins);
                    *rcs = cs;
                    *rip = op1;
                } else {
                    uint16_t op1 = ins_op1_or_imm(vm, ins);
                    *rip = op1;
                }
                break;
            }
            case Op_Jnz: {
                uint16_t op1 = ins_op1(vm, ins);
                uint16_t op2 = ins_op2_or_imm(vm, ins);

                if (op1 != 0) {
                    *rip = op2;
                }
                break;
            }
            case Op_Test: {
                uint16_t op1 = ins_op1(vm, ins);

                if (op1 == 0) {
                    *rfl |= 1u << Fl_Zero;
                } else {
                    *rfl &= (uint16_t)~(1u << Fl_Zero);
                }
                break;
            }
            case Op_Cmp: {
                uint16_t op1 = ins_op1(vm, ins);
                uint16_t op2 = ins_op2_or_imm(vm, ins);

                if (op1 == op2) {
                    *rfl |= 1u << Fl_Eq;
                } else {
                    *rfl &= (uint16_t)~(1u << Fl_Eq);
                }
                if (op1 < op2) {
                    *rfl |= 1u << Fl_Be;
                } else {
                    *rfl &= (uint16_t)~(1u << Fl_Be);
                }
                if ((int16_t)op1 < (int16_t)op2) {
                    *rfl |= 1u << Fl_Lt;
                } else {
                    *rfl &= (uint16_t)~(1u << Fl_Lt);
                }
                break;
            }
            case Op_Mov8: {
                bool is_memory = ins >> 10 & 1;
                bool addr_is_reg = ins >> 11 & 1;
                bool is_store = ins >> 12 & 1;

                if (!is_memory) {
                    uint16_t src = ins_reg_val_or_imm(vm, ins, 6, 7, 0xf);
                    Reg dst = ins >> 12 & 0xf;
                    vm->regs[dst] = (uint8_t)src;
                    break;
                }

                uint16_t addr;
                if (addr_is_reg) {
                    Reg reg = is_store ? ins_dst_reg(ins) : ins_op2_reg(ins);
                    uint16_t offset = eat_uint16(vm);
                    addr = vm->regs[reg] + offset;
                } else {
                    addr = eat_uint16(vm);
                }

                if (is_store) {
                    uint16_t src = ins_op2_or_imm(vm, ins);
                    vm->mem[addr] = (uint8_t)src;

                    maybe_update_vcd(vm, addr);
                } else {
                    Reg reg = is_store ? ins_op2_reg(ins) : ins_dst_reg(ins);
                    vm->regs[reg] = (uint16_t)vm->mem[addr];
                }
                break;
            }
            case Op_Mov16: {
                bool is_memory = ins >> 10 & 1;
                bool addr_is_reg = ins >> 11 & 1;
                bool is_store = ins >> 12 & 1;

                if (!is_memory) {
                    uint16_t src = ins_reg_val_or_imm(vm, ins, 6, 7, 0xf);
                    Reg dst = ins >> 12 & 0xf;
                    vm->regs[dst] = src;
                    break;
                }

                uint16_t addr;
                if (addr_is_reg) {
                    Reg reg = is_store ? ins_dst_reg(ins) : ins_op2_reg(ins);
                    uint16_t offset = eat_uint16(vm);
                    addr = vm->regs[reg] + offset;
                } else {
                    addr = eat_uint16(vm);
                }

                if (addr % 2 != 0) {
                    fprintf(stderr,
                        "error: invalid address alignment, halting "
                        "execution.\n");
                    vm->regs[Rfl] |= 1 << Fl_Err;
                    goto halt_execution;
                }

                if (is_store) {
                    uint16_t src = ins_op2_or_imm(vm, ins);
                    *(uint16_t*)(&vm->mem[addr]) = src;

                    maybe_update_vcd(vm, addr);
                    maybe_update_vcd(vm, addr + 1);
                } else {
                    Reg reg = is_store ? ins_op2_reg(ins) : ins_dst_reg(ins);
                    vm->regs[reg] = *(uint16_t*)(&vm->mem[addr]);
                }
                break;
            }
            case Op_In: {
                Reg dst_reg = ins_dst_reg(ins);
                uint16_t device_id = ins_op1_or_imm(vm, ins);

                switch (device_id) {
                    case 0:
                        vm->regs[dst_reg] = vm->keyboard_port_input;
                        break;
                    default:
                        fprintf(
                            stderr, "warning: no input device %d\n", device_id);
                        break;
                }
                break;
            }
            case Op_Out: {
                uint16_t op1 = ins_op2(vm, ins);
                uint16_t device_id = ins_op1_or_imm(vm, ins);

                switch (device_id) {
                    default:
                        fprintf(stderr,
                            "warning: no output device %d\n",
                            device_id);
                        break;
                }
                break;
            }
            case Op_Lit: {
                uint16_t op2 = ins_op1_or_imm(vm, ins);
                vm->int_table = op2;
                break;
            }
            case Op_Int: {
                uint8_t int_id = (uint8_t)(ins >> 8 & 0xff);

                int res = jump_to_interrupt(vm, int_id);
                if (res != 0) {
                    goto halt_execution;
                }
                break;
            }
            case Op_IRet: {
                *rip = *(uint16_t*)&vm->mem[*rsp];
                *rsp -= 2;
                *rcs = *(uint16_t*)&vm->mem[*rsp];
                *rsp -= 2;
                break;
            }

            case Op_Or:
            case Op_Xor:
            case Op_And:
            case Op_Shl:
            case Op_RShl:
            case Op_Shr:
            case Op_RShr:
            case Op_Add:
            case Op_Sub:
            case Op_RSub:
            case Op_Mul:
            case Op_IMul:
            case Op_Div:
            case Op_IDiv:
            case Op_RDiv:
            case Op_RIDiv:
            case Op_Mod:
            case Op_RMod:
                run_arithm_ins(vm, ins);
                break;
        }

        Interrupt interrupt
            = vm->io_device->maybe_next_interrupt(vm->io_device);
        switch (interrupt.type) {
            case InterruptType_None:
                break;
            case InterruptType_Shutdown:
                goto halt_execution;
            case InterruptType_KeyEvent: {
                if (vm->interrupt_timeout <= 0 && (*rfl >> Fl_Int & 1)) {
                    int res = jump_to_interrupt(vm, 0);
                    if (res != 0) {
                        goto halt_execution;
                    }
                    vm->keyboard_port_input = interrupt.keycode;
                    vm->interrupt_timeout = 10;
                }
                break;
            }
        }
        vm->interrupt_timeout -= 1;
    }

halt_execution:
    return;
}

static inline void run_arithm_ins(VM* vm, uint16_t ins)
{
    typedef uint16_t u;
    typedef int16_t s;

    Op op = ins_op(ins);

    uint16_t op1 = ins_op1(vm, ins);
    uint16_t op2 = ins_op2_or_imm(vm, ins);
    Reg dst_reg = ins_dst_reg(ins);

    uint16_t* dst = &vm->regs[dst_reg];
    switch (op) {
        case Op_Or:
            *dst = op1 | op2;
            break;
        case Op_Xor:
            *dst = op1 ^ op2;
            break;
        case Op_And:
            *dst = op1 & op2;
            break;
        case Op_Shl:
            *dst = (u)(op1 << op2);
            break;
        case Op_RShl:
            *dst = (u)(op2 << op1);
            break;
        case Op_Shr:
            *dst = (u)(op1 >> op2);
            break;
        case Op_RShr:
            *dst = (u)(op2 >> op1);
            break;
        case Op_Add:
            *dst = op1 + op2;
            break;
        case Op_Sub:
            *dst = op1 - op2;
            break;
        case Op_RSub:
            *dst = op2 - op1;
            break;
        case Op_Mul:
            *dst = op1 * op2;
            break;
        case Op_IMul:
            *dst = (u)((s)op1 * (s)op2);
            break;
        case Op_Div:
            *dst = op1 / op2;
            break;
        case Op_IDiv:
            *dst = (u)((s)op1 / (s)op2);
            break;
        case Op_RDiv:
            *dst = op2 / op1;
            break;
        case Op_RIDiv:
            *dst = (u)((s)op2 / (s)op1);
            break;
        case Op_Mod:
            *dst = op1 % op2;
            break;
        case Op_RMod:
            *dst = op2 % op1;
            break;
        default:
            break;
    }
}

static inline int jump_to_interrupt(VM* vm, uint16_t int_id)
{
    uint16_t* rip = &vm->regs[Rip];
    uint16_t* rsp = &vm->regs[Rsp];
    uint16_t* rfl = &vm->regs[Rfl];
    uint16_t* rcs = &vm->regs[Rcs];

    if ((*rfl >> Fl_Int & 1) == 0) {
        fprintf(stderr, "error: interrupt with unset flag\n");
        vm->regs[Rfl] |= 1 << Fl_Err;
        return -1;
    }

    uint16_t int_table_size = *(uint16_t*)&vm->mem[vm->int_table];

    if (int_id >= int_table_size) {
        fprintf(stderr, "error: interrupt outside table (%d)\n", int_id);
        vm->regs[Rfl] |= 1 << Fl_Err;
        return -1;
    }

    uint16_t int_addr = *(uint16_t*)&vm->mem[vm->int_table + int_id * 2 + 2];

    *rsp += 2;
    *(uint16_t*)&vm->mem[*rsp] = *rcs;
    *rsp += 2;
    *(uint16_t*)&vm->mem[*rsp] = *rip;

    *rcs = 0;
    *rip = int_addr;

    return 0;
}

static inline void maybe_update_vcd(VM* vm, uint16_t addr)
{
    if (!vm->io_device)
        return;
    if ((vm->regs[Rfl] >> Fl_Vcd & 1) == 0)
        return;
    if (!(addr >= VCD_BUFFER_OFFSET
            && addr < VCD_BUFFER_OFFSET + VCD_BUFFER_SIZE))
        return;
    uint16_t offset = addr - VCD_BUFFER_OFFSET;
    vm->io_device->set_char(vm->io_device, offset, vm->mem[addr]);
}

static inline uint16_t eat_uint16(VM* vm)
{
    uint16_t* rip = &vm->regs[Rip];
    uint16_t ins = read_seg_uint16(vm, *rip);
    *rip += 2;
    return ins;
}

static inline uint16_t read_seg_uint16(VM* vm, uint16_t ptr)
{
    return *(uint16_t*)&vm->mem[vm->regs[Rcs] * vm->seg_size + ptr];
}

static inline Op ins_op(uint16_t ins)
{
    return ins & 0x3F;
}

static inline Reg ins_dst_reg(uint16_t ins)
{
    return ins >> 13 & 0x7;
}

static inline Reg ins_op1_reg(uint16_t ins)
{
    return ins >> 10 & 0x7;
}

static inline Reg ins_op2_reg(uint16_t ins)
{
    return ins >> 7 & 0x7;
}

static inline uint16_t ins_op1(VM* vm, uint16_t ins)
{
    return vm->regs[ins_op1_reg(ins)];
}

static inline uint16_t ins_op2(VM* vm, uint16_t ins)
{
    return vm->regs[ins_op2_reg(ins)];
}

static inline uint16_t ins_reg_val_or_imm(
    VM* vm, uint16_t ins, uint16_t is_imm_bit, uint16_t reg_bit, uint16_t mask)
{
    bool is_imm = (ins >> is_imm_bit & 1) != 0;
    if (is_imm) {
        return eat_uint16(vm);
    } else {
        return vm->regs[ins >> reg_bit & mask];
    }
}

static inline uint16_t ins_op1_or_imm(VM* vm, uint16_t ins)
{
    return ins_reg_val_or_imm(vm, ins, 6, 10, 0x7);
}

static inline uint16_t ins_op2_or_imm(VM* vm, uint16_t ins)
{
    return ins_reg_val_or_imm(vm, ins, 6, 7, 0x7);
}

const char* op_str(Op op)
{
    switch (op) {
        case Op_Nop:
            return "nop";
        case Op_Hlt:
            return "hlt";
        case Op_Jmp:
            return "jmp";
        case Op_Jnz:
            return "jnz";
        case Op_Test:
            return "test";
        case Op_Cmp:
            return "cmp";
        case Op_Mov8:
            return "mov8";
        case Op_Mov16:
            return "mov16";
        case Op_In:
            return "in";
        case Op_Out:
            return "out";
        case Op_Lit:
            return "lit";
        case Op_Int:
            return "int";
        case Op_IRet:
            return "iret";
        case Op_Or:
            return "or";
        case Op_Xor:
            return "xor";
        case Op_And:
            return "and";
        case Op_Shl:
            return "shl";
        case Op_RShl:
            return "rshl";
        case Op_Shr:
            return "shr";
        case Op_RShr:
            return "rshr";
        case Op_Add:
            return "add";
        case Op_Sub:
            return "sub";
        case Op_RSub:
            return "rsub";
        case Op_Mul:
            return "mul";
        case Op_IMul:
            return "imul";
        case Op_Div:
            return "div";
        case Op_IDiv:
            return "idiv";
        case Op_RDiv:
            return "rdiv";
        case Op_RIDiv:
            return "ridiv";
        case Op_Mod:
            return "mod";
        case Op_RMod:
            return "rmod";
    }
    return "---";
}