c_with_templates/main.cpp
2026-01-05 16:33:31 +01:00

632 lines
15 KiB
C++

#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <string.h>
#include <ctype.h>
template <typename V, typename E>
struct Result {
bool ok;
union {
V value;
E error;
};
};
template <typename V, typename E>
Result<V, E> result_ok(V value)
{
return Result<V, E> {
.ok = true,
.value = value,
};
}
template <typename V, typename E>
Result<V, E> result_error(E error)
{
return Result<V, E> {
.ok = false,
.error = error,
};
}
template <typename T>
using Deallocator = void (*)(T* ptr);
template <typename T>
struct Own {
T* ptr;
Deallocator<T> dealloc;
};
template <typename T>
Own<T> own(T* ptr, Deallocator<T>* dealloc)
{
return Own<T> { ptr, dealloc };
}
template <typename T>
Own<T> make_own(T init, Deallocator<T> dealloc)
{
T* ptr = (T*)malloc(sizeof(T));
*ptr = init;
return Own<T> { ptr, dealloc };
}
template <typename T>
void own_dealloc(Own<T> own)
{
own.dealloc(own.ptr);
}
template <typename T>
struct Vec {
T* data;
size_t capacity;
size_t len;
};
template <typename T>
void vec_init(Vec<T>* vec)
{
*vec = Vec<T> {
.data = nullptr,
.capacity = 0,
.len = 0,
};
}
template <typename T>
void vec_reserve(Vec<T>* vec, size_t min_size)
{
if (!vec->data) {
vec->capacity = 8;
vec->data = (T*)malloc(sizeof(T) * vec->capacity);
}
if (min_size > vec->capacity) {
while (min_size > vec->capacity) {
vec->capacity *= 2;
}
vec->data = (T*)realloc(vec->data, sizeof(T) * vec->capacity);
}
}
template <typename T>
void vec_push(Vec<T>* vec, T v)
{
vec_reserve(vec, vec->len + 1);
vec->data[vec->len] = v;
vec->len += 1;
}
template <typename T>
void vec_deinit(Vec<T>* vec)
{
if (vec->data) {
free(vec->data);
}
}
typedef Vec<char> String;
void string_init(String* str)
{
vec_init(str);
vec_push(str, '\0');
}
void string_deinit(String* str)
{
vec_deinit(str);
}
void string_push(String* str, char ch)
{
vec_reserve(str, str->len + 2);
str->data[str->len] = ch;
str->data[str->len + 1] = '\0';
str->len += 1;
}
template <typename T>
struct Span {
const T* data;
size_t len;
};
template <typename T>
Span<T> span(const T* data, size_t len)
{
return Span<T> { data, len };
}
typedef Span<char> StrView;
StrView string_to_view(const String* str)
{
return span(str->data, str->len);
}
StrView string_slice(const String* str, size_t begin, size_t count)
{
return span(str->data + begin, count);
}
StrView strview_slice(StrView str, size_t begin, size_t count)
{
return span(str.data + begin, count);
}
void string_from_strview(String* string, StrView view) {
vec_reserve(string, view.len + 1);
strncpy(string->data, view.data, view.len);
string->len = view.len;
string->data[string->len] = '\0';
}
template <typename... Args>
int format(String* str, const char* fmt, Args... args) {
int res = snprintf(nullptr, 0, fmt, args...);
if (res < 0)
return 1;
size_t size = (size_t)res;
vec_reserve(str, size + 1);
snprintf(str->data, size + 1, fmt, args...);
str->len = size;
str->data[str->len] = '\0';
return 0;
}
template <typename... Args>
void string_cat_fmt(String* str, const char* fmt, Args... args)
{
String formatted;
string_init(&formatted);
format(&formatted, fmt, args...);
vec_reserve(str, str->len + formatted.len + 1);
strncat(str->data, formatted.data, formatted.len);
str->len += formatted.len;
str->data[str->len] = '\0';
string_deinit(&formatted);
}
Result<size_t, String> read_file_to_string(String* str, const char* filename)
{
FILE* file = fopen(filename, "r");
if (!file) {
String error_str;
string_init(&error_str);
format(&error_str, "couldn't open file '%s': %s", filename, strerror(errno));
return result_error<size_t, String>(error_str);
}
fseek(file, 0, SEEK_END);
size_t file_size = (size_t)ftell(file);
rewind(file);
vec_reserve(str, file_size + 1);
size_t bytes_read = fread(str->data, 1, file_size, file);
str->len = bytes_read;
str->data[str->len] = '\0';
fclose(file);
if (bytes_read != file_size) {
String error_str;
string_init(&error_str);
format(&error_str, "failed to read file");
return result_error<size_t, String>(error_str);
}
return result_ok<size_t, String>(bytes_read);
}
enum TokTy {
TT_Eof,
TT_Ident,
TT_Int,
TT_LParen = '(',
TT_RParen = ')',
TT_Plus = '+',
TT_Comma = ',',
};
struct Tok {
TokTy ty;
StrView text;
int line;
};
struct Lexer {
StrView text;
size_t idx;
int line;
char ch;
};
void lexer_init(Lexer* lexer, StrView text)
{
*lexer = Lexer {
.text = text,
.idx = 0,
.line = 1,
.ch = text.data[0],
};
}
static bool lexer_done(const Lexer* lexer)
{
return lexer->idx >= lexer->text.len;
}
static void lexer_step(Lexer* lexer)
{
if (lexer_done(lexer))
return;
if (lexer->ch == '\n') {
lexer->line += 1;
}
lexer->idx += 1;
lexer->ch = lexer->text.data[lexer->idx];
}
static Tok lexer_tok(Lexer* lexer, TokTy ty, size_t idx, int line)
{
return Tok { ty, strview_slice(lexer->text, idx, lexer->idx - idx), line };
}
Tok lexer_next(Lexer* lexer)
{
size_t idx = lexer->idx;
int line = lexer->line;
if (lexer_done(lexer)) {
return lexer_tok(lexer, TT_Eof, idx, line);
}
if (isspace(lexer->ch)) {
while (!lexer_done(lexer) && isspace(lexer->ch)) {
lexer_step(lexer);
}
return lexer_next(lexer);
}
if (isdigit(lexer->ch)) {
while (!lexer_done(lexer) && isdigit(lexer->ch)) {
lexer_step(lexer);
}
return lexer_tok(lexer, TT_Int, idx, line);
}
if (isalpha(lexer->ch)) {
while (!lexer_done(lexer) && isalnum(lexer->ch)) {
lexer_step(lexer);
}
return lexer_tok(lexer, TT_Ident, idx, line);
}
const char* static_toks = "()+,";
for (size_t i = 0; i < strlen(static_toks); ++i) {
if (lexer->ch == static_toks[i]) {
lexer_step(lexer);
return lexer_tok(lexer, (TokTy)static_toks[i], idx, line);
}
}
fprintf(stderr, "error: illegal character '%c' on line %d\n", lexer->ch, line);
lexer_step(lexer);
return lexer_next(lexer);
}
enum ExprTy {
ET_Error,
ET_Ident,
ET_Int,
ET_Call,
ET_Add,
};
struct Expr;
struct CallExpr {
Own<Expr> expr;
Vec<Own<Expr>> args;
};
struct AddExpr {
Own<Expr> left;
Own<Expr> right;
};
struct Expr {
ExprTy ty;
int line;
union {
int nothing;
String ident_value;
int64_t int_value;
CallExpr call_expr;
AddExpr add_expr;
};
};
static void expr_vec_deinit(Vec<Own<Expr>>* exprs);
void expr_deinit(Expr* expr)
{
switch (expr->ty) {
case ET_Error:
break;
case ET_Ident:
string_deinit(&expr->ident_value);
break;
case ET_Int:
break;
case ET_Call:
own_dealloc(expr->call_expr.expr);
expr_vec_deinit(&expr->call_expr.args);
break;
case ET_Add:
own_dealloc(expr->add_expr.left);
own_dealloc(expr->add_expr.right);
break;
}
}
static void expr_vec_deinit(Vec<Own<Expr>>* exprs)
{
for (size_t i = 0; i < exprs->len; ++i) {
own_dealloc(exprs->data[i]);
}
vec_deinit(exprs);
}
void expr_free(Expr* expr)
{
expr_deinit(expr);
free(expr);
}
void expr_to_string(String* str, const Expr* expr)
{
switch (expr->ty){
case ET_Error:
string_cat_fmt(str, "<error>");
break;
case ET_Ident:
string_cat_fmt(str, "Ident(\"%s\")", expr->ident_value.data);
break;
case ET_Int:
string_cat_fmt(str, "Int(%ld)", expr->int_value);
break;
case ET_Call:
string_cat_fmt(str, "Call { expr: ");
expr_to_string(str, expr->call_expr.expr.ptr);
string_cat_fmt(str, ", args: [");
for (size_t i = 0; i < expr->call_expr.args.len; ++i) {
if (i != 0) {
string_cat_fmt(str, ", ");
}
expr_to_string(str, expr->call_expr.args.data[i].ptr);
}
string_cat_fmt(str, "] }");
break;
case ET_Add:
string_cat_fmt(str, "Add { left: ");
expr_to_string(str, expr->add_expr.left.ptr);
string_cat_fmt(str, ", right: ");
expr_to_string(str, expr->add_expr.right.ptr);
string_cat_fmt(str, " }");
break;
}
}
struct Parser {
Lexer lexer;
Tok tok;
};
void parser_init(Parser* parser, StrView text)
{
Lexer lexer;
lexer_init(&lexer, text);
Tok tok = lexer_next(&lexer);
*parser = Parser {
.lexer = lexer,
.tok = tok,
};
}
static Own<Expr> parser_parse_postfix(Parser* parser);
static Own<Expr> parser_parse_prefix(Parser* parser);
static Own<Expr> parser_parse_operand(Parser* parser);
static void parser_step(Parser* parser)
{
parser->tok = lexer_next(&parser->lexer);
}
Own<Expr> parser_parse_expr(Parser* parser)
{
return parser_parse_prefix(parser);
}
static Own<Expr> parser_parse_prefix(Parser* parser)
{
int line = parser->tok.line;
if (parser->tok.ty == '+') {
parser_step(parser);
Own<Expr> left = parser_parse_prefix(parser);
Own<Expr> right = parser_parse_prefix(parser);
return make_own<Expr>(Expr {
.ty = ET_Add,
.line = line,
.add_expr = AddExpr {
.left = left,
.right = right,
}
}, expr_free);
}
return parser_parse_postfix(parser);
}
static Own<Expr> parser_parse_postfix(Parser* parser)
{
int line = parser->tok.line;
Own<Expr> expr = parser_parse_operand(parser);
while (true) {
if (parser->tok.ty == '(') {
parser_step(parser);
Vec<Own<Expr>> args;
vec_init(&args);
if (parser->tok.ty != TT_Eof && parser->tok.ty != ')') {
vec_push(&args, parser_parse_expr(parser));
while (parser->tok.ty != TT_Eof && parser->tok.ty != ')') {
if (parser->tok.ty != TT_Comma) {
fprintf(stderr, "error: expected ',' on line %d\n", line);
return make_own<Expr>(
Expr { .ty = ET_Error, .line = line, .nothing = 0 }, expr_free);
}
parser_step(parser);
if (parser->tok.ty == TT_Eof || parser->tok.ty == ')')
break;
vec_push(&args, parser_parse_expr(parser));
}
}
if (parser->tok.ty != ')') {
fprintf(stderr, "error: expected ')' on line %d\n", line);
return make_own<Expr>(
Expr { .ty = ET_Error, .line = line, .nothing = 0 }, expr_free);
}
parser_step(parser);
expr = make_own<Expr>(Expr {
.ty = ET_Call,
.line = line,
.call_expr = CallExpr {
.expr = expr,
.args = args,
},
}, expr_free);
continue;
}
break;
}
return expr;
}
static Own<Expr> parser_parse_operand(Parser* parser)
{
int line = parser->tok.line;
if (parser->tok.ty == TT_Ident) {
String value;
string_init(&value);
string_from_strview(&value, parser->tok.text);
parser_step(parser);
return make_own<Expr>(Expr {
.ty = ET_Ident,
.line = line,
.ident_value = value,
}, expr_free);
} else if (parser->tok.ty == TT_Int) {
String text;
string_init(&text);
string_from_strview(&text, parser->tok.text);
int64_t value = strtoll(text.data, nullptr, 10);
string_deinit(&text);
parser_step(parser);
return make_own<Expr>(Expr {
.ty = ET_Int,
.line = line,
.int_value = value,
}, expr_free);
} else {
if (parser->tok.ty == TT_Eof) {
fprintf(stderr, "error: expected expression, got EOF on line %d\n", parser->tok.line);
} else {
fprintf(stderr, "error: expected expression, got '%.*s' on line %d\n", (int)parser->tok.text.len, parser->tok.text.data, parser->tok.line);
}
parser_step(parser);
}
return make_own<Expr>(
Expr { .ty = ET_Error, .line = line, .nothing = 0 }, expr_free);
}
int64_t eval_expr(const Expr* expr)
{
switch (expr->ty) {
case ET_Error:
return -1;
case ET_Ident:
return -1;
case ET_Int:
return expr->int_value;
case ET_Call:
if (expr->call_expr.expr.ptr->ty == ET_Ident) {
if (strcmp(expr->call_expr.expr.ptr->ident_value.data, "print") == 0) {
int64_t value = eval_expr(expr->call_expr.args.data[0].ptr);
printf("%ld\n", value);
return 0;
}
}
return -1;
case ET_Add: {
int64_t left = eval_expr(expr->add_expr.left.ptr);
int64_t right = eval_expr(expr->add_expr.right.ptr);
return left + right;
}
}
return -1;
}
int main(int argc, const char** argv)
{
if (argc <= 1) {
fprintf(stderr, "error: no filename\n");
return EXIT_FAILURE;
}
const char* filename = argv[1];
String text;
string_init(&text);
auto result = read_file_to_string(&text, filename);
if (!result.ok) {
fprintf(stderr, "error: %s\n", result.error.data);
string_deinit(&result.error);
return EXIT_FAILURE;
}
printf("=== text ===\n%s\n", text.data);
Lexer lexer;
lexer_init(&lexer, string_to_view(&text));
printf("=== tokens ===\n");
Tok tok;
while ((tok = lexer_next(&lexer)).ty != TT_Eof) {
printf("%d\t\"%.*s\"\n", tok.ty, (int)tok.text.len, tok.text.data);
}
printf("=== ast ===\n");
Parser parser;
parser_init(&parser, string_to_view(&text));
Own<Expr> ast = parser_parse_expr(&parser);
String ast_string;
string_init(&ast_string);
expr_to_string(&ast_string, ast.ptr);
printf("%s\n", ast_string.data);
string_deinit(&ast_string);
printf("=== eval ===\n");
eval_expr(ast.ptr);
own_dealloc(ast);
string_deinit(&text);
}