#include "hashtable.h" typedef struct Scope { size_t id; struct Scope *parent; HashTable *symbols; HashTable *types; } Scope; typedef struct ParseTree { Root *roots; Scope **scopes; } ParseTree; typedef struct Type { StringView name; size_t size; // (bytes) } Type; typedef enum DefaultType { TYPE_VOID, TYPE_BOOL, TYPE_STR, TYPE_U8, TYPE_U16, TYPE_U32, TYPE_U64, TYPE_S8, TYPE_S16, TYPE_S32, TYPE_S64, TYPE_F32, TYPE_F64, } DefaultType; static Type default_types[] = { [TYPE_VOID] = {STRING("void"), 0}, [TYPE_BOOL] = {STRING("bool"), 1}, [TYPE_STR] = {STRING("str"), 16}, // size (8) + pointer to data (8). [TYPE_U8] = {STRING("u8"), 1}, [TYPE_U16] = {STRING("u16"), 2}, [TYPE_U32] = {STRING("u32"), 4}, [TYPE_U64] = {STRING("u64"), 8}, [TYPE_S8] = {STRING("s8"), 1}, [TYPE_S16] = {STRING("s16"), 2}, [TYPE_S32] = {STRING("s32"), 4}, [TYPE_S64] = {STRING("s64"), 8}, [TYPE_F32] = {STRING("f32"), 4}, [TYPE_F64] = {STRING("f64"), 8}, }; typedef enum SymbolType { SYMBOL_VAR, SYMBOL_PAR, SYMBOL_FUN, } SymbolType; typedef struct Symbol { Node *name; SymbolType type; union { struct { Node *type; } var; struct { Node **param_types; Node *return_type; } fun; }; } Symbol; static size_t scope_gen_id = 0; Symbol * alloc_symval(Node *name, SymbolType type) { Symbol *val = malloc(sizeof(Symbol)); val->name = name; val->type = type; return val; } u64 sym_hash(const struct HashTable *table, void *bytes) { Node *symbol = bytes; u64 hash = _xor_shift_hash(symbol->string.start, symbol->string.n); hash = _fibonacci_hash(hash, table->shift_amount); return hash; } bool sym_eq(void *a, void *b) { Node *a_node = a; Node *b_node = b; assert(a_node->type == NODE_SYMBOL); assert(b_node->type == NODE_SYMBOL); return sv_equal(&a_node->string, &b_node->string); } u64 type_hash(const struct HashTable *table, void *bytes) { StringView *type = bytes; u64 hash = _xor_shift_hash(type->start, type->n); hash = _fibonacci_hash(hash, table->shift_amount); return hash; } bool type_eq(void *a, void *b) { StringView *a_type = a; StringView *b_type = b; return sv_equal(a_type, b_type); } Scope * alloc_scope(Scope *parent) { Scope *scope = malloc(sizeof(Scope)); scope->id = scope_gen_id++; scope->parent = parent; scope->symbols = ht_init(sym_hash, sym_eq); scope->types = ht_init(type_hash, type_eq); return scope; } Type * find_type(Scope *scope, Node *type) { // TODO: Normally default types will be used more often. Since we don't // allow type shadowing, we should search first on the global scope. while (scope != NULL) { Type *ret = ht_lookup(scope->types, &type->string); if (ret != NULL) { return ret; } scope = scope->parent; } push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_TYPE, type->line, type->col); return NULL; } bool insert_symbol(Scope *scope, Node *symbol, Symbol *val) { // Check if symbol already exists. HashTable *symbols = scope->symbols; if (ht_lookup(symbols, symbol) != NULL) { push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col); return false; } ht_insert(symbols, symbol, val); return true; } Type * coerce_numeric_types(Type *a, Type *b) { // TODO: Decide what to do with mixed numeric types. What are the promotion // rules, etc. if (a == &default_types[TYPE_U8]) { if (b == &default_types[TYPE_U16] || b == &default_types[TYPE_U32] || b == &default_types[TYPE_U64]) { return b; } } else if (a == &default_types[TYPE_U16]) { if (b == &default_types[TYPE_U32] || b == &default_types[TYPE_U64]) { return b; } } else if (a == &default_types[TYPE_U32]) { if (b == &default_types[TYPE_U64]) { return b; } } else if (a == &default_types[TYPE_S8]) { if (b == &default_types[TYPE_S16] || b == &default_types[TYPE_S32] || b == &default_types[TYPE_S64]) { return b; } } else if (a == &default_types[TYPE_S16]) { if (b == &default_types[TYPE_S32] || b == &default_types[TYPE_S64]) { return b; } } else if (a == &default_types[TYPE_S32]) { if (b == &default_types[TYPE_S64]) { return b; } } else if (a == &default_types[TYPE_F32]) { if (b == &default_types[TYPE_F64]) { return b; } } return a; } bool type_is_numeric(Type *t) { if (t == &default_types[TYPE_U8] || t == &default_types[TYPE_U16] || t == &default_types[TYPE_U32] || t == &default_types[TYPE_U64] || t == &default_types[TYPE_S8] || t == &default_types[TYPE_S16] || t == &default_types[TYPE_S32] || t == &default_types[TYPE_S64] || t == &default_types[TYPE_F32] || t == &default_types[TYPE_F64]) { return true; } return false; } Symbol * find_symbol(Scope *scope, Node *node) { while (scope != NULL) { Symbol *val = ht_lookup(scope->symbols, node); if (val != NULL) { return val; } scope = scope->parent; } push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col); return NULL; } bool resolve_type(ParseTree *ast, Scope *scope, Node *node) { if (node->expr_type != NULL) { return true; } switch (node->type) { case NODE_BUILTIN: { for (size_t i = 0; i < array_size(node->builtin.args); ++i) { Node *arg = node->builtin.args[i]; if (!resolve_type(ast, scope, arg)) { return false; } } switch (node->builtin.type) { // Numbers. case TOKEN_ADD: case TOKEN_SUB: case TOKEN_MUL: case TOKEN_DIV: case TOKEN_MOD: { Type *type = NULL; for (size_t i = 0; i < array_size(node->builtin.args); ++i) { Node *arg = node->builtin.args[i]; // Check that all arguments are numbers. if (!type_is_numeric(arg->expr_type)) { push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM, arg->line, arg->col); return false; } if (type == NULL) { type = arg->expr_type; } else if (type != arg->expr_type) { type = coerce_numeric_types(type, arg->expr_type); } } node->expr_type = type; } break; // Bools. case TOKEN_NOT: case TOKEN_AND: case TOKEN_OR: { // Check that all arguments are boolean. for (size_t i = 0; i < array_size(node->builtin.args); ++i) { Node *arg = node->builtin.args[i]; if (arg->expr_type != &default_types[TYPE_BOOL]) { push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_BOOL, arg->line, arg->col); return false; } } node->expr_type = &default_types[TYPE_BOOL]; } break; case TOKEN_EQ: case TOKEN_LT: case TOKEN_GT: case TOKEN_LE: case TOKEN_GE: { // Check that all arguments are nums. for (size_t i = 0; i < array_size(node->builtin.args); ++i) { Node *arg = node->builtin.args[i]; if (!type_is_numeric(arg->expr_type)) { push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM, arg->line, arg->col); return false; } } node->expr_type = &default_types[TYPE_BOOL]; } break; default: break; } } break; case NODE_SYMBOL: { Symbol *val = find_symbol(scope, node); if (val == NULL) { return false; } Type *type = NULL; switch (val->type) { case SYMBOL_PAR: case SYMBOL_VAR: { type = find_type(scope, val->var.type); } break; case SYMBOL_FUN: { type = find_type(scope, val->fun.return_type); } break; } if (type == NULL) { return false; } node->expr_type = type; } break; case NODE_FUN: { node->expr_type = &default_types[TYPE_VOID]; // Fill up new scope with parameters scope = alloc_scope(scope); array_push(ast->scopes, scope); // Parameters. for (size_t i = 0; i < array_size(node->fun.param_names); ++i) { Node *param = node->fun.param_names[i]; Node *type = node->fun.param_types[i]; Symbol *var = alloc_symval(param, SYMBOL_PAR); var->var.type = type; if (!insert_symbol(scope, param, var)) { return false; } } // Body. Node *body = node->fun.body; if (body->type == NODE_BLOCK) { for (size_t i = 0; i < array_size(body->block.expr); ++i) { Node *expr = body->block.expr[i]; if (!resolve_type(ast, scope, expr)) { return false; } } Node *last_expr = body->block.expr[array_size(body->block.expr) - 1]; node->fun.body->expr_type = last_expr->expr_type; } else { if (!resolve_type(ast, scope, body)) { return false; } } // Check that the type of body matches the return type. StringView *type_body = &node->fun.body->expr_type->name; StringView *return_type = &node->fun.return_type->string; if (!sv_equal(type_body, return_type)) { push_error(ERR_TYPE_PARSER, ERR_WRONG_RET_TYPE, node->line, node->col); return false; } } break; case NODE_BLOCK: { scope = alloc_scope(scope); array_push(ast->scopes, scope); for (size_t i = 0; i < array_size(node->block.expr); ++i) { Node *expr = node->block.expr[i]; if (!resolve_type(ast, scope, expr)) { return false; } } Node *last_expr = node->block.expr[array_size(node->block.expr) - 1]; node->expr_type = last_expr->expr_type; } break; case NODE_IF: { if (!resolve_type(ast, scope, node->ifexpr.cond)) { return false; } if (!resolve_type(ast, scope, node->ifexpr.expr_true)) { return false; } Type *type_true = node->ifexpr.expr_true->expr_type; node->expr_type = type_true; if (node->ifexpr.expr_false != NULL) { if (!resolve_type(ast, scope, node->ifexpr.expr_false)) { return false; } } // Check ifexpr.cond is a bool. Type *type_cond = node->ifexpr.cond->expr_type; if (!sv_equal(&type_cond->name, &default_types[TYPE_BOOL].name)) { push_error(ERR_TYPE_PARSER, ERR_WRONG_COND_TYPE, node->line, node->col); return false; } // Check if types of expr_true and expr_false match if (node->ifexpr.expr_false != NULL) { Type *type_false = node->ifexpr.expr_false->expr_type; if (type_is_numeric(type_true) && type_is_numeric(type_false)) { node->expr_type = coerce_numeric_types(type_true, type_false); } else if (!sv_equal(&type_true->name, &type_false->name)) { push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_T_F, node->line, node->col); return false; } } } break; case NODE_SET: { node->expr_type = &default_types[TYPE_VOID]; if (!resolve_type(ast, scope, node->set.symbol)) { return false; } if (!resolve_type(ast, scope, node->set.value)) { return false; } Node *symbol = node->set.symbol; Node *value = node->set.value; if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, node->line, node->col); return false; } } break; case NODE_DEF: { // Prepare value for symbol table. Symbol *var = alloc_symval(node->def.symbol, SYMBOL_VAR); var->var.type = node->def.type; if (!insert_symbol(scope, node->def.symbol, var)) { return false; } Type *type = find_type(scope, node->def.type); if (type == NULL) { return false; } node->def.symbol->expr_type = type; node->expr_type = &default_types[TYPE_VOID]; // TODO: type inference from right side when not annotated? if (!resolve_type(ast, scope, node->def.value)) { return false; } Node *symbol = node->def.symbol; Node *value = node->def.value; if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, node->line, node->col); return false; } } break; case NODE_NUMBER: { // TODO: Numbers are f64/s64 unless explicitely annotated. Annotated // numbers must fit in the given range (e.g. no negative constants // inside a u64, no numbers bigger than 255 in a u8, etc.). if (node->number.fractional != 0) { node->expr_type = &default_types[TYPE_F64]; } else { node->expr_type = &default_types[TYPE_S64]; } } break; case NODE_BOOL: { node->expr_type = &default_types[TYPE_BOOL]; } break; case NODE_STRING: { node->expr_type = &default_types[TYPE_STR]; } break; case NODE_FUNCALL: { Symbol *val = find_symbol(scope, node->funcall.name); if (!resolve_type(ast, scope, node->funcall.name)) { return false; } if (val->type != SYMBOL_FUN) { push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_FUN, node->funcall.name->line, node->funcall.name->col); return false; } if (array_size(node->funcall.args) != array_size(val->fun.param_types)) { push_error(ERR_TYPE_PARSER, ERR_BAD_ARGS, node->line, node->col); return false; } node->expr_type = node->funcall.name->expr_type; for (size_t i = 0; i < array_size(node->funcall.args); ++i) { Node *arg = node->funcall.args[i]; if (!resolve_type(ast, scope, arg)) { return false; } Node *expected = val->fun.param_types[i]; if (!sv_equal(&arg->expr_type->name, &expected->string)) { push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, arg->line, arg->col); return false; } } } break; default: break; } return true; } ParseTree * semantic_analysis(Root *roots) { ParseTree *parse_tree = malloc(sizeof(ParseTree)); parse_tree->roots = roots; array_init(parse_tree->scopes, 0); Scope *scope = alloc_scope(NULL); array_push(parse_tree->scopes, scope); // Fill global scope with default types. HashTable *types = scope->types; for (size_t i = 0; i < sizeof(default_types)/sizeof(Type); ++i) { Type *type = &default_types[i]; ht_insert(types, &type->name, type); } // Fill up global function symbols. for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { Node *root = parse_tree->roots[i]; if (root->type == NODE_FUN) { Node *name = root->fun.name; Symbol *fun = alloc_symval(root->fun.name, SYMBOL_FUN); fun->fun.param_types = root->fun.param_types; fun->fun.return_type = root->fun.return_type; if (!insert_symbol(scope, name, fun)) { return NULL; } } } for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { // Fill up symbol tables in proper scope and resolve type of expression // for all elements. if (!resolve_type(parse_tree, scope, parse_tree->roots[i])) { return NULL; } } return parse_tree; }