#include "lex.h"
#include "report.h"
#include "str.h"
#include <string.h>

void lexer_construct(Lexer* lexer, const char* filename, const char* text)
{
    *lexer = (Lexer) {
        .filename = filename,
        .text = text,
        .text_len = strlen(text),
        .idx = 0,
        .line = 1,
        .col = 1,
        .ch = text[0],
        .error_occured = false,
    };
}

static inline bool lexer_done(const Lexer* lexer)
{
    return lexer->idx >= lexer->text_len;
}

static inline void lexer_step(Lexer* lexer)
{
    if (lexer_done(lexer)) {
        return;
    }
    if (lexer->ch == '\n') {
        lexer->line += 1;
        lexer->col = 1;
    } else {
        lexer->col += 1;
    }
    lexer->idx += 1;
    lexer->ch = lexer->text[lexer->idx];
}

static inline Loc lexer_loc(const Lexer* lexer)
{
    return (Loc) { .idx = lexer->idx, .line = lexer->line, .col = lexer->col };
}

static inline Tok lexer_tok(const Lexer* lexer, TokTy ty, Loc loc)
{
    return (Tok) { .ty = ty, .loc = loc, .len = lexer->idx - loc.idx };
}

static inline int lexer_skip_literal_char(Lexer* lexer)
{
    char ch = lexer->ch;
    lexer_step(lexer);
    if (ch == '\\') {
        if (lexer_done(lexer))
            return -1;
        lexer_step(lexer);
    }
    return 0;
}

static inline void lexer_report(Lexer* lexer, const char* msg, Loc loc)
{
    lexer->error_occured = true;
    REPORTF_ERROR("%s", msg);
    print_report_loc(lexer->filename, lexer->text, lexer->text_len, loc);
}

Tok lexer_next(Lexer* lexer)
{
    const char* ident_chars = "abcdefghijklmnopqrstuvwxyz"
                              "ABCDEFGHIJKLMNOPQRSTUVWXYZ_$";
    const char* int_chars = "1234567890";
    const char* hex_chars = "01234567889abcdefABCDEF";

    Loc loc = lexer_loc(lexer);
    if (lexer_done(lexer)) {
        return lexer_tok(lexer, TT_Eof, loc);
    }
    if (lexer->ch == '\n') {
        lexer_step(lexer);
        return lexer_tok(lexer, '\n', loc);
    } else if (str_includes(" \t", lexer->ch)) {
        while (!lexer_done(lexer) && str_includes(" \t", lexer->ch)) {
            lexer_step(lexer);
        }
        return lexer_next(lexer);
    } else if (str_includes(ident_chars, lexer->ch)) {
        while (!lexer_done(lexer)
            && (str_includes(ident_chars, lexer->ch)
                || str_includes(int_chars, lexer->ch))) {
            lexer_step(lexer);
        }
        return lexer_tok(lexer, TT_Ident, loc);
    } else if (str_includes(int_chars, lexer->ch) && lexer->ch != '0') {
        while (!lexer_done(lexer) && (str_includes(int_chars, lexer->ch))) {
            lexer_step(lexer);
        }
        return lexer_tok(lexer, TT_Int, loc);
    } else if (lexer->ch == ';') {
        while (!lexer_done(lexer) && lexer->ch != '\n') {
            lexer_step(lexer);
        }
        return lexer_next(lexer);
    } else if (lexer->ch == '0') {
        lexer_step(lexer);
        if (lexer->ch == 'b') {
            lexer_step(lexer);
            if (lexer_done(lexer) || !str_includes("01", lexer->ch)) {
                lexer_report(lexer, "malformed binary literal", loc);
                return lexer_tok(lexer, TT_Err, loc);
            }
            while (!lexer_done(lexer) && str_includes("01", lexer->ch)) {
                lexer_step(lexer);
            }
            return lexer_tok(lexer, TT_Binary, loc);
        } else if (lexer->ch == 'x') {
            lexer_step(lexer);
            if (lexer_done(lexer) || !str_includes(hex_chars, lexer->ch)) {
                lexer_report(lexer, "malformed hex literal", loc);
                return lexer_tok(lexer, TT_Err, loc);
            }
            while (!lexer_done(lexer) && str_includes(hex_chars, lexer->ch)) {
                lexer_step(lexer);
            }
            return lexer_tok(lexer, TT_Hex, loc);

        } else {
            return lexer_tok(lexer, TT_Int, loc);
        }
    } else if (lexer->ch == '\'') {
        lexer_step(lexer);
        lexer_skip_literal_char(lexer);
        if (lexer_done(lexer) || lexer->ch != '\'') {
            lexer_report(lexer, "malformed character literal", loc);
            return lexer_tok(lexer, TT_Err, loc);
        }
        lexer_step(lexer);
        return lexer_tok(lexer, TT_Char, loc);
    } else if (lexer->ch == '"') {
        lexer_step(lexer);
        while (!lexer_done(lexer) && lexer->ch != '"') {
            lexer_skip_literal_char(lexer);
        }
        if (lexer_done(lexer) || lexer->ch != '"') {
            lexer_report(lexer, "malformed string literal", loc);
            return lexer_tok(lexer, TT_Err, loc);
        }
        lexer_step(lexer);
        return lexer_tok(lexer, TT_Str, loc);
    } else if (lexer->ch == '<') {
        lexer_step(lexer);
        if (!lexer_done(lexer) && lexer->ch == '<') {
            lexer_step(lexer);
            return lexer_tok(lexer, TT_DoubleLt, loc);
        } else {
            lexer_report(lexer, "expected '<'", loc);
            return lexer_tok(lexer, TT_Err, loc);
        }
    } else if (lexer->ch == '>') {
        lexer_step(lexer);
        if (!lexer_done(lexer) && lexer->ch == '>') {
            lexer_step(lexer);
            return lexer_tok(lexer, TT_DoubleGt, loc);
        } else {
            lexer_report(lexer, "expected '>'", loc);
            return lexer_tok(lexer, TT_Err, loc);
        }
    } else if (str_includes("|^&+-*/%()[].,:!", lexer->ch)) {
        char ch = lexer->ch;
        lexer_step(lexer);
        return lexer_tok(lexer, (TokTy)ch, loc);
    } else {
        lexer_report(lexer, "illegal character", loc);
        lexer_step(lexer);
        return lexer_tok(lexer, TT_Err, loc);
    }
}