From 4ba7a7509d398df55e10274a24985e63e6723ad9 Mon Sep 17 00:00:00 2001 From: Bad Diode Date: Fri, 28 Jun 2024 13:50:56 +0200 Subject: Add bytecode compilation for strings and booleans --- Makefile | 2 +- src/badlib.h | 1 + src/main.c | 1301 ++-------------------------------------- src/parser.c | 5 +- src/semantic.c | 1592 +++++++++++++++++++++++++++++++++++-------------- src/vm.c | 90 ++- tests/compilation.bad | 10 + 7 files changed, 1257 insertions(+), 1744 deletions(-) create mode 100644 tests/compilation.bad diff --git a/Makefile b/Makefile index ff8a6c2..c3ab930 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ BUILD_DIR := build TESTS_DIR := tests TEST_FILES := $(wildcard $(TESTS_DIR)/*.bad) SRC_MAIN := $(SRC_DIR)/main.c -SRC_RUN := tests/semantics.bad +SRC_RUN := tests/compilation.bad WATCH_SRC := $(shell find $(SRC_DIR) -name "*.c" -or -name "*.s" -or -name "*.h") INC_DIRS := $(shell find $(SRC_DIR) -type d) INC_FLAGS := $(addprefix -I,$(INC_DIRS)) diff --git a/src/badlib.h b/src/badlib.h index e6f0df6..25e4914 100644 --- a/src/badlib.h +++ b/src/badlib.h @@ -955,6 +955,7 @@ SETDEF(StrSet, strset, Str, str_hash, str_eq) MAPDEF(StrIntMap, strintmap, Str, sz, str_hash, str_eq) SETDEF(IntSet, intset, sz, _int_hash, _int_eq) MAPDEF(IntStrMap, intstrmap, sz, Str, _int_hash, _int_eq) +MAPDEF(IntIntMap, intintmap, sz, sz, _int_hash, _int_eq) // // Dynamic arrays. diff --git a/src/main.c b/src/main.c index cb93649..fdbf4e0 100644 --- a/src/main.c +++ b/src/main.c @@ -5,6 +5,7 @@ #include "badlib.h" #include "lexer.c" #include "parser.c" +#include "semantic.c" #include "vm.c" // TODO: unions @@ -27,1238 +28,6 @@ init(void) { log_init_default(); } -typedef enum { - SYM_UNKNOWN, - SYM_BUILTIN_FUN, - SYM_BUILTIN_TYPE, - SYM_FUN, - SYM_VAR, - SYM_PARAM, - SYM_ENUM, - SYM_ENUM_FIELD, - SYM_STRUCT, - SYM_STRUCT_FIELD, -} SymbolKind; - -Str sym_kind_str[] = { - [SYM_UNKNOWN] = cstr("UNKNOWN "), - [SYM_BUILTIN_FUN] = cstr("BUILTIN FUN "), - [SYM_BUILTIN_TYPE] = cstr("BUILTIN TYPE "), - [SYM_FUN] = cstr("FUNCTION "), - [SYM_VAR] = cstr("VARIABLE "), - [SYM_PARAM] = cstr("PARAMETER "), - [SYM_ENUM] = cstr("ENUM "), - [SYM_ENUM_FIELD] = cstr("ENUM FIELD "), - [SYM_STRUCT] = cstr("STRUCT "), - [SYM_STRUCT_FIELD] = cstr("STRUCT FIELD "), -}; - -typedef struct Symbol { - Str name; - SymbolKind kind; -} Symbol; - -typedef struct Fun { - Str name; - Str param_type; - Str return_type; -} Fun; - -typedef struct Enum { - Str name; - Node *val; -} Enum; - -typedef struct Struct { - Str name; - Str type; - Node *val; -} Struct; - -MAPDEF(SymbolMap, symmap, Str, Symbol, str_hash, str_eq) -MAPDEF(FunMap, funmap, Str, Fun, str_hash, str_eq) -MAPDEF(EnumMap, enummap, Str, Enum, str_hash, str_eq) -MAPDEF(StructMap, structmap, Str, Struct, str_hash, str_eq) - -typedef struct Scope { - sz id; - sz depth; - Str name; - SymbolMap *symbols; - FunMap *funcs; - EnumMap *enums; - StructMap *structs; - struct Scope *parent; -} Scope; - -typedef struct Analyzer { - Arena *storage; - Str file_name; - sz typescope_gen; - Scope **scopes; - StrSet *numeric_types; - StrSet *integer_types; - bool err; -} Analyzer; - -Scope * -typescope_alloc(Analyzer *a, Scope *parent) { - Scope *scope = arena_calloc(sizeof(Scope), a->storage); - scope->parent = parent; - scope->id = a->typescope_gen++; - scope->depth = parent == NULL ? 0 : parent->depth + 1; - array_push(a->scopes, scope, a->storage); - return scope; -} - -SymbolMap * -find_type(Scope *scope, Str type) { - while (scope != NULL) { - SymbolMap *val = symmap_lookup(&scope->symbols, type); - if (val != NULL) { - return val; - } - scope = scope->parent; - } - return NULL; -} - -FunMap * -find_fun(Scope *scope, Str type) { - while (scope != NULL) { - FunMap *val = funmap_lookup(&scope->funcs, type); - if (val != NULL) { - return val; - } - scope = scope->parent; - } - return NULL; -} - -typedef struct FindEnumResult { - EnumMap *map; - Scope *scope; -} FindEnumResult; - -FindEnumResult -find_enum(Scope *scope, Str type) { - while (scope != NULL) { - EnumMap *val = enummap_lookup(&scope->enums, type); - if (val != NULL) { - return (FindEnumResult){.map = val, .scope = scope}; - } - scope = scope->parent; - } - return (FindEnumResult){0}; -} - -typedef struct FindStructResult { - StructMap *map; - Scope *scope; -} FindStructResult; - -FindStructResult -find_struct(Scope *scope, Str type) { - while (scope != NULL) { - StructMap *val = structmap_lookup(&scope->structs, type); - if (val != NULL) { - return (FindStructResult){.map = val, .scope = scope}; - } - scope = scope->parent; - } - return (FindStructResult){0}; -} - -void -graph_typescope(Scope *scope, Arena a) { - if (!scope->symbols) { - return; - } - SymbolMapIter iter = symmap_iterator(scope->symbols, &a); - SymbolMap *type = symmap_next(&iter, &a); - print( - "%d[shape=\"none\" label=<", - scope->id); - print( - "" - "" - "" - ""); - while (type) { - print( - "" - "" - "" - "", - type->key, type->val.name); - type = symmap_next(&iter, &a); - } - println("
NAME TYPE
%s %s
>];"); - - sz this_id = scope->id; - while (scope->parent) { - if (scope->parent->symbols) { - println("%d:e->%d:w;", this_id, scope->parent->id); - break; - } else { - scope = scope->parent; - } - } -} - -void -graph_functions(Scope *scope, Arena a) { - if (!scope->funcs) { - return; - } - FunMapIter iter = funmap_iterator(scope->funcs, &a); - FunMap *func = funmap_next(&iter, &a); - print( - "fun_%d[shape=\"none\" label=<", - scope->id); - print( - "" - "" - "" - "" - ""); - while (func) { - print( - "" - "" - "" - "" - "", - func->val.name, func->val.name, func->val.param_type, - func->val.return_type); - func = funmap_next(&iter, &a); - } - println("
NAME PARAMS RETURN
%s %s %s
>];"); - sz this_id = scope->id; - while (scope->parent) { - if (scope->parent->symbols) { - println("fun_%d:e->fun_%d:%s:w;", this_id, scope->parent->id, - scope->name); - break; - } else { - scope = scope->parent; - } - } -} - -void -graph_types(Scope **scopes, Arena a) { - if (scopes == NULL) return; - println("digraph types {"); - println("rankdir=LR;"); - println("ranksep=\"0.95 equally\";"); - println("nodesep=\"0.5 equally\";"); - println("overlap=scale;"); - println("bgcolor=\"transparent\";"); - for (sz i = 0; i < array_size(scopes); i++) { - Scope *scope = scopes[i]; - if (!scope) { - continue; - } - println("subgraph %d {", i); - graph_typescope(scope, a); - graph_functions(scope, a); - println("}"); - } - println("}"); -} - -void -emit_semantic_error(Analyzer *a, Node *n, Str msg) { - eprintln("%s:%d:%d: error: %s", a->file_name, n->line, n->col, msg); - a->err = true; -} - -Str type_inference(Analyzer *a, Node *node, Scope *scope); - -void -typecheck_field(Analyzer *a, Node *node, Scope *scope, Str symbol) { - if (node->field_type->kind == NODE_COMPOUND_TYPE) { - Str field_name = str_concat(symbol, cstr("."), a->storage); - field_name = str_concat(field_name, node->value.str, a->storage); - if (structmap_lookup(&scope->structs, field_name)) { - eprintln("%s:%d:%d: error: struct field '%s' already exists", - a->file_name, node->line, node->col, field_name); - a->err = true; - } - Str type = cstr("\\{ "); - for (sz i = 0; i < array_size(node->field_type->elements); i++) { - Node *field = node->field_type->elements[i]; - typecheck_field(a, field, scope, field_name); - type = str_concat(type, field->type, a->storage); - type = str_concat(type, cstr(" "), a->storage); - } - type = str_concat(type, cstr("\\}"), a->storage); - node->type = type; - } else { - Str field_name = str_concat(symbol, cstr("."), a->storage); - field_name = str_concat(field_name, node->value.str, a->storage); - Str field_type = node->field_type->value.str; - if (!find_type(scope, field_type)) { - eprintln("%s:%d:%d: error: unknown type '%s'", a->file_name, - node->field_type->line, node->field_type->col, field_type); - a->err = true; - } - if (node->field_type->is_ptr) { - field_type = str_concat(cstr("@"), field_type, a->storage); - } - if (node->field_type->kind == NODE_ARR_TYPE) { - field_type = str_concat(cstr("@"), field_type, a->storage); - } - if (structmap_lookup(&scope->structs, field_name)) { - eprintln("%s:%d:%d: error: struct field '%s' already exists", - a->file_name, node->line, node->col, field_name); - a->err = true; - } - if (node->field_val) { - Str type = type_inference(a, node->field_val, scope); - if (!str_eq(type, field_type)) { - eprintln( - "%s:%d:%d: error: mismatched types in struct " - "value " - "for '%s': %s expected %s", - a->file_name, node->line, node->col, field_name, type, - field_type); - a->err = true; - } - } - structmap_insert(&scope->structs, field_name, - (Struct){ - .name = field_name, - .type = field_type, - .val = node->field_val, - }, - a->storage); - symmap_insert(&scope->symbols, field_name, - (Symbol){.name = field_type, .kind = SYM_STRUCT_FIELD}, - a->storage); - node->type = field_type; - } -} - -void -typecheck_lit_field(Analyzer *a, Node *node, Scope *scope, Str symbol) { - if (node->field_val->kind == NODE_COMPOUND_TYPE) { - Str type = cstr("\\{ "); - for (sz i = 0; i < array_size(node->field_val->elements); i++) { - Node *field = node->field_val->elements[i]; - Str field_name = str_concat(symbol, cstr("."), a->storage); - field_name = str_concat(field_name, field->value.str, a->storage); - typecheck_lit_field(a, field, scope, field_name); - type = str_concat(type, field->type, a->storage); - type = str_concat(type, cstr(" "), a->storage); - } - type = str_concat(type, cstr("\\}"), a->storage); - node->type = type; - } else { - StructMap *s = structmap_lookup(&scope->structs, symbol); - if (!s) { - eprintln("%s:%d:%d: error: unknown struct field '%s'", a->file_name, - node->line, node->col, symbol); - a->err = true; - return; - } - Str field_type = s->val.type; - Str type = type_inference(a, node->field_val, scope); - if (!str_eq(type, field_type)) { - eprintln( - "%s:%d:%d: error: mismatched types in struct " - "value " - "for '%s': %s expected %s", - a->file_name, node->line, node->col, symbol, type, field_type); - a->err = true; - } - node->type = field_type; - } -} - -void -typecheck_returns(Analyzer *a, Node *node, Str expected) { - if (!node) { - return; - } - // Traverse the tree again. - switch (node->kind) { - case NODE_COND: - case NODE_MATCH: { - for (sz i = 0; i < array_size(node->match_cases); i++) { - Node *next = node->match_cases[i]; - typecheck_returns(a, next, expected); - } - } break; - case NODE_RETURN: { - bool err = !str_eq(node->type, expected); - if (err) { - eprintln( - "%s:%d:%d: error: mismatched return type %s, expected %s", - a->file_name, node->line, node->col, node->type, expected); - a->err = true; - } - } break; - case NODE_BLOCK: { - for (sz i = 0; i < array_size(node->elements); i++) { - Node *next = node->elements[i]; - typecheck_returns(a, next, expected); - } - } break; - case NODE_IF: { - if (node->cond_expr) { - typecheck_returns(a, node->cond_expr, expected); - } - if (node->cond_else) { - typecheck_returns(a, node->cond_else, expected); - } - } break; - case NODE_SET: - case NODE_LET: { - if (node->var_val) { - typecheck_returns(a, node->var_val, expected); - } - } break; - case NODE_ADD: - case NODE_SUB: - case NODE_DIV: - case NODE_MUL: - case NODE_MOD: - case NODE_NOT: - case NODE_AND: - case NODE_OR: - case NODE_EQ: - case NODE_NEQ: - case NODE_LT: - case NODE_GT: - case NODE_LE: - case NODE_GE: - case NODE_BITNOT: - case NODE_BITAND: - case NODE_BITOR: - case NODE_BITLSHIFT: - case NODE_BITRSHIFT: { - if (node->left) { - typecheck_returns(a, node->left, expected); - } - if (node->right) { - typecheck_returns(a, node->right, expected); - } - } break; - default: break; - } -} - -Str -type_inference(Analyzer *a, Node *node, Scope *scope) { - assert(a); - assert(scope); - if (!node) { - return cstr(""); - } - // NOTE: For now we are not going to do implicit numeric conversions. - switch (node->kind) { - case NODE_LET: { - node->type = cstr("nil"); - Str symbol = node->var_name->value.str; - if (symmap_lookup(&scope->symbols, symbol)) { - eprintln( - "%s:%d:%d: error: symbol '%s' already exists in current " - "scope ", - a->file_name, node->var_name->line, node->var_name->col, - symbol); - a->err = true; - return cstr(""); - } - if (node->var_type) { - Str type_name = node->var_type->value.str; - SymbolMap *type = find_type(scope, type_name); - if (type == NULL) { - eprintln("%s:%d:%d: error: unknown type '%s'", a->file_name, - node->var_type->line, node->var_type->col, - type_name); - a->err = true; - return cstr(""); - } - if (node->var_type->is_ptr) { - type_name = str_concat(cstr("@"), type_name, a->storage); - } - if (node->var_type->kind == NODE_ARR_TYPE) { - type_name = str_concat(cstr("@"), type_name, a->storage); - // TODO: typecheck size - // TODO: register array in scope - } - if (node->var_val) { - Str type = type_inference(a, node->var_val, scope); - if (!type.size) { - eprintln( - "%s:%d:%d: error: can't bind `nil` to variable " - "'%s'", - a->file_name, node->var_type->line, - node->var_type->col, symbol); - a->err = true; - return cstr(""); - } - // TODO: Consider compatible types. - if (!str_eq(type, type_name)) { - // Special case, enums can be treated as ints. - FindEnumResult res = find_enum(scope, type_name); - if (!(res.map && str_eq(type, cstr("int")))) { - eprintln( - "%s:%d:%d: error: type mismatch, trying to " - "assing " - "%s" - " to a variable of type %s", - a->file_name, node->var_type->line, - node->var_type->col, type, type_name); - a->err = true; - return cstr(""); - } - } - } - symmap_insert(&scope->symbols, symbol, - (Symbol){ - .name = type_name, - .kind = SYM_VAR, - }, - a->storage); - return node->type; - } - - // We don't know the type for this symbol, perform inference. - Str type = type_inference(a, node->var_val, scope); - if (type.size) { - symmap_insert(&scope->symbols, symbol, - (Symbol){.name = type, .kind = SYM_VAR}, - a->storage); - node->var_name->type = type; - } - return node->type; - } break; - case NODE_SET: { - Str name = type_inference(a, node->var_name, scope); - Str val = type_inference(a, node->var_val, scope); - if (!str_eq(name, val)) { - eprintln( - "%s:%d:%d: error: type mismatch, trying to assing " - "%s" - " to a variable of type %s", - a->file_name, node->line, node->col, val, name); - a->err = true; - return cstr(""); - } - node->type = cstr("nil"); - return node->type; - } break; - case NODE_STRUCT: { - node->type = cstr("nil"); - Str symbol = node->value.str; - if (symmap_lookup(&scope->symbols, symbol) != NULL) { - eprintln( - "%s:%d:%d: error: struct '%s' already exists in current " - "scope", - a->file_name, node->line, node->col, symbol); - a->err = true; - return cstr(""); - } - structmap_insert(&scope->structs, symbol, (Struct){.name = symbol}, - a->storage); - for (sz i = 0; i < array_size(node->struct_field); i++) { - Node *field = node->struct_field[i]; - typecheck_field(a, field, scope, symbol); - } - symmap_insert(&scope->symbols, symbol, - (Symbol){.name = symbol, .kind = SYM_STRUCT}, - a->storage); - return node->type; - } break; - case NODE_ENUM: { - node->type = cstr("nil"); - Str symbol = node->value.str; - if (symmap_lookup(&scope->symbols, symbol) != NULL) { - eprintln( - "%s:%d:%d: error: enum '%s' already exists in current " - "scope", - a->file_name, node->line, node->col, symbol); - a->err = true; - return cstr(""); - } - enummap_insert(&scope->enums, symbol, - (Enum){ - .name = symbol, - .val = node->field_val, - }, - a->storage); - for (sz i = 0; i < array_size(node->struct_field); i++) { - Node *field = node->struct_field[i]; - Str field_name = str_concat(symbol, cstr("."), a->storage); - field_name = - str_concat(field_name, field->value.str, a->storage); - if (enummap_lookup(&scope->enums, field_name)) { - eprintln("%s:%d:%d: error: enum field '%s' already exists", - a->file_name, field->line, field->col, field_name); - a->err = true; - } - if (field->field_val) { - Str type = type_inference(a, field->field_val, scope); - if (!str_eq(type, cstr("int"))) { - eprintln( - "%s:%d:%d: error: non int enum value for '%s.%s'", - a->file_name, field->line, field->col, symbol, - field_name); - a->err = true; - } - } - enummap_insert(&scope->enums, field_name, - (Enum){.name = field_name}, a->storage); - symmap_insert( - &scope->symbols, field_name, - (Symbol){.name = field_name, .kind = SYM_ENUM_FIELD}, - a->storage); - field->type = symbol; - } - symmap_insert(&scope->symbols, symbol, - (Symbol){.name = symbol, .kind = SYM_ENUM}, - a->storage); - return node->type; - } break; - case NODE_IF: { - Str cond_type = type_inference(a, node->cond_if, scope); - if (!str_eq(cond_type, cstr("bool"))) { - emit_semantic_error( - a, node->cond_if, - cstr("non boolean expression on if condition")); - return cstr(""); - } - if (node->cond_expr->kind == NODE_BLOCK) { - node->type = type_inference(a, node->cond_expr, scope); - } else { - Scope *next = typescope_alloc(a, scope); - node->type = type_inference(a, node->cond_expr, next); - } - if (node->cond_else) { - Str else_type; - if (node->cond_else->kind == NODE_BLOCK) { - else_type = type_inference(a, node->cond_else, scope); - } else { - Scope *next = typescope_alloc(a, scope); - else_type = type_inference(a, node->cond_else, next); - } - if (!str_eq(node->type, else_type)) { - emit_semantic_error( - a, node, cstr("mismatch types for if/else branches")); - return cstr(""); - } - } - return node->type; - } break; - case NODE_WHILE: { - Str cond_type = type_inference(a, node->while_cond, scope); - if (!str_eq(cond_type, cstr("bool"))) { - emit_semantic_error( - a, node->cond_if, - cstr("non boolean expression on while condition")); - return cstr(""); - } - if (node->while_expr->kind != NODE_BLOCK) { - scope = typescope_alloc(a, scope); - } - type_inference(a, node->while_expr, scope); - node->type = cstr("nil"); - return node->type; - } break; - case NODE_COND: { - Str previous = cstr(""); - for (sz i = 0; i < array_size(node->match_cases); i++) { - Node *expr = node->match_cases[i]; - Str next = type_inference(a, expr, scope); - if (i != 0 && !str_eq(next, previous)) { - emit_semantic_error( - a, node, - cstr("non-matching types for cond expressions")); - return cstr(""); - } - previous = next; - } - node->type = previous; - return node->type; - } break; - case NODE_MATCH: { - Str e = type_inference(a, node->match_expr, scope); - if (str_eq(e, cstr("int"))) { - // Integer matching. - for (sz i = 0; i < array_size(node->match_cases); i++) { - Node *field = node->match_cases[i]; - if (field->case_value) { - if (field->case_value->kind != NODE_NUM_INT && - field->case_value->kind != NODE_NUM_UINT) { - emit_semantic_error( - a, field->case_value, - cstr( - "non-integer or enum types on match case")); - } - } - } - } else { - // Get enum type and de-structure the match. - FindEnumResult res = find_enum(scope, e); - Str enum_prefix = - str_concat(res.map->val.name, cstr("."), a->storage); - for (sz i = 0; i < array_size(node->match_cases); i++) { - Node *field = node->match_cases[i]; - if (field->case_value) { - Str field_name = str_concat( - enum_prefix, field->case_value->value.str, - a->storage); - if (!enummap_lookup(&res.scope->enums, field_name)) { - eprintln("%s:%d:%d: error: unknown enum field '%s'", - a->file_name, field->case_value->line, - field->case_value->col, field_name); - a->err = true; - } - } - } - } - Str previous = cstr(""); - for (sz i = 0; i < array_size(node->match_cases); i++) { - Node *expr = node->match_cases[i]; - Str next = type_inference(a, expr, scope); - if (i != 0 && !str_eq(next, previous)) { - emit_semantic_error( - a, node, - cstr("non-matching types for match expressions")); - return cstr(""); - } - previous = next; - } - node->type = previous; - return node->type; - } break; - case NODE_CASE_MATCH: { - if (node->case_expr->kind != NODE_BLOCK) { - scope = typescope_alloc(a, scope); - } - node->type = type_inference(a, node->case_expr, scope); - return node->type; - } break; - case NODE_CASE_COND: { - if (node->case_value) { - Str cond = type_inference(a, node->case_value, scope); - if (!str_eq(cond, cstr("bool"))) { - emit_semantic_error(a, node, - cstr("non-boolean case condition")); - } - } - if (node->case_expr->kind != NODE_BLOCK) { - scope = typescope_alloc(a, scope); - } - node->type = type_inference(a, node->case_expr, scope); - return node->type; - } break; - case NODE_TRUE: - case NODE_FALSE: { - node->type = cstr("bool"); - return node->type; - } break; - case NODE_NIL: { - node->type = cstr("nil"); - return node->type; - } break; - case NODE_NOT: - case NODE_AND: - case NODE_OR: { - Str left = type_inference(a, node->left, scope); - if (!str_eq(left, cstr("bool"))) { - emit_semantic_error(a, node, - cstr("expected bool on logic expression")); - return cstr(""); - } - if (node->right) { - Str right = type_inference(a, node->right, scope); - if (!str_eq(right, cstr("bool"))) { - emit_semantic_error( - a, node, cstr("expected bool on logic expression")); - return cstr(""); - } - } - node->type = cstr("bool"); - return node->type; - } break; - case NODE_EQ: - case NODE_NEQ: - case NODE_LT: - case NODE_GT: - case NODE_LE: - case NODE_GE: { - Str left = type_inference(a, node->left, scope); - Str right = type_inference(a, node->right, scope); - if (!str_eq(left, right)) { - emit_semantic_error( - a, node, cstr("mismatched types on binary expression")); - return cstr(""); - } - node->type = cstr("bool"); - return node->type; - } break; - case NODE_BITNOT: { - Str left = type_inference(a, node->left, scope); - if (!strset_lookup(&a->integer_types, left)) { - emit_semantic_error( - a, node, cstr("non integer type on bit twiddling expr")); - return cstr(""); - } - node->type = left; - return node->type; - } break; - case NODE_BITAND: - case NODE_BITOR: - case NODE_BITLSHIFT: - case NODE_BITRSHIFT: { - Str left = type_inference(a, node->left, scope); - Str right = type_inference(a, node->right, scope); - if (!strset_lookup(&a->integer_types, left) || - !strset_lookup(&a->integer_types, right)) { - emit_semantic_error( - a, node, cstr("non integer type on bit twiddling expr")); - return cstr(""); - } - node->type = left; - return node->type; - } break; - case NODE_ADD: - case NODE_SUB: - case NODE_DIV: - case NODE_MUL: - case NODE_MOD: { - Str left = type_inference(a, node->left, scope); - Str right = type_inference(a, node->right, scope); - if (!strset_lookup(&a->numeric_types, left) || - !strset_lookup(&a->numeric_types, right)) { - emit_semantic_error( - a, node, cstr("non numeric type on arithmetic expr")); - return cstr(""); - } - if (!str_eq(left, right)) { - emit_semantic_error( - a, node, cstr("mismatched types on binary expression")); - return cstr(""); - } - node->type = left; - return node->type; - } break; - case NODE_NUM_UINT: { - node->type = cstr("uint"); - return node->type; - } break; - case NODE_NUM_INT: { - node->type = cstr("int"); - return node->type; - } break; - case NODE_NUM_FLOAT: { - node->type = cstr("f64"); - return node->type; - } break; - case NODE_STRING: { - node->type = cstr("str"); - return node->type; - } break; - case NODE_ARR_TYPE: - case NODE_TYPE: { - SymbolMap *type = find_type(scope, node->value.str); - if (!type) { - emit_semantic_error(a, node, cstr("unknown type")); - return cstr(""); - } - node->type = type->val.name; - return node->type; - } break; - case NODE_SYMBOL_IDX: - case NODE_SYMBOL: { - Str symbol = node->value.str; - SymbolMap *type = find_type(scope, symbol); - if (!type) { - eprintln("%s:%d:%d: error: couldn't resolve symbol '%s'", - a->file_name, node->line, node->col, symbol); - a->err = true; - return cstr(""); - } - Str type_name = type->val.name; - if (node->kind == NODE_SYMBOL_IDX) { - Str idx_type = type_inference(a, node->arr_size, scope); - if (!strset_lookup(&a->integer_types, idx_type)) { - emit_semantic_error( - a, node, cstr("can't resolve non integer index")); - return cstr(""); - } - type_name = str_remove_prefix(type_name, cstr("@")); - } - if (node->is_ptr) { - type_name = str_concat(cstr("@"), type_name, a->storage); - } - - FindEnumResult e = find_enum(scope, type_name); - if (e.map && str_eq(symbol, type_name)) { - if (!node->next) { - eprintln( - "%s:%d:%d: error: unspecified enum field for symbol " - "'%s'", - a->file_name, node->line, node->col, symbol); - a->err = true; - return cstr(""); - } - // Check if there is a next and it matches the enum field. - Str field = str_concat(type_name, cstr("."), a->storage); - field = str_concat(field, node->next->value.str, a->storage); - if (!enummap_lookup(&e.scope->enums, field)) { - eprintln( - "%s:%d:%d: error: unknown enum field for " - "'%s': %s", - a->file_name, node->line, node->col, symbol, - node->next->value.str); - a->err = true; - return cstr(""); - } - node->next->type = type_name; - node->type = type_name; - return node->next->type; - } - - FindStructResult s = find_struct(scope, type_name); - if (s.map) { - if (str_eq(symbol, type_name)) { - eprintln( - "%s:%d:%d: error: struct incomplete struct literal " - "'%s', did you mean to use %s:{}?", - a->file_name, node->line, node->col, symbol, symbol); - a->err = true; - return cstr(""); - } else { - if (node->next) { - Str chain = type_name; - Node *next = node; - while (next->next) { - next = next->next; - chain = str_concat(chain, cstr("."), a->storage); - chain = - str_concat(chain, next->value.str, a->storage); - } - StructMap *field = - structmap_lookup(&s.scope->structs, chain); - if (!field) { - eprintln( - "%s:%d:%d: error: unknown struct field '%s'", - a->file_name, node->line, node->col, chain); - a->err = true; - return cstr(""); - } - Str field_type = field->val.type; - if (next->kind == NODE_SYMBOL_IDX) { - Str idx_type = - type_inference(a, next->arr_size, scope); - if (!strset_lookup(&a->integer_types, idx_type)) { - emit_semantic_error( - a, next, - cstr("can't resolve non integer index")); - return cstr(""); - } - field_type = - str_remove_prefix(field_type, cstr("@")); - } - node->type = field_type; - return node->type; - } - } - } - node->type = type_name; - return node->type; - } break; - case NODE_STRUCT_LIT: { - Str name = node->value.str; - FindStructResult s = find_struct(scope, name); - if (!s.map) { - eprintln("%s:%d:%d: error: unknown struct type '%s'", - a->file_name, node->line, node->col, name); - a->err = true; - return cstr(""); - } - - StrSet *set = NULL; - for (sz i = 0; i < array_size(node->elements); i++) { - Node *next = node->elements[i]; - Str field_name = str_concat(name, cstr("."), a->storage); - field_name = - str_concat(field_name, next->value.str, a->storage); - - if (strset_lookup(&set, field_name)) { - eprintln( - "%s:%d:%d: error: field '%s' already present in struct " - "literal", - a->file_name, next->line, next->col, field_name); - a->err = true; - } else { - strset_insert(&set, field_name, a->storage); - } - typecheck_lit_field(a, next, s.scope, field_name); - } - node->type = name; - return node->type; - } break; - case NODE_FUNCALL: { - Str symbol = node->value.str; - FunMap *fun = find_fun(scope, symbol); - if (!fun) { - eprintln( - "%s:%d:%d: error: function '%s' doesn't exist in current " - "scope ", - a->file_name, node->line, node->col, symbol); - a->err = true; - return cstr(""); - } - // Check that actual parameters typecheck - Str args = cstr(""); - for (sz i = 0; i < array_size(node->elements); i++) { - Node *expr = node->elements[i]; - Str type = type_inference(a, expr, scope); - args = str_concat(args, type, a->storage); - if (i != array_size(node->elements) - 1) { - args = str_concat(args, cstr(","), a->storage); - } - } - if (!args.size) { - args = cstr("nil"); - } - Str expected = fun->val.param_type; - if (!str_eq(args, expected) && !str_eq(expected, cstr("..."))) { - eprintln( - "%s:%d:%d: error: mismatched parameter types: %s expected " - "%s", - a->file_name, node->line, node->col, args, expected); - a->err = true; - } - node->type = fun->val.return_type; - return node->type; - } break; - case NODE_BLOCK: { - scope = typescope_alloc(a, scope); - Str type; - for (sz i = 0; i < array_size(node->elements); i++) { - Node *expr = node->elements[i]; - type = type_inference(a, expr, scope); - } - node->type = type; - return node->type; - } break; - case NODE_RETURN: { - Str ret_type = cstr(""); - for (sz i = 0; i < array_size(node->elements); i++) { - Node *expr = node->elements[i]; - Str type = type_inference(a, expr, scope); - ret_type = str_concat(ret_type, type, a->storage); - if (i != array_size(node->elements) - 1) { - ret_type = str_concat(ret_type, cstr(","), a->storage); - } - } - if (!ret_type.size) { - ret_type = cstr("nil"); - } - node->type = ret_type; - return node->type; - } break; - case NODE_FUN: { - node->type = cstr("nil"); - Scope *prev_scope = scope; - scope = typescope_alloc(a, scope); - Str param_type = cstr(""); - for (sz i = 0; i < array_size(node->func_params); i++) { - Node *param = node->func_params[i]; - Str symbol = param->param_name->value.str; - Str type = param->param_type->value.str; - if (param->param_type->is_ptr) { - type = str_concat(cstr("@"), type, a->storage); - } - if (param->param_type->kind == NODE_ARR_TYPE) { - type = str_concat(cstr("@"), type, a->storage); - } - param->param_name->type = - type_inference(a, param->param_type, scope); - param->type = type; - symmap_insert(&scope->symbols, symbol, - (Symbol){.name = type, .kind = SYM_PARAM}, - a->storage); - param_type = str_concat(param_type, type, a->storage); - if (i != array_size(node->func_params) - 1) { - param_type = str_concat(param_type, cstr(","), a->storage); - } - } - if (!param_type.size) { - param_type = cstr("nil"); - } - node->fun_params = param_type; - - Str ret_type = cstr(""); - for (sz i = 0; i < array_size(node->func_ret); i++) { - Node *expr = node->func_ret[i]; - Str type = type_inference(a, expr, scope); - if (expr->is_ptr) { - type = str_concat(cstr("@"), type, a->storage); - } - if (expr->kind == NODE_ARR_TYPE) { - type = str_concat(cstr("@"), type, a->storage); - } - ret_type = str_concat(ret_type, type, a->storage); - if (i != array_size(node->func_ret) - 1) { - ret_type = str_concat(ret_type, cstr(","), a->storage); - } - } - if (!ret_type.size) { - ret_type = cstr("nil"); - } - node->fun_return = ret_type; - - Str symbol = node->func_name->value.str; - if (prev_scope->parent != NULL) { - if (symmap_lookup(&prev_scope->symbols, symbol)) { - eprintln( - "%s:%d:%d: error: function '%s' already defined in " - "current " - "scope ", - a->file_name, node->var_name->line, node->var_name->col, - symbol); - a->err = true; - return cstr(""); - } - symmap_insert(&scope->symbols, symbol, - (Symbol){.name = symbol, .kind = SYM_FUN}, - a->storage); - } - scope->name = symbol; - funmap_insert(&prev_scope->funcs, symbol, - (Fun){.name = symbol, - .param_type = param_type, - .return_type = ret_type}, - a->storage); - - if (node->func_body->kind == NODE_BLOCK) { - Str type; - for (sz i = 0; i < array_size(node->func_body->elements); i++) { - Node *expr = node->func_body->elements[i]; - type = type_inference(a, expr, scope); - } - if (!type.size) { - type = cstr("nil"); - } - node->func_body->type = type; - } else { - type_inference(a, node->func_body, scope); - } - - // Ensure main body return matches the prototype. - if (!str_eq(node->func_body->type, ret_type)) { - eprintln( - "%s:%d:%d: error: mismatched return type %s, expected %s", - a->file_name, node->line, node->col, node->func_body->type, - ret_type); - a->err = true; - } - - // Ensure ALL return statements match the function prototype. - typecheck_returns(a, node->func_body, ret_type); - - // TODO: should return statements be allowed on let blocks? - return node->type; - } break; - default: { - emit_semantic_error(a, node, - cstr("type inference not implemented for this " - "kind of expression")); - println("KIND: %s", node_str[node->kind]); - } break; - } - return cstr(""); -} - -void -symbolic_analysis(Analyzer *a, Parser *parser) { - Scope *scope = typescope_alloc(a, NULL); - assert(a); - assert(parser); - - // Fill builtin tables. - Str builtin_functions[] = { - cstr("print"), - cstr("println"), - }; - for (sz i = 0; i < LEN(builtin_functions); i++) { - Str symbol = builtin_functions[i]; - symmap_insert(&scope->symbols, symbol, - (Symbol){.name = symbol, .kind = SYM_BUILTIN_FUN}, - a->storage); - funmap_insert(&scope->funcs, symbol, - (Fun){.name = symbol, - .param_type = cstr("..."), - .return_type = cstr("nil")}, - a->storage); - } - Str builtin_types[] = { - cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), - cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"), - cstr("f32"), cstr("f64"), cstr("ptr"), cstr("int"), - cstr("uint"), cstr("str"), cstr("bool"), cstr("nil")}; - for (sz i = 0; i < LEN(builtin_types); i++) { - Str type = builtin_types[i]; - symmap_insert(&scope->symbols, type, - (Symbol){.name = type, .kind = SYM_BUILTIN_TYPE}, - a->storage); - } - Str numeric_types[] = { - cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), cstr("u32"), - cstr("s32"), cstr("u64"), cstr("s64"), cstr("f32"), cstr("f64"), - cstr("ptr"), cstr("int"), cstr("uint"), - }; - for (sz i = 0; i < LEN(numeric_types); i++) { - Str type = numeric_types[i]; - strset_insert(&a->numeric_types, type, a->storage); - } - Str integer_types[] = { - cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), - cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"), - cstr("ptr"), cstr("int"), cstr("uint"), - }; - for (sz i = 0; i < LEN(integer_types); i++) { - Str type = integer_types[i]; - strset_insert(&a->integer_types, type, a->storage); - } - // Find top level function declarations. - for (sz i = 0; i < array_size(parser->nodes); i++) { - Node *root = parser->nodes[i]; - if (root->kind == NODE_FUN) { - Str symbol = root->func_name->value.str; - if (symmap_lookup(&scope->symbols, symbol)) { - eprintln( - "%s:%d:%d: error: function '%s' already defined in " - "current " - "scope ", - a->file_name, root->var_name->line, root->var_name->col, - symbol); - a->err = true; - } - symmap_insert(&scope->symbols, symbol, - (Symbol){.name = symbol, .kind = SYM_FUN}, - a->storage); - } - } - // Recursively fill symbol tables. - for (sz i = 0; i < array_size(parser->nodes); i++) { - Node *root = parser->nodes[i]; - type_inference(a, root, scope); - } -} - void process_file(Str path) { #if DEBUG == 1 @@ -1396,43 +165,43 @@ process_file(Str path) { // TODO: Type checking. // Compile roots. - // Arena bytecode_arena = arena_create(LEXER_MEM, os_allocator); - // Chunk chunk = {.file_name = path, .storage = &bytecode_arena}; - // array_zero(chunk.constants, 256, &bytecode_arena); - // array_zero(chunk.code, 0xffff, &bytecode_arena); - // sz n_roots = array_size(parser.nodes); - // CompResult res; - // for (sz i = 0; i < n_roots; i++) { - // // The parser stores the root nodes as a stack. - // Node *root = parser.nodes[i]; - // res = compile_expr(&chunk, root); - // } - // sz res_reg = 0; - // switch (res.type) { - // case COMP_CONST: { - // res_reg = chunk.reg_idx++; - // Instruction inst = - // (Instruction){.op = OP_LD64K, .dst = res_reg, .a = res.idx}; - // array_push(chunk.code, inst, chunk.storage); - // } break; - // case COMP_REG: { - // res_reg = res.idx; - // } break; - // default: break; - // } - // // After we are done move the last result to r0 for printing. - // Instruction halt = (Instruction){.op = OP_HALT, .dst = res_reg}; - // array_push(chunk.code, halt, &bytecode_arena); + Arena bytecode_arena = arena_create(LEXER_MEM, os_allocator); + Chunk chunk = {.file_name = path, .storage = &bytecode_arena}; + array_zero(chunk.constants, 256, &bytecode_arena); + array_zero(chunk.code, 0xffff, &bytecode_arena); + sz n_roots = array_size(parser.nodes); + CompResult res; + for (sz i = 0; i < n_roots; i++) { + // The parser stores the root nodes as a stack. + Node *root = parser.nodes[i]; + res = compile_expr(&chunk, root); + } + sz res_reg = 0; + switch (res.type) { + case COMP_CONST: { + res_reg = chunk.reg_idx++; + Instruction inst = + (Instruction){.op = OP_LD64K, .dst = res_reg, .a = res.idx}; + array_push(chunk.code, inst, chunk.storage); + } break; + case COMP_REG: { + res_reg = res.idx; + } break; + default: break; + } + // After we are done move the last result to r0 for printing. + Instruction halt = (Instruction){.op = OP_HALT, .dst = res_reg}; + array_push(chunk.code, halt, &bytecode_arena); - // // Run bytecode on VM. - // VM vm = {0}; - // vm_init(&vm, &chunk); - // // println("VM REGISTERS BEFORE:\n%{Mem}", - // // &(Array){.mem = (u8 *)&vm.regs, sizeof(vm.regs)}); - // vm_run(&vm); + // Run bytecode on VM. + VM vm = {0}; + vm_init(&vm, &chunk); + // println("VM REGISTERS BEFORE:\n%{Mem}", + // &(Array){.mem = (u8 *)&vm.regs, sizeof(vm.regs)}); + vm_run(&vm); // println("VM REGISTERS AFTER:\n%{Mem}", // &(Array){.mem = (u8 *)&vm.regs, sizeof(vm.regs)}); - // disassemble_chunk(chunk); + disassemble_chunk(chunk); #if DEBUG == 1 println("Space used: %{Arena}", &lexer_arena); diff --git a/src/parser.c b/src/parser.c index cdd3c47..f7d0d41 100644 --- a/src/parser.c +++ b/src/parser.c @@ -450,7 +450,10 @@ parse_literal(Parser *parser) { #endif Node *node = NULL; switch (prev.kind) { - case TOK_TRUE: node = node_alloc(parser, NODE_TRUE, prev); break; + case TOK_TRUE: { + node = node_alloc(parser, NODE_TRUE, prev); + node->value.i = 1; + } break; case TOK_FALSE: node = node_alloc(parser, NODE_FALSE, prev); break; case TOK_NIL: node = node_alloc(parser, NODE_NIL, prev); break; default: return; // Unreachable. diff --git a/src/semantic.c b/src/semantic.c index fe88249..428cc53 100644 --- a/src/semantic.c +++ b/src/semantic.c @@ -1,544 +1,1234 @@ -#include "hashtable.h" +#include "badlib.h" -typedef struct Scope { - size_t id; - struct Scope *parent; - HashTable *symbols; - HashTable *types; -} Scope; +typedef enum { + SYM_UNKNOWN, + SYM_BUILTIN_FUN, + SYM_BUILTIN_TYPE, + SYM_FUN, + SYM_VAR, + SYM_PARAM, + SYM_ENUM, + SYM_ENUM_FIELD, + SYM_STRUCT, + SYM_STRUCT_FIELD, +} SymbolKind; -typedef struct ParseTree { - Root *roots; - Scope **scopes; -} ParseTree; - -typedef struct Type { - StringView name; - size_t size; // (bytes) -} Type; - -typedef enum DefaultType { - TYPE_VOID, - TYPE_BOOL, - TYPE_STR, - TYPE_U8, - TYPE_U16, - TYPE_U32, - TYPE_U64, - TYPE_S8, - TYPE_S16, - TYPE_S32, - TYPE_S64, - TYPE_F32, - TYPE_F64, -} DefaultType; - -static Type default_types[] = { - [TYPE_VOID] = {STRING("void"), 0}, - [TYPE_BOOL] = {STRING("bool"), 1}, - [TYPE_STR] = {STRING("str"), 16}, // size (8) + pointer to data (8). - [TYPE_U8] = {STRING("u8"), 1}, - [TYPE_U16] = {STRING("u16"), 2}, - [TYPE_U32] = {STRING("u32"), 4}, - [TYPE_U64] = {STRING("u64"), 8}, - [TYPE_S8] = {STRING("s8"), 1}, - [TYPE_S16] = {STRING("s16"), 2}, - [TYPE_S32] = {STRING("s32"), 4}, - [TYPE_S64] = {STRING("s64"), 8}, - [TYPE_F32] = {STRING("f32"), 4}, - [TYPE_F64] = {STRING("f64"), 8}, +Str sym_kind_str[] = { + [SYM_UNKNOWN] = cstr("UNKNOWN "), + [SYM_BUILTIN_FUN] = cstr("BUILTIN FUN "), + [SYM_BUILTIN_TYPE] = cstr("BUILTIN TYPE "), + [SYM_FUN] = cstr("FUNCTION "), + [SYM_VAR] = cstr("VARIABLE "), + [SYM_PARAM] = cstr("PARAMETER "), + [SYM_ENUM] = cstr("ENUM "), + [SYM_ENUM_FIELD] = cstr("ENUM FIELD "), + [SYM_STRUCT] = cstr("STRUCT "), + [SYM_STRUCT_FIELD] = cstr("STRUCT FIELD "), }; -typedef enum SymbolType { - SYMBOL_VAR, - SYMBOL_PAR, - SYMBOL_FUN, -} SymbolType; - typedef struct Symbol { - Node *name; - SymbolType type; - - union { - struct { - Node *type; - } var; - - struct { - Node **param_types; - Node *return_type; - } fun; - }; + Str name; + SymbolKind kind; } Symbol; -static size_t scope_gen_id = 0; +typedef struct Fun { + Str name; + Str param_type; + Str return_type; +} Fun; -Symbol * -alloc_symval(Node *name, SymbolType type) { - Symbol *val = malloc(sizeof(Symbol)); - val->name = name; - val->type = type; - return val; -} +typedef struct Enum { + Str name; + Node *val; +} Enum; -u64 sym_hash(const struct HashTable *table, void *bytes) { - Node *symbol = bytes; - u64 hash = _xor_shift_hash(symbol->string.start, symbol->string.n); - hash = _fibonacci_hash(hash, table->shift_amount); - return hash; -} +typedef struct Struct { + Str name; + Str type; + Node *val; +} Struct; -bool sym_eq(void *a, void *b) { - Node *a_node = a; - Node *b_node = b; - assert(a_node->type == NODE_SYMBOL); - assert(b_node->type == NODE_SYMBOL); - return sv_equal(&a_node->string, &b_node->string); -} +MAPDEF(SymbolMap, symmap, Str, Symbol, str_hash, str_eq) +MAPDEF(FunMap, funmap, Str, Fun, str_hash, str_eq) +MAPDEF(EnumMap, enummap, Str, Enum, str_hash, str_eq) +MAPDEF(StructMap, structmap, Str, Struct, str_hash, str_eq) -u64 type_hash(const struct HashTable *table, void *bytes) { - StringView *type = bytes; - u64 hash = _xor_shift_hash(type->start, type->n); - hash = _fibonacci_hash(hash, table->shift_amount); - return hash; -} +typedef struct Scope { + sz id; + sz depth; + Str name; + SymbolMap *symbols; + FunMap *funcs; + EnumMap *enums; + StructMap *structs; + struct Scope *parent; +} Scope; -bool type_eq(void *a, void *b) { - StringView *a_type = a; - StringView *b_type = b; - return sv_equal(a_type, b_type); -} +typedef struct Analyzer { + Arena *storage; + Str file_name; + sz typescope_gen; + Scope **scopes; + StrSet *numeric_types; + StrSet *integer_types; + bool err; +} Analyzer; Scope * -alloc_scope(Scope *parent) { - Scope *scope = malloc(sizeof(Scope)); - scope->id = scope_gen_id++; +typescope_alloc(Analyzer *a, Scope *parent) { + Scope *scope = arena_calloc(sizeof(Scope), a->storage); scope->parent = parent; - scope->symbols = ht_init(sym_hash, sym_eq); - scope->types = ht_init(type_hash, type_eq); + scope->id = a->typescope_gen++; + scope->depth = parent == NULL ? 0 : parent->depth + 1; + array_push(a->scopes, scope, a->storage); return scope; } -Type * -find_type(Scope *scope, Node *type) { - // TODO: Normally default types will be used more often. Since we don't - // allow type shadowing, we should search first on the global scope. +SymbolMap * +find_type(Scope *scope, Str type) { while (scope != NULL) { - Type *ret = ht_lookup(scope->types, &type->string); - if (ret != NULL) { - return ret; + SymbolMap *val = symmap_lookup(&scope->symbols, type); + if (val != NULL) { + return val; } scope = scope->parent; } - push_error(ERR_TYPE_SEMANTIC, ERR_UNKNOWN_TYPE, type->line, type->col); return NULL; } -bool -insert_symbol(Scope *scope, Node *symbol, Symbol *val) { - // Check if symbol already exists. - HashTable *symbols = scope->symbols; - if (ht_lookup(symbols, symbol) != NULL) { - push_error(ERR_TYPE_SEMANTIC, ERR_SYMBOL_REDEF, symbol->line, symbol->col); - return false; +FunMap * +find_fun(Scope *scope, Str type) { + while (scope != NULL) { + FunMap *val = funmap_lookup(&scope->funcs, type); + if (val != NULL) { + return val; + } + scope = scope->parent; } - ht_insert(symbols, symbol, val); - return true; + return NULL; } -Type * -coerce_numeric_types(Type *a, Type *b) { - // TODO: Decide what to do with mixed numeric types. What are the promotion - // rules, etc. - if (a == &default_types[TYPE_U8]) { - if (b == &default_types[TYPE_U16] || - b == &default_types[TYPE_U32] || - b == &default_types[TYPE_U64]) { - return b; - } - } else if (a == &default_types[TYPE_U16]) { - if (b == &default_types[TYPE_U32] || - b == &default_types[TYPE_U64]) { - return b; - } - } else if (a == &default_types[TYPE_U32]) { - if (b == &default_types[TYPE_U64]) { - return b; +typedef struct FindEnumResult { + EnumMap *map; + Scope *scope; +} FindEnumResult; + +FindEnumResult +find_enum(Scope *scope, Str type) { + while (scope != NULL) { + EnumMap *val = enummap_lookup(&scope->enums, type); + if (val != NULL) { + return (FindEnumResult){.map = val, .scope = scope}; } - } else if (a == &default_types[TYPE_S8]) { - if (b == &default_types[TYPE_S16] || - b == &default_types[TYPE_S32] || - b == &default_types[TYPE_S64]) { - return b; + scope = scope->parent; + } + return (FindEnumResult){0}; +} + +typedef struct FindStructResult { + StructMap *map; + Scope *scope; +} FindStructResult; + +FindStructResult +find_struct(Scope *scope, Str type) { + while (scope != NULL) { + StructMap *val = structmap_lookup(&scope->structs, type); + if (val != NULL) { + return (FindStructResult){.map = val, .scope = scope}; } - } else if (a == &default_types[TYPE_S16]) { - if (b == &default_types[TYPE_S32] || - b == &default_types[TYPE_S64]) { - return b; + scope = scope->parent; + } + return (FindStructResult){0}; +} + +void +graph_typescope(Scope *scope, Arena a) { + if (!scope->symbols) { + return; + } + SymbolMapIter iter = symmap_iterator(scope->symbols, &a); + SymbolMap *type = symmap_next(&iter, &a); + print( + "%d[shape=\"none\" label=<", + scope->id); + print( + "" + "" + "" + ""); + while (type) { + print( + "" + "" + "" + "", + type->key, type->val.name); + type = symmap_next(&iter, &a); + } + println("
NAME TYPE
%s %s
>];"); + + sz this_id = scope->id; + while (scope->parent) { + if (scope->parent->symbols) { + println("%d:e->%d:w;", this_id, scope->parent->id); + break; + } else { + scope = scope->parent; } - } else if (a == &default_types[TYPE_S32]) { - if (b == &default_types[TYPE_S64]) { - return b; + } +} + +void +graph_functions(Scope *scope, Arena a) { + if (!scope->funcs) { + return; + } + FunMapIter iter = funmap_iterator(scope->funcs, &a); + FunMap *func = funmap_next(&iter, &a); + print( + "fun_%d[shape=\"none\" label=<", + scope->id); + print( + "" + "" + "" + "" + ""); + while (func) { + print( + "" + "" + "" + "" + "", + func->val.name, func->val.name, func->val.param_type, + func->val.return_type); + func = funmap_next(&iter, &a); + } + println("
NAME PARAMS RETURN
%s %s %s
>];"); + sz this_id = scope->id; + while (scope->parent) { + if (scope->parent->symbols) { + println("fun_%d:e->fun_%d:%s:w;", this_id, scope->parent->id, + scope->name); + break; + } else { + scope = scope->parent; } - } else if (a == &default_types[TYPE_F32]) { - if (b == &default_types[TYPE_F64]) { - return b; + } +} + +void +graph_types(Scope **scopes, Arena a) { + if (scopes == NULL) return; + println("digraph types {"); + println("rankdir=LR;"); + println("ranksep=\"0.95 equally\";"); + println("nodesep=\"0.5 equally\";"); + println("overlap=scale;"); + println("bgcolor=\"transparent\";"); + for (sz i = 0; i < array_size(scopes); i++) { + Scope *scope = scopes[i]; + if (!scope) { + continue; } + println("subgraph %d {", i); + graph_typescope(scope, a); + graph_functions(scope, a); + println("}"); } - return a; + println("}"); +} + +void +emit_semantic_error(Analyzer *a, Node *n, Str msg) { + eprintln("%s:%d:%d: error: %s", a->file_name, n->line, n->col, msg); + a->err = true; } -bool -type_is_numeric(Type *t) { - if (t == &default_types[TYPE_U8] || - t == &default_types[TYPE_U16] || - t == &default_types[TYPE_U32] || - t == &default_types[TYPE_U64] || - t == &default_types[TYPE_S8] || - t == &default_types[TYPE_S16] || - t == &default_types[TYPE_S32] || - t == &default_types[TYPE_S64] || - t == &default_types[TYPE_F32] || - t == &default_types[TYPE_F64]) { - return true; +Str type_inference(Analyzer *a, Node *node, Scope *scope); + +void +typecheck_field(Analyzer *a, Node *node, Scope *scope, Str symbol) { + if (node->field_type->kind == NODE_COMPOUND_TYPE) { + Str field_name = str_concat(symbol, cstr("."), a->storage); + field_name = str_concat(field_name, node->value.str, a->storage); + if (structmap_lookup(&scope->structs, field_name)) { + eprintln("%s:%d:%d: error: struct field '%s' already exists", + a->file_name, node->line, node->col, field_name); + a->err = true; + } + Str type = cstr("\\{ "); + for (sz i = 0; i < array_size(node->field_type->elements); i++) { + Node *field = node->field_type->elements[i]; + typecheck_field(a, field, scope, field_name); + type = str_concat(type, field->type, a->storage); + type = str_concat(type, cstr(" "), a->storage); + } + type = str_concat(type, cstr("\\}"), a->storage); + node->type = type; + } else { + Str field_name = str_concat(symbol, cstr("."), a->storage); + field_name = str_concat(field_name, node->value.str, a->storage); + Str field_type = node->field_type->value.str; + if (!find_type(scope, field_type)) { + eprintln("%s:%d:%d: error: unknown type '%s'", a->file_name, + node->field_type->line, node->field_type->col, field_type); + a->err = true; + } + if (node->field_type->is_ptr) { + field_type = str_concat(cstr("@"), field_type, a->storage); + } + if (node->field_type->kind == NODE_ARR_TYPE) { + field_type = str_concat(cstr("@"), field_type, a->storage); + } + if (structmap_lookup(&scope->structs, field_name)) { + eprintln("%s:%d:%d: error: struct field '%s' already exists", + a->file_name, node->line, node->col, field_name); + a->err = true; + } + if (node->field_val) { + Str type = type_inference(a, node->field_val, scope); + if (!str_eq(type, field_type)) { + eprintln( + "%s:%d:%d: error: mismatched types in struct " + "value " + "for '%s': %s expected %s", + a->file_name, node->line, node->col, field_name, type, + field_type); + a->err = true; + } + } + structmap_insert(&scope->structs, field_name, + (Struct){ + .name = field_name, + .type = field_type, + .val = node->field_val, + }, + a->storage); + symmap_insert(&scope->symbols, field_name, + (Symbol){.name = field_type, .kind = SYM_STRUCT_FIELD}, + a->storage); + node->type = field_type; } - return false; } -Symbol * -find_symbol(Scope *scope, Node *node) { - while (scope != NULL) { - Symbol *val = ht_lookup(scope->symbols, node); - if (val != NULL) { - return val; +void +typecheck_lit_field(Analyzer *a, Node *node, Scope *scope, Str symbol) { + if (node->field_val->kind == NODE_COMPOUND_TYPE) { + Str type = cstr("\\{ "); + for (sz i = 0; i < array_size(node->field_val->elements); i++) { + Node *field = node->field_val->elements[i]; + Str field_name = str_concat(symbol, cstr("."), a->storage); + field_name = str_concat(field_name, field->value.str, a->storage); + typecheck_lit_field(a, field, scope, field_name); + type = str_concat(type, field->type, a->storage); + type = str_concat(type, cstr(" "), a->storage); } - scope = scope->parent; + type = str_concat(type, cstr("\\}"), a->storage); + node->type = type; + } else { + StructMap *s = structmap_lookup(&scope->structs, symbol); + if (!s) { + eprintln("%s:%d:%d: error: unknown struct field '%s'", a->file_name, + node->line, node->col, symbol); + a->err = true; + return; + } + Str field_type = s->val.type; + Str type = type_inference(a, node->field_val, scope); + if (!str_eq(type, field_type)) { + eprintln( + "%s:%d:%d: error: mismatched types in struct " + "value " + "for '%s': %s expected %s", + a->file_name, node->line, node->col, symbol, type, field_type); + a->err = true; + } + node->type = field_type; } - push_error(ERR_TYPE_SEMANTIC, ERR_UNKNOWN_SYMBOL, node->line, node->col); - return NULL; } -bool -resolve_type(ParseTree *ast, Scope *scope, Node *node) { - if (node->expr_type != NULL) { - return true; +void +typecheck_returns(Analyzer *a, Node *node, Str expected) { + if (!node) { + return; } - switch (node->type) { - case NODE_BUILTIN: { - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - if (!resolve_type(ast, scope, arg)) { - return false; - } + // Traverse the tree again. + switch (node->kind) { + case NODE_COND: + case NODE_MATCH: { + for (sz i = 0; i < array_size(node->match_cases); i++) { + Node *next = node->match_cases[i]; + typecheck_returns(a, next, expected); } - switch (node->builtin.type) { - // Numbers. - case TOKEN_ADD: - case TOKEN_SUB: - case TOKEN_MUL: - case TOKEN_DIV: - case TOKEN_MOD: { - Type *type = NULL; - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - - // Check that all arguments are numbers. - if (!type_is_numeric(arg->expr_type)) { - push_error(ERR_TYPE_SEMANTIC, ERR_WRONG_TYPE_NUM, - arg->line, arg->col); - return false; - } + } break; + case NODE_RETURN: { + bool err = !str_eq(node->type, expected); + if (err) { + eprintln( + "%s:%d:%d: error: mismatched return type %s, expected %s", + a->file_name, node->line, node->col, node->type, expected); + a->err = true; + } + } break; + case NODE_BLOCK: { + for (sz i = 0; i < array_size(node->elements); i++) { + Node *next = node->elements[i]; + typecheck_returns(a, next, expected); + } + } break; + case NODE_IF: { + if (node->cond_expr) { + typecheck_returns(a, node->cond_expr, expected); + } + if (node->cond_else) { + typecheck_returns(a, node->cond_else, expected); + } + } break; + case NODE_SET: + case NODE_LET: { + if (node->var_val) { + typecheck_returns(a, node->var_val, expected); + } + } break; + case NODE_ADD: + case NODE_SUB: + case NODE_DIV: + case NODE_MUL: + case NODE_MOD: + case NODE_NOT: + case NODE_AND: + case NODE_OR: + case NODE_EQ: + case NODE_NEQ: + case NODE_LT: + case NODE_GT: + case NODE_LE: + case NODE_GE: + case NODE_BITNOT: + case NODE_BITAND: + case NODE_BITOR: + case NODE_BITLSHIFT: + case NODE_BITRSHIFT: { + if (node->left) { + typecheck_returns(a, node->left, expected); + } + if (node->right) { + typecheck_returns(a, node->right, expected); + } + } break; + default: break; + } +} - if (type == NULL) { - type = arg->expr_type; - } else if (type != arg->expr_type) { - type = coerce_numeric_types(type, arg->expr_type); - } +Str +type_inference(Analyzer *a, Node *node, Scope *scope) { + assert(a); + assert(scope); + if (!node) { + return cstr(""); + } + // NOTE: For now we are not going to do implicit numeric conversions. + switch (node->kind) { + case NODE_LET: { + node->type = cstr("nil"); + Str symbol = node->var_name->value.str; + if (symmap_lookup(&scope->symbols, symbol)) { + eprintln( + "%s:%d:%d: error: symbol '%s' already exists in current " + "scope ", + a->file_name, node->var_name->line, node->var_name->col, + symbol); + a->err = true; + return cstr(""); + } + if (node->var_type) { + Str type_name = node->var_type->value.str; + SymbolMap *type = find_type(scope, type_name); + if (type == NULL) { + eprintln("%s:%d:%d: error: unknown type '%s'", a->file_name, + node->var_type->line, node->var_type->col, + type_name); + a->err = true; + return cstr(""); + } + if (node->var_type->is_ptr) { + type_name = str_concat(cstr("@"), type_name, a->storage); + } + if (node->var_type->kind == NODE_ARR_TYPE) { + type_name = str_concat(cstr("@"), type_name, a->storage); + // TODO: typecheck size + // TODO: register array in scope + } + if (node->var_val) { + Str type = type_inference(a, node->var_val, scope); + if (!type.size) { + eprintln( + "%s:%d:%d: error: can't bind `nil` to variable " + "'%s'", + a->file_name, node->var_type->line, + node->var_type->col, symbol); + a->err = true; + return cstr(""); } - node->expr_type = type; - } break; - // Bools. - case TOKEN_NOT: - // TODO: not should only take one argument and - // return the inverse. - case TOKEN_AND: - case TOKEN_OR: { - // Check that all arguments are boolean. - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - if (arg->expr_type != &default_types[TYPE_BOOL]) { - push_error(ERR_TYPE_SEMANTIC, ERR_WRONG_TYPE_BOOL, - arg->line, arg->col); - return false; + // TODO: Consider compatible types. + if (!str_eq(type, type_name)) { + // Special case, enums can be treated as ints. + FindEnumResult res = find_enum(scope, type_name); + if (!(res.map && str_eq(type, cstr("int")))) { + eprintln( + "%s:%d:%d: error: type mismatch, trying to " + "assing " + "%s" + " to a variable of type %s", + a->file_name, node->var_type->line, + node->var_type->col, type, type_name); + a->err = true; + return cstr(""); } } - node->expr_type = &default_types[TYPE_BOOL]; - } break; - case TOKEN_EQ: - case TOKEN_LT: - case TOKEN_GT: - case TOKEN_LE: - case TOKEN_GE: { - // Check that all arguments are nums. - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - // TODO: Make sure all arguments have the same type, - // like with numeric expressions. - if (!type_is_numeric(arg->expr_type)) { - push_error(ERR_TYPE_SEMANTIC, ERR_WRONG_TYPE_NUM, - arg->line, arg->col); - return false; - } + } + symmap_insert(&scope->symbols, symbol, + (Symbol){ + .name = type_name, + .kind = SYM_VAR, + }, + a->storage); + return node->type; + } + + // We don't know the type for this symbol, perform inference. + Str type = type_inference(a, node->var_val, scope); + if (type.size) { + symmap_insert(&scope->symbols, symbol, + (Symbol){.name = type, .kind = SYM_VAR}, + a->storage); + node->var_name->type = type; + } + return node->type; + } break; + case NODE_SET: { + Str name = type_inference(a, node->var_name, scope); + Str val = type_inference(a, node->var_val, scope); + if (!str_eq(name, val)) { + eprintln( + "%s:%d:%d: error: type mismatch, trying to assing " + "%s" + " to a variable of type %s", + a->file_name, node->line, node->col, val, name); + a->err = true; + return cstr(""); + } + node->type = cstr("nil"); + return node->type; + } break; + case NODE_STRUCT: { + node->type = cstr("nil"); + Str symbol = node->value.str; + if (symmap_lookup(&scope->symbols, symbol) != NULL) { + eprintln( + "%s:%d:%d: error: struct '%s' already exists in current " + "scope", + a->file_name, node->line, node->col, symbol); + a->err = true; + return cstr(""); + } + structmap_insert(&scope->structs, symbol, (Struct){.name = symbol}, + a->storage); + for (sz i = 0; i < array_size(node->struct_field); i++) { + Node *field = node->struct_field[i]; + typecheck_field(a, field, scope, symbol); + } + symmap_insert(&scope->symbols, symbol, + (Symbol){.name = symbol, .kind = SYM_STRUCT}, + a->storage); + return node->type; + } break; + case NODE_ENUM: { + node->type = cstr("nil"); + Str symbol = node->value.str; + if (symmap_lookup(&scope->symbols, symbol) != NULL) { + eprintln( + "%s:%d:%d: error: enum '%s' already exists in current " + "scope", + a->file_name, node->line, node->col, symbol); + a->err = true; + return cstr(""); + } + enummap_insert(&scope->enums, symbol, + (Enum){ + .name = symbol, + .val = node->field_val, + }, + a->storage); + for (sz i = 0; i < array_size(node->struct_field); i++) { + Node *field = node->struct_field[i]; + Str field_name = str_concat(symbol, cstr("."), a->storage); + field_name = + str_concat(field_name, field->value.str, a->storage); + if (enummap_lookup(&scope->enums, field_name)) { + eprintln("%s:%d:%d: error: enum field '%s' already exists", + a->file_name, field->line, field->col, field_name); + a->err = true; + } + if (field->field_val) { + Str type = type_inference(a, field->field_val, scope); + if (!str_eq(type, cstr("int"))) { + eprintln( + "%s:%d:%d: error: non int enum value for '%s.%s'", + a->file_name, field->line, field->col, symbol, + field_name); + a->err = true; } - node->expr_type = &default_types[TYPE_BOOL]; - } break; - default: break; + } + enummap_insert(&scope->enums, field_name, + (Enum){.name = field_name}, a->storage); + symmap_insert( + &scope->symbols, field_name, + (Symbol){.name = field_name, .kind = SYM_ENUM_FIELD}, + a->storage); + field->type = symbol; } + symmap_insert(&scope->symbols, symbol, + (Symbol){.name = symbol, .kind = SYM_ENUM}, + a->storage); + return node->type; } break; - case NODE_SYMBOL: { - Symbol *val = find_symbol(scope, node); - if (val == NULL) { - return false; + case NODE_IF: { + Str cond_type = type_inference(a, node->cond_if, scope); + if (!str_eq(cond_type, cstr("bool"))) { + emit_semantic_error( + a, node->cond_if, + cstr("non boolean expression on if condition")); + return cstr(""); } - - Type *type = NULL; - switch (val->type) { - case SYMBOL_PAR: - case SYMBOL_VAR: { - type = find_type(scope, val->var.type); - } break; - case SYMBOL_FUN: { - type = find_type(scope, val->fun.return_type); - } break; + if (node->cond_expr->kind == NODE_BLOCK) { + node->type = type_inference(a, node->cond_expr, scope); + } else { + Scope *next = typescope_alloc(a, scope); + node->type = type_inference(a, node->cond_expr, next); } - if (type == NULL) { - return false; + if (node->cond_else) { + Str else_type; + if (node->cond_else->kind == NODE_BLOCK) { + else_type = type_inference(a, node->cond_else, scope); + } else { + Scope *next = typescope_alloc(a, scope); + else_type = type_inference(a, node->cond_else, next); + } + if (!str_eq(node->type, else_type)) { + emit_semantic_error( + a, node, cstr("mismatch types for if/else branches")); + return cstr(""); + } } - node->expr_type = type; + return node->type; } break; - case NODE_FUN: { - // TODO: don't allow parameters of type void - node->expr_type = &default_types[TYPE_VOID]; - - // Fill up new scope with parameters - scope = alloc_scope(scope); - array_push(ast->scopes, scope); - - // Parameters. - for (size_t i = 0; i < array_size(node->fun.param_names); ++i) { - Node *param = node->fun.param_names[i]; - Node *type = node->fun.param_types[i]; - Symbol *var = alloc_symval(param, SYMBOL_PAR); - var->var.type = type; - if (!insert_symbol(scope, param, var)) { - return false; + case NODE_WHILE: { + Str cond_type = type_inference(a, node->while_cond, scope); + if (!str_eq(cond_type, cstr("bool"))) { + emit_semantic_error( + a, node->cond_if, + cstr("non boolean expression on while condition")); + return cstr(""); + } + if (node->while_expr->kind != NODE_BLOCK) { + scope = typescope_alloc(a, scope); + } + type_inference(a, node->while_expr, scope); + node->type = cstr("nil"); + return node->type; + } break; + case NODE_COND: { + Str previous = cstr(""); + for (sz i = 0; i < array_size(node->match_cases); i++) { + Node *expr = node->match_cases[i]; + Str next = type_inference(a, expr, scope); + if (i != 0 && !str_eq(next, previous)) { + emit_semantic_error( + a, node, + cstr("non-matching types for cond expressions")); + return cstr(""); } + previous = next; } - - // Body. - Node *body = node->fun.body; - if (body->type == NODE_BLOCK) { - for (size_t i = 0; i < array_size(body->block.expr); ++i) { - Node *expr = body->block.expr[i]; - if (!resolve_type(ast, scope, expr)) { - return false; + node->type = previous; + return node->type; + } break; + case NODE_MATCH: { + Str e = type_inference(a, node->match_expr, scope); + if (str_eq(e, cstr("int"))) { + // Integer matching. + for (sz i = 0; i < array_size(node->match_cases); i++) { + Node *field = node->match_cases[i]; + if (field->case_value) { + if (field->case_value->kind != NODE_NUM_INT && + field->case_value->kind != NODE_NUM_UINT) { + emit_semantic_error( + a, field->case_value, + cstr( + "non-integer or enum types on match case")); + } } } - Node *last_expr = body->block.expr[array_size(body->block.expr) - 1]; - node->fun.body->expr_type = last_expr->expr_type; } else { - if (!resolve_type(ast, scope, body)) { - return false; + // Get enum type and de-structure the match. + FindEnumResult res = find_enum(scope, e); + Str enum_prefix = + str_concat(res.map->val.name, cstr("."), a->storage); + for (sz i = 0; i < array_size(node->match_cases); i++) { + Node *field = node->match_cases[i]; + if (field->case_value) { + Str field_name = str_concat( + enum_prefix, field->case_value->value.str, + a->storage); + if (!enummap_lookup(&res.scope->enums, field_name)) { + eprintln("%s:%d:%d: error: unknown enum field '%s'", + a->file_name, field->case_value->line, + field->case_value->col, field_name); + a->err = true; + } + } } } - - // Check that the type of body matches the return type. - StringView *type_body = &node->fun.body->expr_type->name; - StringView *return_type = &node->fun.return_type->string; - if (!sv_equal(type_body, return_type)) { - push_error(ERR_TYPE_SEMANTIC, ERR_WRONG_RET_TYPE, node->line, node->col); - return false; + Str previous = cstr(""); + for (sz i = 0; i < array_size(node->match_cases); i++) { + Node *expr = node->match_cases[i]; + Str next = type_inference(a, expr, scope); + if (i != 0 && !str_eq(next, previous)) { + emit_semantic_error( + a, node, + cstr("non-matching types for match expressions")); + return cstr(""); + } + previous = next; } + node->type = previous; + return node->type; } break; - case NODE_BLOCK: { - scope = alloc_scope(scope); - array_push(ast->scopes, scope); - for (size_t i = 0; i < array_size(node->block.expr); ++i) { - Node *expr = node->block.expr[i]; - if (!resolve_type(ast, scope, expr)) { - return false; + case NODE_CASE_MATCH: { + if (node->case_expr->kind != NODE_BLOCK) { + scope = typescope_alloc(a, scope); + } + node->type = type_inference(a, node->case_expr, scope); + return node->type; + } break; + case NODE_CASE_COND: { + if (node->case_value) { + Str cond = type_inference(a, node->case_value, scope); + if (!str_eq(cond, cstr("bool"))) { + emit_semantic_error(a, node, + cstr("non-boolean case condition")); } } - Node *last_expr = node->block.expr[array_size(node->block.expr) - 1]; - node->expr_type = last_expr->expr_type; + if (node->case_expr->kind != NODE_BLOCK) { + scope = typescope_alloc(a, scope); + } + node->type = type_inference(a, node->case_expr, scope); + return node->type; } break; - case NODE_IF: { - // TODO: If we don't have an else, ifexpr.expr_true - // must be void for consistency. - if (!resolve_type(ast, scope, node->ifexpr.cond)) { - return false; - } - if (!resolve_type(ast, scope, node->ifexpr.expr_true)) { - return false; - } - Type *type_true = node->ifexpr.expr_true->expr_type; - node->expr_type = type_true; - if (node->ifexpr.expr_false != NULL) { - if (!resolve_type(ast, scope, node->ifexpr.expr_false)) { - return false; + case NODE_TRUE: + case NODE_FALSE: { + node->type = cstr("bool"); + return node->type; + } break; + case NODE_NIL: { + node->type = cstr("nil"); + return node->type; + } break; + case NODE_NOT: + case NODE_AND: + case NODE_OR: { + Str left = type_inference(a, node->left, scope); + if (!str_eq(left, cstr("bool"))) { + emit_semantic_error(a, node, + cstr("expected bool on logic expression")); + return cstr(""); + } + if (node->right) { + Str right = type_inference(a, node->right, scope); + if (!str_eq(right, cstr("bool"))) { + emit_semantic_error( + a, node, cstr("expected bool on logic expression")); + return cstr(""); + } + } + node->type = cstr("bool"); + return node->type; + } break; + case NODE_EQ: + case NODE_NEQ: + case NODE_LT: + case NODE_GT: + case NODE_LE: + case NODE_GE: { + Str left = type_inference(a, node->left, scope); + Str right = type_inference(a, node->right, scope); + if (!str_eq(left, right)) { + emit_semantic_error( + a, node, cstr("mismatched types on binary expression")); + return cstr(""); + } + node->type = cstr("bool"); + return node->type; + } break; + case NODE_BITNOT: { + Str left = type_inference(a, node->left, scope); + if (!strset_lookup(&a->integer_types, left)) { + emit_semantic_error( + a, node, cstr("non integer type on bit twiddling expr")); + return cstr(""); + } + node->type = left; + return node->type; + } break; + case NODE_BITAND: + case NODE_BITOR: + case NODE_BITLSHIFT: + case NODE_BITRSHIFT: { + Str left = type_inference(a, node->left, scope); + Str right = type_inference(a, node->right, scope); + if (!strset_lookup(&a->integer_types, left) || + !strset_lookup(&a->integer_types, right)) { + emit_semantic_error( + a, node, cstr("non integer type on bit twiddling expr")); + return cstr(""); + } + node->type = left; + return node->type; + } break; + case NODE_ADD: + case NODE_SUB: + case NODE_DIV: + case NODE_MUL: + case NODE_MOD: { + Str left = type_inference(a, node->left, scope); + Str right = type_inference(a, node->right, scope); + if (!strset_lookup(&a->numeric_types, left) || + !strset_lookup(&a->numeric_types, right)) { + emit_semantic_error( + a, node, cstr("non numeric type on arithmetic expr")); + return cstr(""); + } + if (!str_eq(left, right)) { + emit_semantic_error( + a, node, cstr("mismatched types on binary expression")); + return cstr(""); + } + node->type = left; + return node->type; + } break; + case NODE_NUM_UINT: { + node->type = cstr("uint"); + return node->type; + } break; + case NODE_NUM_INT: { + node->type = cstr("int"); + return node->type; + } break; + case NODE_NUM_FLOAT: { + node->type = cstr("f64"); + return node->type; + } break; + case NODE_STRING: { + node->type = cstr("str"); + return node->type; + } break; + case NODE_ARR_TYPE: + case NODE_TYPE: { + SymbolMap *type = find_type(scope, node->value.str); + if (!type) { + emit_semantic_error(a, node, cstr("unknown type")); + return cstr(""); + } + node->type = type->val.name; + return node->type; + } break; + case NODE_SYMBOL_IDX: + case NODE_SYMBOL: { + Str symbol = node->value.str; + SymbolMap *type = find_type(scope, symbol); + if (!type) { + eprintln("%s:%d:%d: error: couldn't resolve symbol '%s'", + a->file_name, node->line, node->col, symbol); + a->err = true; + return cstr(""); + } + Str type_name = type->val.name; + if (node->kind == NODE_SYMBOL_IDX) { + Str idx_type = type_inference(a, node->arr_size, scope); + if (!strset_lookup(&a->integer_types, idx_type)) { + emit_semantic_error( + a, node, cstr("can't resolve non integer index")); + return cstr(""); } + type_name = str_remove_prefix(type_name, cstr("@")); + } + if (node->is_ptr) { + type_name = str_concat(cstr("@"), type_name, a->storage); } - // Check ifexpr.cond is a bool. - Type *type_cond = node->ifexpr.cond->expr_type; - if (!sv_equal(&type_cond->name, &default_types[TYPE_BOOL].name)) { - push_error(ERR_TYPE_SEMANTIC, ERR_WRONG_COND_TYPE, - node->line, node->col); - return false; + FindEnumResult e = find_enum(scope, type_name); + if (e.map && str_eq(symbol, type_name)) { + if (!node->next) { + eprintln( + "%s:%d:%d: error: unspecified enum field for symbol " + "'%s'", + a->file_name, node->line, node->col, symbol); + a->err = true; + return cstr(""); + } + // Check if there is a next and it matches the enum field. + Str field = str_concat(type_name, cstr("."), a->storage); + field = str_concat(field, node->next->value.str, a->storage); + if (!enummap_lookup(&e.scope->enums, field)) { + eprintln( + "%s:%d:%d: error: unknown enum field for " + "'%s': %s", + a->file_name, node->line, node->col, symbol, + node->next->value.str); + a->err = true; + return cstr(""); + } + node->next->type = type_name; + node->type = type_name; + return node->next->type; } - // Check if types of expr_true and expr_false match - if (node->ifexpr.expr_false != NULL) { - Type *type_false = node->ifexpr.expr_false->expr_type; - if (type_is_numeric(type_true) && type_is_numeric(type_false)) { - node->expr_type = coerce_numeric_types(type_true, type_false); - } else if (!sv_equal(&type_true->name, &type_false->name)) { - push_error(ERR_TYPE_SEMANTIC, ERR_WRONG_TYPE_T_F, - node->line, node->col); - return false; + FindStructResult s = find_struct(scope, type_name); + if (s.map) { + if (str_eq(symbol, type_name)) { + eprintln( + "%s:%d:%d: error: struct incomplete struct literal " + "'%s', did you mean to use %s:{}?", + a->file_name, node->line, node->col, symbol, symbol); + a->err = true; + return cstr(""); + } else { + if (node->next) { + Str chain = type_name; + Node *next = node; + while (next->next) { + next = next->next; + chain = str_concat(chain, cstr("."), a->storage); + chain = + str_concat(chain, next->value.str, a->storage); + } + StructMap *field = + structmap_lookup(&s.scope->structs, chain); + if (!field) { + eprintln( + "%s:%d:%d: error: unknown struct field '%s'", + a->file_name, node->line, node->col, chain); + a->err = true; + return cstr(""); + } + Str field_type = field->val.type; + if (next->kind == NODE_SYMBOL_IDX) { + Str idx_type = + type_inference(a, next->arr_size, scope); + if (!strset_lookup(&a->integer_types, idx_type)) { + emit_semantic_error( + a, next, + cstr("can't resolve non integer index")); + return cstr(""); + } + field_type = + str_remove_prefix(field_type, cstr("@")); + } + node->type = field_type; + return node->type; + } } } + node->type = type_name; + return node->type; } break; - case NODE_SET: { - node->expr_type = &default_types[TYPE_VOID]; - if (!resolve_type(ast, scope, node->set.symbol)) { - return false; - } - if (!resolve_type(ast, scope, node->set.value)) { - return false; - } - Node *symbol = node->set.symbol; - Node *value = node->set.value; - if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { - push_error(ERR_TYPE_SEMANTIC, ERR_TYPE_MISMATCH, - node->line, node->col); - return false; - } - } break; - case NODE_DEF: { - // TODO: don't allow assignment of expressions that return void - // Prepare value for symbol table. - Symbol *var = alloc_symval(node->def.symbol, SYMBOL_VAR); - var->var.type = node->def.type; - if (!insert_symbol(scope, node->def.symbol, var)) { - return false; - } - - Type *type = find_type(scope, node->def.type); - if (type == NULL) { - return false; - } - node->def.symbol->expr_type = type; - - node->expr_type = &default_types[TYPE_VOID]; - // TODO: type inference from right side when not annotated? - if (!resolve_type(ast, scope, node->def.value)) { - return false; - } - Node *symbol = node->def.symbol; - Node *value = node->def.value; - if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { - push_error(ERR_TYPE_SEMANTIC, ERR_TYPE_MISMATCH, - node->line, node->col); - return false; - } - } break; - case NODE_NUMBER: { - // TODO: Numbers are f64/s64 unless explicitely annotated. Annotated - // numbers must fit in the given range (e.g. no negative constants - // inside a u64, no numbers bigger than 255 in a u8, etc.). - if (node->number.fractional != 0) { - // TODO: 1.0 should also be checked as float. - node->expr_type = &default_types[TYPE_F64]; - } else { - node->expr_type = &default_types[TYPE_S64]; + case NODE_STRUCT_LIT: { + Str name = node->value.str; + FindStructResult s = find_struct(scope, name); + if (!s.map) { + eprintln("%s:%d:%d: error: unknown struct type '%s'", + a->file_name, node->line, node->col, name); + a->err = true; + return cstr(""); } + + StrSet *set = NULL; + for (sz i = 0; i < array_size(node->elements); i++) { + Node *next = node->elements[i]; + Str field_name = str_concat(name, cstr("."), a->storage); + field_name = + str_concat(field_name, next->value.str, a->storage); + + if (strset_lookup(&set, field_name)) { + eprintln( + "%s:%d:%d: error: field '%s' already present in struct " + "literal", + a->file_name, next->line, next->col, field_name); + a->err = true; + } else { + strset_insert(&set, field_name, a->storage); + } + typecheck_lit_field(a, next, s.scope, field_name); + } + node->type = name; + return node->type; } break; - case NODE_BOOL: { - node->expr_type = &default_types[TYPE_BOOL]; + case NODE_FUNCALL: { + Str symbol = node->value.str; + FunMap *fun = find_fun(scope, symbol); + if (!fun) { + eprintln( + "%s:%d:%d: error: function '%s' doesn't exist in current " + "scope ", + a->file_name, node->line, node->col, symbol); + a->err = true; + return cstr(""); + } + // Check that actual parameters typecheck + Str args = cstr(""); + for (sz i = 0; i < array_size(node->elements); i++) { + Node *expr = node->elements[i]; + Str type = type_inference(a, expr, scope); + args = str_concat(args, type, a->storage); + if (i != array_size(node->elements) - 1) { + args = str_concat(args, cstr(","), a->storage); + } + } + if (!args.size) { + args = cstr("nil"); + } + Str expected = fun->val.param_type; + if (!str_eq(args, expected) && !str_eq(expected, cstr("..."))) { + eprintln( + "%s:%d:%d: error: mismatched parameter types: %s expected " + "%s", + a->file_name, node->line, node->col, args, expected); + a->err = true; + } + node->type = fun->val.return_type; + return node->type; } break; - case NODE_STRING: { - node->expr_type = &default_types[TYPE_STR]; + case NODE_BLOCK: { + scope = typescope_alloc(a, scope); + Str type; + for (sz i = 0; i < array_size(node->elements); i++) { + Node *expr = node->elements[i]; + type = type_inference(a, expr, scope); + } + node->type = type; + return node->type; } break; - case NODE_FUNCALL: { - Symbol *val = find_symbol(scope, node->funcall.name); - if (!resolve_type(ast, scope, node->funcall.name)) { - return false; - } - if (val->type != SYMBOL_FUN) { - push_error(ERR_TYPE_SEMANTIC, ERR_WRONG_TYPE_FUN, - node->funcall.name->line, node->funcall.name->col); - return false; - } - if (array_size(node->funcall.args) != array_size(val->fun.param_types)) { - push_error(ERR_TYPE_SEMANTIC, ERR_BAD_ARGS, node->line, node->col); - return false; - } - node->expr_type = node->funcall.name->expr_type; - for (size_t i = 0; i < array_size(node->funcall.args); ++i) { - Node *arg = node->funcall.args[i]; - if (!resolve_type(ast, scope, arg)) { - return false; + case NODE_RETURN: { + Str ret_type = cstr(""); + for (sz i = 0; i < array_size(node->elements); i++) { + Node *expr = node->elements[i]; + Str type = type_inference(a, expr, scope); + ret_type = str_concat(ret_type, type, a->storage); + if (i != array_size(node->elements) - 1) { + ret_type = str_concat(ret_type, cstr(","), a->storage); + } + } + if (!ret_type.size) { + ret_type = cstr("nil"); + } + node->type = ret_type; + return node->type; + } break; + case NODE_FUN: { + node->type = cstr("nil"); + Scope *prev_scope = scope; + scope = typescope_alloc(a, scope); + Str param_type = cstr(""); + for (sz i = 0; i < array_size(node->func_params); i++) { + Node *param = node->func_params[i]; + Str symbol = param->param_name->value.str; + Str type = param->param_type->value.str; + if (param->param_type->is_ptr) { + type = str_concat(cstr("@"), type, a->storage); + } + if (param->param_type->kind == NODE_ARR_TYPE) { + type = str_concat(cstr("@"), type, a->storage); + } + param->param_name->type = + type_inference(a, param->param_type, scope); + param->type = type; + symmap_insert(&scope->symbols, symbol, + (Symbol){.name = type, .kind = SYM_PARAM}, + a->storage); + param_type = str_concat(param_type, type, a->storage); + if (i != array_size(node->func_params) - 1) { + param_type = str_concat(param_type, cstr(","), a->storage); + } + } + if (!param_type.size) { + param_type = cstr("nil"); + } + node->fun_params = param_type; + + Str ret_type = cstr(""); + for (sz i = 0; i < array_size(node->func_ret); i++) { + Node *expr = node->func_ret[i]; + Str type = type_inference(a, expr, scope); + if (expr->is_ptr) { + type = str_concat(cstr("@"), type, a->storage); + } + if (expr->kind == NODE_ARR_TYPE) { + type = str_concat(cstr("@"), type, a->storage); } - Node *expected = val->fun.param_types[i]; - if (!sv_equal(&arg->expr_type->name, &expected->string)) { - push_error(ERR_TYPE_SEMANTIC, ERR_TYPE_MISMATCH, - arg->line, arg->col); - return false; + ret_type = str_concat(ret_type, type, a->storage); + if (i != array_size(node->func_ret) - 1) { + ret_type = str_concat(ret_type, cstr(","), a->storage); } } + if (!ret_type.size) { + ret_type = cstr("nil"); + } + node->fun_return = ret_type; + + Str symbol = node->func_name->value.str; + if (prev_scope->parent != NULL) { + if (symmap_lookup(&prev_scope->symbols, symbol)) { + eprintln( + "%s:%d:%d: error: function '%s' already defined in " + "current " + "scope ", + a->file_name, node->var_name->line, node->var_name->col, + symbol); + a->err = true; + return cstr(""); + } + symmap_insert(&scope->symbols, symbol, + (Symbol){.name = symbol, .kind = SYM_FUN}, + a->storage); + } + scope->name = symbol; + funmap_insert(&prev_scope->funcs, symbol, + (Fun){.name = symbol, + .param_type = param_type, + .return_type = ret_type}, + a->storage); + + if (node->func_body->kind == NODE_BLOCK) { + Str type; + for (sz i = 0; i < array_size(node->func_body->elements); i++) { + Node *expr = node->func_body->elements[i]; + type = type_inference(a, expr, scope); + } + if (!type.size) { + type = cstr("nil"); + } + node->func_body->type = type; + } else { + type_inference(a, node->func_body, scope); + } + + // Ensure main body return matches the prototype. + if (!str_eq(node->func_body->type, ret_type)) { + eprintln( + "%s:%d:%d: error: mismatched return type %s, expected %s", + a->file_name, node->line, node->col, node->func_body->type, + ret_type); + a->err = true; + } + + // Ensure ALL return statements match the function prototype. + typecheck_returns(a, node->func_body, ret_type); + + // TODO: should return statements be allowed on let blocks? + return node->type; + } break; + default: { + emit_semantic_error(a, node, + cstr("type inference not implemented for this " + "kind of expression")); + println("KIND: %s", node_str[node->kind]); } break; - default: break; } - return true; + return cstr(""); } -ParseTree * -semantic_analysis(Root *roots) { - ParseTree *parse_tree = malloc(sizeof(ParseTree)); - parse_tree->roots = roots; - array_init(parse_tree->scopes, 0); - Scope *scope = alloc_scope(NULL); - array_push(parse_tree->scopes, scope); - - // Fill global scope with default types. - HashTable *types = scope->types; - for (size_t i = 0; i < sizeof(default_types)/sizeof(Type); ++i) { - Type *type = &default_types[i]; - ht_insert(types, &type->name, type); - } +void +symbolic_analysis(Analyzer *a, Parser *parser) { + Scope *scope = typescope_alloc(a, NULL); + assert(a); + assert(parser); - // Fill up global function symbols. - for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { - Node *root = parse_tree->roots[i]; - if (root->type == NODE_FUN) { - Node *name = root->fun.name; - Symbol *fun = alloc_symval(root->fun.name, SYMBOL_FUN); - fun->fun.param_types = root->fun.param_types; - fun->fun.return_type = root->fun.return_type; - if (!insert_symbol(scope, name, fun)) { - return NULL; + // Fill builtin tables. + Str builtin_functions[] = { + cstr("print"), + cstr("println"), + }; + for (sz i = 0; i < LEN(builtin_functions); i++) { + Str symbol = builtin_functions[i]; + symmap_insert(&scope->symbols, symbol, + (Symbol){.name = symbol, .kind = SYM_BUILTIN_FUN}, + a->storage); + funmap_insert(&scope->funcs, symbol, + (Fun){.name = symbol, + .param_type = cstr("..."), + .return_type = cstr("nil")}, + a->storage); + } + Str builtin_types[] = { + cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), + cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"), + cstr("f32"), cstr("f64"), cstr("ptr"), cstr("int"), + cstr("uint"), cstr("str"), cstr("bool"), cstr("nil")}; + for (sz i = 0; i < LEN(builtin_types); i++) { + Str type = builtin_types[i]; + symmap_insert(&scope->symbols, type, + (Symbol){.name = type, .kind = SYM_BUILTIN_TYPE}, + a->storage); + } + Str numeric_types[] = { + cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), cstr("u32"), + cstr("s32"), cstr("u64"), cstr("s64"), cstr("f32"), cstr("f64"), + cstr("ptr"), cstr("int"), cstr("uint"), + }; + for (sz i = 0; i < LEN(numeric_types); i++) { + Str type = numeric_types[i]; + strset_insert(&a->numeric_types, type, a->storage); + } + Str integer_types[] = { + cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), + cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"), + cstr("ptr"), cstr("int"), cstr("uint"), + }; + for (sz i = 0; i < LEN(integer_types); i++) { + Str type = integer_types[i]; + strset_insert(&a->integer_types, type, a->storage); + } + // Find top level function declarations. + for (sz i = 0; i < array_size(parser->nodes); i++) { + Node *root = parser->nodes[i]; + if (root->kind == NODE_FUN) { + Str symbol = root->func_name->value.str; + if (symmap_lookup(&scope->symbols, symbol)) { + eprintln( + "%s:%d:%d: error: function '%s' already defined in " + "current " + "scope ", + a->file_name, root->var_name->line, root->var_name->col, + symbol); + a->err = true; } + symmap_insert(&scope->symbols, symbol, + (Symbol){.name = symbol, .kind = SYM_FUN}, + a->storage); } } - - for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { - // Fill up symbol tables in proper scope and resolve type of expression - // for all elements. - if (!resolve_type(parse_tree, scope, parse_tree->roots[i])) { - return NULL; - } + // Recursively fill symbol tables. + for (sz i = 0; i < array_size(parser->nodes); i++) { + Node *root = parser->nodes[i]; + type_inference(a, root, scope); } - - return parse_tree; } + diff --git a/src/vm.c b/src/vm.c index 574f4fa..cb1b6cd 100644 --- a/src/vm.c +++ b/src/vm.c @@ -17,10 +17,22 @@ typedef union Constant { typedef struct Chunk { Instruction *code; + + // Constant values that fit in 64 bits. Constant *constants; - Str file_name; - sz reg_idx; + IntIntMap *intmap; sz const_idx; + + // Constant strings. + Str *strings; + StrIntMap *strmap; + sz str_idx; + + // Number of registers currently used in this chunk. + sz reg_idx; + + // Debugging. + Str file_name; Arena *storage; // TODO: line/col info for debugging. } Chunk; @@ -205,6 +217,12 @@ disassemble_chunk(Chunk chunk) { chunk.constants[i]); } } + if (array_size(chunk.strings) > 0) { + println("%s: ========== strings =========", chunk.file_name); + for (sz i = 0; i < array_size(chunk.strings); i++) { + println("%s: %x{2}: %s", chunk.file_name, i, chunk.strings[i]); + } + } } #define N_CONST 256 @@ -292,6 +310,7 @@ vm_run(VM *vm) { typedef enum { COMP_CONST, + COMP_STRING, COMP_REG, COMP_ERR, } CompResultType; @@ -303,11 +322,19 @@ typedef struct CompResult { CompResult compile_expr(Chunk *chunk, Node *node); -// #define EMIT_OP(OP, CHUNK, ARENA) +#define EMIT_OP(OP, DST, A, B, NODE, CHUNK) \ + do { \ + Instruction inst = (Instruction){ \ + .op = (OP), \ + .dst = (DST), \ + .a = (A), \ + .b = (B), \ + }; \ + array_push((CHUNK)->code, inst, (CHUNK)->storage); \ + } while (0) CompResult compile_binary(OpCode op, Chunk *chunk, Node *node) { - sz reg_dst = chunk->reg_idx++; CompResult comp_a = compile_expr(chunk, node->left); CompResult comp_b = compile_expr(chunk, node->right); sz reg_a; @@ -315,9 +342,7 @@ compile_binary(OpCode op, Chunk *chunk, Node *node) { switch (comp_a.type) { case COMP_CONST: { reg_a = chunk->reg_idx++; - Instruction inst = - (Instruction){.op = OP_LD64K, .dst = reg_a, .a = comp_a.idx}; - array_push(chunk->code, inst, chunk->storage); + EMIT_OP(OP_LD64K, reg_a, comp_a.idx, 0, node, chunk); } break; case COMP_REG: { reg_a = comp_a.idx; @@ -329,9 +354,7 @@ compile_binary(OpCode op, Chunk *chunk, Node *node) { switch (comp_b.type) { case COMP_CONST: { reg_b = chunk->reg_idx++; - Instruction inst = - (Instruction){.op = OP_LD64K, .dst = reg_b, .a = comp_b.idx}; - array_push(chunk->code, inst, chunk->storage); + EMIT_OP(OP_LD64K, reg_b, comp_b.idx, 0, node, chunk); } break; case COMP_REG: { reg_b = comp_b.idx; @@ -340,9 +363,9 @@ compile_binary(OpCode op, Chunk *chunk, Node *node) { return (CompResult){.type = COMP_ERR}; } break; } - Instruction inst = - (Instruction){.op = op, .dst = reg_dst, .a = reg_a, .b = reg_b}; - array_push(chunk->code, inst, chunk->storage); + sz reg_dst = comp_a.idx; // Less registers + // sz reg_dst = chunk->reg_idx++; // Better for optimization + EMIT_OP(op, reg_dst, reg_a, reg_b, node, chunk); return (CompResult){.type = COMP_REG, .idx = reg_dst}; } @@ -354,26 +377,43 @@ compile_expr(Chunk *chunk, Node *node) { case NODE_MUL: return compile_binary(OP_MUL, chunk, node); break; case NODE_DIV: return compile_binary(OP_DIV, chunk, node); break; case NODE_MOD: return compile_binary(OP_MOD, chunk, node); break; + case NODE_TRUE: + case NODE_FALSE: case NODE_NUM_FLOAT: case NODE_NUM_INT: { + sz value = node->value.i; // Make sure we don't have duplicated constants. - for (sz i = 0; i < chunk->const_idx; i++) { - if (node->value.i == chunk->constants[i].i) { - return (CompResult){ - .type = COMP_CONST, - .idx = i, - }; - } + IntIntMap *map = intintmap_lookup(&chunk->intmap, value); + if (!map) { + map = intintmap_insert(&chunk->intmap, value, + chunk->const_idx++, chunk->storage); + Constant c = (Constant){.i = node->value.i}; + array_push(chunk->constants, c, chunk->storage); } - Constant c = (Constant){.i = node->value.i}; - array_push(chunk->constants, c, chunk->storage); return (CompResult){ .type = COMP_CONST, - .idx = chunk->const_idx++, + .idx = map->val, + }; + } break; + case NODE_STRING: { + Str string = node->value.str; + // Make sure we don't have duplicated strings. + StrIntMap *map = strintmap_lookup(&chunk->strmap, string); + if (!map) { + map = strintmap_insert(&chunk->strmap, string, chunk->str_idx++, + chunk->storage); + array_push(chunk->strings, string, chunk->storage); + } + return (CompResult){ + .type = COMP_STRING, + .idx = map->val, }; } break; - default: break; + default: { + eprintln("error: compilation not implemented for node %s", + node_str[node->kind]); + exit(EXIT_FAILURE); + } break; } return (CompResult){.type = COMP_ERR}; } - diff --git a/tests/compilation.bad b/tests/compilation.bad new file mode 100644 index 0000000..782e75f --- /dev/null +++ b/tests/compilation.bad @@ -0,0 +1,10 @@ +1 + 2 * 3 +true +false +0 1 +"hello" +"world" +; fun foo(): int { +; 32 +; } +; 1 + 2 + foo() -- cgit v1.2.1