From 3da041f2e17fdeb69bf345aadf89c5fcc1814260 Mon Sep 17 00:00:00 2001 From: Bad Diode Date: Mon, 18 Apr 2022 16:27:21 -0300 Subject: Move semantic analysis to separate file --- src/main.c | 9 +- src/nodes.c | 1 - src/nodes.h | 1 - src/parser.c | 525 +------------------------------------------------------- src/parser.h | 22 +-- src/semantic.c | 527 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/viz.c | 8 +- 7 files changed, 545 insertions(+), 548 deletions(-) create mode 100644 src/semantic.c (limited to 'src') diff --git a/src/main.c b/src/main.c index 863d9d1..cdf4167 100644 --- a/src/main.c +++ b/src/main.c @@ -9,6 +9,7 @@ #include "lexer.c" #include "nodes.c" #include "parser.c" +#include "semantic.c" #include "viz.c" void @@ -29,9 +30,13 @@ process_source(const StringView *source, const char *file_name) { // print_tokens(tokens); // Parser. - ParseTree *parse_tree = parse(tokens); + Root *roots = parse(tokens); check_errors(file_name); - viz_ast(parse_tree); + // viz_ast(roots); + + // Symbol table generation and type checking. + ParseTree *parse_tree = semantic_analysis(roots); + viz_ast(parse_tree->roots); } void diff --git a/src/nodes.c b/src/nodes.c index b8a5f09..6978acc 100644 --- a/src/nodes.c +++ b/src/nodes.c @@ -11,7 +11,6 @@ alloc_node(NodeType type) { node->type = type; node->line = 0; node->col = 0; - node->scope = NULL; node->expr_type = NULL; return node; } diff --git a/src/nodes.h b/src/nodes.h index 11b30dd..af10573 100644 --- a/src/nodes.h +++ b/src/nodes.h @@ -21,7 +21,6 @@ typedef struct Node { NodeType type; size_t line; size_t col; - struct Scope *scope; struct Type *expr_type; union { diff --git a/src/parser.c b/src/parser.c index 0cf70d7..3f15b47 100644 --- a/src/parser.c +++ b/src/parser.c @@ -1,59 +1,6 @@ #include "parser.h" #include "darray.h" -typedef enum DefaultType { - TYPE_VOID, - TYPE_BOOL, - TYPE_STR, - TYPE_U8, - TYPE_U16, - TYPE_U32, - TYPE_U64, - TYPE_S8, - TYPE_S16, - TYPE_S32, - TYPE_S64, - TYPE_F32, - TYPE_F64, -} DefaultType; - -static Type default_types[] = { - [TYPE_VOID] = {STRING("void"), 0}, - [TYPE_BOOL] = {STRING("bool"), 1}, - [TYPE_STR] = {STRING("str"), 16}, // size (8) + pointer to data (8). - [TYPE_U8] = {STRING("u8"), 1}, - [TYPE_U16] = {STRING("u16"), 2}, - [TYPE_U32] = {STRING("u32"), 4}, - [TYPE_U64] = {STRING("u64"), 8}, - [TYPE_S8] = {STRING("s8"), 1}, - [TYPE_S16] = {STRING("s16"), 2}, - [TYPE_S32] = {STRING("s32"), 4}, - [TYPE_S64] = {STRING("s64"), 8}, - [TYPE_F32] = {STRING("f32"), 4}, - [TYPE_F64] = {STRING("f64"), 8}, -}; - -typedef enum SymbolType { - SYMBOL_VAR, - SYMBOL_FUN, -} SymbolType; - -typedef struct SymbolValue { - Node *name; - SymbolType type; - - union { - struct { - Node *type; - } var; - - struct { - Node **param_types; - Node *return_type; - } fun; - }; -} SymbolValue; - Token next_token(Parser *parser) { return parser->tokens[parser->current_token++]; @@ -510,67 +457,6 @@ parse_next(Parser *parser) { } } -u64 sym_hash(const struct HashTable *table, void *bytes) { - Node *symbol = bytes; - u64 hash = _xor_shift_hash(symbol->string.start, symbol->string.n); - hash = _fibonacci_hash(hash, table->shift_amount); - return hash; -} - -bool sym_eq(void *a, void *b) { - Node *a_node = a; - Node *b_node = b; - assert(a_node->type == NODE_SYMBOL); - assert(b_node->type == NODE_SYMBOL); - return sv_equal(&a_node->string, &b_node->string); -} - -u64 type_hash(const struct HashTable *table, void *bytes) { - StringView *type = bytes; - u64 hash = _xor_shift_hash(type->start, type->n); - hash = _fibonacci_hash(hash, table->shift_amount); - return hash; -} - -bool type_eq(void *a, void *b) { - StringView *a_type = a; - StringView *b_type = b; - return sv_equal(a_type, b_type); -} - -Scope * -alloc_scope(Scope *parent) { - Scope *scope = malloc(sizeof(Scope)); - scope->parent = parent; - scope->symbols = ht_init(sym_hash, sym_eq); - scope->types = ht_init(type_hash, type_eq); - return scope; -} - -ParseTree * -alloc_parsetree(void) { - ParseTree *parse_tree = malloc(sizeof(ParseTree)); - array_init(parse_tree->roots, 0); - parse_tree->global_scope = alloc_scope(NULL); - parse_tree->current_scope = parse_tree->global_scope; - - // Fill global scope with default types. - HashTable *types = parse_tree->global_scope->types; - for (size_t i = 0; i < sizeof(default_types)/sizeof(Type); ++i) { - Type *type = &default_types[i]; - ht_insert(types, &type->name, type); - } - return parse_tree; -} - -SymbolValue * -alloc_symval(Node *name, SymbolType type) { - SymbolValue *val = malloc(sizeof(SymbolValue)); - val->name = name; - val->type = type; - return val; -} - bool parse_roots(Parser *parser) { while (has_next(parser)) { @@ -582,424 +468,21 @@ parse_roots(Parser *parser) { if (node == NULL) { return false; } - array_push(parser->parse_tree->roots, node); - } - return true; -} - -Type * -find_type(Scope *scope, Node *type) { - // TODO: Normally default types will be used more often. Since we don't - // allow type shadowing, we should search first on the global scope. - while (scope != NULL) { - Type *ret = ht_lookup(scope->types, &type->string); - if (ret != NULL) { - return ret; - } - scope = scope->parent; - } - push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_TYPE, type->line, type->col); - return NULL; -} - -bool -insert_symbol(Scope *scope, Node *symbol, SymbolValue *val) { - // Check if symbol already exists. - HashTable *symbols = scope->symbols; - if (ht_lookup(symbols, symbol) != NULL) { - push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col); - return false; - } - ht_insert(symbols, symbol, val); - return true; -} - -Type * -coerce_numeric_types(Type *a, Type *b) { - // TODO: Decide what to do with mixed numeric types. What are the promotion - // rules, etc. - if (a == &default_types[TYPE_U8]) { - if (b == &default_types[TYPE_U16] || - b == &default_types[TYPE_U32] || - b == &default_types[TYPE_U64]) { - return b; - } - } else if (a == &default_types[TYPE_U16]) { - if (b == &default_types[TYPE_U32] || - b == &default_types[TYPE_U64]) { - return b; - } - } else if (a == &default_types[TYPE_U32]) { - if (b == &default_types[TYPE_U64]) { - return b; - } - } else if (a == &default_types[TYPE_S8]) { - if (b == &default_types[TYPE_S16] || - b == &default_types[TYPE_S32] || - b == &default_types[TYPE_S64]) { - return b; - } - } else if (a == &default_types[TYPE_S16]) { - if (b == &default_types[TYPE_S32] || - b == &default_types[TYPE_S64]) { - return b; - } - } else if (a == &default_types[TYPE_S32]) { - if (b == &default_types[TYPE_S64]) { - return b; - } - } else if (a == &default_types[TYPE_F32]) { - if (b == &default_types[TYPE_F64]) { - return b; - } - } - return a; -} - -bool -type_is_numeric(Type *t) { - if (t == &default_types[TYPE_U8] || - t == &default_types[TYPE_U16] || - t == &default_types[TYPE_U32] || - t == &default_types[TYPE_U64] || - t == &default_types[TYPE_S8] || - t == &default_types[TYPE_S16] || - t == &default_types[TYPE_S32] || - t == &default_types[TYPE_S64] || - t == &default_types[TYPE_F32] || - t == &default_types[TYPE_F64]) { - return true; - } - return false; -} - -SymbolValue * -find_symbol(Scope *scope, Node *node) { - while (scope != NULL) { - SymbolValue *val = ht_lookup(scope->symbols, node); - if (val != NULL) { - return val; - } - scope = scope->parent; - } - push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col); - return NULL; -} - -bool -resolve_type(Scope *scope, Node *node) { - if (node->expr_type != NULL) { - return true; - } - switch (node->type) { - case NODE_BUILTIN: { - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - if (!resolve_type(scope, arg)) { - return false; - } - } - switch (node->builtin.type) { - // Numbers. - case TOKEN_ADD: - case TOKEN_SUB: - case TOKEN_MUL: - case TOKEN_DIV: - case TOKEN_MOD: { - Type *type = NULL; - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - - // Check that all arguments are numbers. - if (!type_is_numeric(arg->expr_type)) { - push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM, - arg->line, arg->col); - return false; - } - - if (type == NULL) { - type = arg->expr_type; - } else if (type != arg->expr_type) { - type = coerce_numeric_types(type, arg->expr_type); - } - } - node->expr_type = type; - } break; - // Bools. - case TOKEN_NOT: - case TOKEN_AND: - case TOKEN_OR: { - // Check that all arguments are boolean. - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - if (arg->expr_type != &default_types[TYPE_BOOL]) { - push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_BOOL, - arg->line, arg->col); - return false; - } - } - node->expr_type = &default_types[TYPE_BOOL]; - } break; - case TOKEN_EQ: - case TOKEN_LT: - case TOKEN_GT: - case TOKEN_LE: - case TOKEN_GE: { - // Check that all arguments are nums. - for (size_t i = 0; i < array_size(node->builtin.args); ++i) { - Node *arg = node->builtin.args[i]; - if (!type_is_numeric(arg->expr_type)) { - push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM, - arg->line, arg->col); - return false; - } - } - node->expr_type = &default_types[TYPE_BOOL]; - } break; - default: break; - } - } break; - case NODE_SYMBOL: { - SymbolValue *val = find_symbol(scope, node); - if (val == NULL) { - return false; - } - - Type *type = NULL; - switch (val->type) { - case SYMBOL_VAR: { - type = find_type(scope, val->var.type); - } break; - case SYMBOL_FUN: { - type = find_type(scope, val->fun.return_type); - } break; - } - if (type == NULL) { - return false; - } - node->expr_type = type; - } break; - case NODE_FUN: { - // Fill up new scope with parameters - scope = alloc_scope(scope); - - // Parameters. - for (size_t i = 0; i < array_size(node->fun.param_names); ++i) { - Node *param = node->fun.param_names[i]; - Node *type = node->fun.param_types[i]; - SymbolValue *var = alloc_symval(param, SYMBOL_VAR); - var->var.type = type; - if (!insert_symbol(scope, param, var)) { - return false; - } - } - - // Body. - Node *body = node->fun.body; - if (body->type == NODE_BLOCK) { - body->scope = scope; - for (size_t i = 0; i < array_size(body->block.expr); ++i) { - Node *expr = body->block.expr[i]; - if (!resolve_type(scope, expr)) { - return false; - } - } - Node *last_expr = body->block.expr[array_size(body->block.expr) - 1]; - node->expr_type = last_expr->expr_type; - } else { - if (!resolve_type(scope, body)) { - return false; - } - } - - // Check that the type of body matches the return type. - StringView *type_body = &node->fun.body->expr_type->name; - StringView *return_type = &node->fun.return_type->string; - if (!sv_equal(type_body, return_type)) { - push_error(ERR_TYPE_PARSER, ERR_WRONG_RET_TYPE, node->line, node->col); - return false; - } - } break; - case NODE_BLOCK: { - scope = alloc_scope(scope); - for (size_t i = 0; i < array_size(node->block.expr); ++i) { - Node *expr = node->block.expr[i]; - if (!resolve_type(scope, expr)) { - return false; - } - } - Node *last_expr = node->block.expr[array_size(node->block.expr) - 1]; - node->expr_type = last_expr->expr_type; - } break; - case NODE_IF: { - if (!resolve_type(scope, node->ifexpr.cond)) { - return false; - } - if (!resolve_type(scope, node->ifexpr.expr_true)) { - return false; - } - Type *type_true = node->ifexpr.expr_true->expr_type; - node->expr_type = type_true; - if (node->ifexpr.expr_false != NULL) { - if (!resolve_type(scope, node->ifexpr.expr_false)) { - return false; - } - } - - // Check ifexpr.cond is a bool. - Type *type_cond = node->ifexpr.cond->expr_type; - if (!sv_equal(&type_cond->name, &default_types[TYPE_BOOL].name)) { - push_error(ERR_TYPE_PARSER, ERR_WRONG_COND_TYPE, - node->line, node->col); - return false; - } - - // Check if types of expr_true and expr_false match - if (node->ifexpr.expr_false != NULL) { - Type *type_false = node->ifexpr.expr_false->expr_type; - if (type_is_numeric(type_true) && type_is_numeric(type_false)) { - node->expr_type = coerce_numeric_types(type_true, type_false); - } else if (!sv_equal(&type_true->name, &type_false->name)) { - push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_T_F, - node->line, node->col); - return false; - } - } - } break; - case NODE_SET: { - node->expr_type = &default_types[TYPE_VOID]; - if (!resolve_type(scope, node->set.symbol)) { - return false; - } - if (!resolve_type(scope, node->set.value)) { - return false; - } - Node *symbol = node->set.symbol; - Node *value = node->set.value; - if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { - push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, - node->line, node->col); - return false; - } - } break; - case NODE_DEF: { - // Prepare value for symbol table. - SymbolValue *var = alloc_symval(node->def.symbol, SYMBOL_VAR); - var->var.type = node->def.type; - if (!insert_symbol(scope, node->def.symbol, var)) { - return false; - } - - Type *type = find_type(scope, node->def.type); - if (type == NULL) { - return false; - } - node->def.symbol->expr_type = type; - - node->expr_type = &default_types[TYPE_VOID]; - // TODO: type inference from right side when not annotated? - if (!resolve_type(scope, node->def.value)) { - return false; - } - Node *symbol = node->def.symbol; - Node *value = node->def.value; - if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { - push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, - node->line, node->col); - return false; - } - } break; - case NODE_NUMBER: { - // TODO: Numbers are f64/s64 unless explicitely annotated. Annotated - // numbers must fit in the given range (e.g. no negative constants - // inside a u64, no numbers bigger than 255 in a u8, etc.). - if (node->number.fractional != 0) { - node->expr_type = &default_types[TYPE_F64]; - } else { - node->expr_type = &default_types[TYPE_S64]; - } - } break; - case NODE_BOOL: { - node->expr_type = &default_types[TYPE_BOOL]; - } break; - case NODE_STRING: { - node->expr_type = &default_types[TYPE_STR]; - } break; - case NODE_FUNCALL: { - SymbolValue *val = find_symbol(scope, node->funcall.name); - if (!resolve_type(scope, node->funcall.name)) { - return false; - } - if (val->type != SYMBOL_FUN) { - push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_FUN, - node->funcall.name->line, node->funcall.name->col); - return false; - } - if (array_size(node->funcall.args) != array_size(val->fun.param_types)) { - push_error(ERR_TYPE_PARSER, ERR_BAD_ARGS, node->line, node->col); - return false; - } - node->expr_type = node->funcall.name->expr_type; - for (size_t i = 0; i < array_size(node->funcall.args); ++i) { - Node *arg = node->funcall.args[i]; - if (!resolve_type(scope, arg)) { - return false; - } - Node *expected = val->fun.param_types[i]; - if (!sv_equal(&arg->expr_type->name, &expected->string)) { - push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, - arg->line, arg->col); - return false; - } - } - } break; - default: break; - } - return true; -} - -bool -semantic_analysis(Parser *parser) { - // Fill up global function symbols. - Scope *scope = parser->parse_tree->global_scope; - for (size_t i = 0; i < array_size(parser->parse_tree->roots); ++i) { - Node *root = parser->parse_tree->roots[i]; - if (root->type == NODE_FUN) { - Node *name = root->fun.name; - SymbolValue *fun = alloc_symval(root->fun.name, SYMBOL_FUN); - fun->fun.param_types = root->fun.param_types; - fun->fun.return_type = root->fun.return_type; - if (!insert_symbol(scope, name, fun)) { - return false; - } - } - } - - for (size_t i = 0; i < array_size(parser->parse_tree->roots); ++i) { - // Fill up symbol tables in proper scope and resolve type of expression - // for all elements. - if (!resolve_type(scope, parser->parse_tree->roots[i])) { - return false; - } + array_push(parser->roots, node); } - return true; } -ParseTree * +Root * parse(Token *tokens) { Parser parser = { .tokens = tokens, .current_token = 0, }; - parser.parse_tree = alloc_parsetree(); + array_init(parser.roots, 0); if (!parse_roots(&parser)) { return NULL; } - if (!semantic_analysis(&parser)) { - return NULL; - } - - return parser.parse_tree; + return parser.roots; } diff --git a/src/parser.h b/src/parser.h index cc3ba92..206ca4c 100644 --- a/src/parser.h +++ b/src/parser.h @@ -3,32 +3,16 @@ #include "lexer.h" #include "nodes.h" -#include "hashtable.h" -typedef struct Type { - StringView name; - size_t size; // (bytes) -} Type; - -typedef struct Scope { - struct Scope *parent; - HashTable *symbols; - HashTable *types; -} Scope; - -typedef struct ParseTree { - Node **roots; - Scope *global_scope; - Scope *current_scope; -} ParseTree; +typedef Node* Root; typedef struct Parser { Token *tokens; size_t current_token; - ParseTree *parse_tree; + Root *roots; } Parser; -ParseTree * parse(Token *tokens); +Root * parse(Token *tokens); Node * parse_next(Parser *parser); #endif // BDL_PARSER_H diff --git a/src/semantic.c b/src/semantic.c new file mode 100644 index 0000000..06958b9 --- /dev/null +++ b/src/semantic.c @@ -0,0 +1,527 @@ +#include "hashtable.h" + +typedef struct Scope { + struct Scope *parent; + HashTable *symbols; + HashTable *types; +} Scope; + +typedef struct ParseTree { + Root *roots; + Scope *global_scope; + Scope *current_scope; +} ParseTree; + +typedef struct Type { + StringView name; + size_t size; // (bytes) +} Type; + +typedef enum DefaultType { + TYPE_VOID, + TYPE_BOOL, + TYPE_STR, + TYPE_U8, + TYPE_U16, + TYPE_U32, + TYPE_U64, + TYPE_S8, + TYPE_S16, + TYPE_S32, + TYPE_S64, + TYPE_F32, + TYPE_F64, +} DefaultType; + +static Type default_types[] = { + [TYPE_VOID] = {STRING("void"), 0}, + [TYPE_BOOL] = {STRING("bool"), 1}, + [TYPE_STR] = {STRING("str"), 16}, // size (8) + pointer to data (8). + [TYPE_U8] = {STRING("u8"), 1}, + [TYPE_U16] = {STRING("u16"), 2}, + [TYPE_U32] = {STRING("u32"), 4}, + [TYPE_U64] = {STRING("u64"), 8}, + [TYPE_S8] = {STRING("s8"), 1}, + [TYPE_S16] = {STRING("s16"), 2}, + [TYPE_S32] = {STRING("s32"), 4}, + [TYPE_S64] = {STRING("s64"), 8}, + [TYPE_F32] = {STRING("f32"), 4}, + [TYPE_F64] = {STRING("f64"), 8}, +}; + +typedef enum SymbolType { + SYMBOL_VAR, + SYMBOL_FUN, +} SymbolType; + +typedef struct Symbol { + Node *name; + SymbolType type; + + union { + struct { + Node *type; + } var; + + struct { + Node **param_types; + Node *return_type; + } fun; + }; +} Symbol; + + +Symbol * +alloc_symval(Node *name, SymbolType type) { + Symbol *val = malloc(sizeof(Symbol)); + val->name = name; + val->type = type; + return val; +} + +u64 sym_hash(const struct HashTable *table, void *bytes) { + Node *symbol = bytes; + u64 hash = _xor_shift_hash(symbol->string.start, symbol->string.n); + hash = _fibonacci_hash(hash, table->shift_amount); + return hash; +} + +bool sym_eq(void *a, void *b) { + Node *a_node = a; + Node *b_node = b; + assert(a_node->type == NODE_SYMBOL); + assert(b_node->type == NODE_SYMBOL); + return sv_equal(&a_node->string, &b_node->string); +} + +u64 type_hash(const struct HashTable *table, void *bytes) { + StringView *type = bytes; + u64 hash = _xor_shift_hash(type->start, type->n); + hash = _fibonacci_hash(hash, table->shift_amount); + return hash; +} + +bool type_eq(void *a, void *b) { + StringView *a_type = a; + StringView *b_type = b; + return sv_equal(a_type, b_type); +} + +Scope * +alloc_scope(Scope *parent) { + Scope *scope = malloc(sizeof(Scope)); + scope->parent = parent; + scope->symbols = ht_init(sym_hash, sym_eq); + scope->types = ht_init(type_hash, type_eq); + return scope; +} + +Type * +find_type(Scope *scope, Node *type) { + // TODO: Normally default types will be used more often. Since we don't + // allow type shadowing, we should search first on the global scope. + while (scope != NULL) { + Type *ret = ht_lookup(scope->types, &type->string); + if (ret != NULL) { + return ret; + } + scope = scope->parent; + } + push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_TYPE, type->line, type->col); + return NULL; +} + +bool +insert_symbol(Scope *scope, Node *symbol, Symbol *val) { + // Check if symbol already exists. + HashTable *symbols = scope->symbols; + if (ht_lookup(symbols, symbol) != NULL) { + push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col); + return false; + } + ht_insert(symbols, symbol, val); + return true; +} + +Type * +coerce_numeric_types(Type *a, Type *b) { + // TODO: Decide what to do with mixed numeric types. What are the promotion + // rules, etc. + if (a == &default_types[TYPE_U8]) { + if (b == &default_types[TYPE_U16] || + b == &default_types[TYPE_U32] || + b == &default_types[TYPE_U64]) { + return b; + } + } else if (a == &default_types[TYPE_U16]) { + if (b == &default_types[TYPE_U32] || + b == &default_types[TYPE_U64]) { + return b; + } + } else if (a == &default_types[TYPE_U32]) { + if (b == &default_types[TYPE_U64]) { + return b; + } + } else if (a == &default_types[TYPE_S8]) { + if (b == &default_types[TYPE_S16] || + b == &default_types[TYPE_S32] || + b == &default_types[TYPE_S64]) { + return b; + } + } else if (a == &default_types[TYPE_S16]) { + if (b == &default_types[TYPE_S32] || + b == &default_types[TYPE_S64]) { + return b; + } + } else if (a == &default_types[TYPE_S32]) { + if (b == &default_types[TYPE_S64]) { + return b; + } + } else if (a == &default_types[TYPE_F32]) { + if (b == &default_types[TYPE_F64]) { + return b; + } + } + return a; +} + +bool +type_is_numeric(Type *t) { + if (t == &default_types[TYPE_U8] || + t == &default_types[TYPE_U16] || + t == &default_types[TYPE_U32] || + t == &default_types[TYPE_U64] || + t == &default_types[TYPE_S8] || + t == &default_types[TYPE_S16] || + t == &default_types[TYPE_S32] || + t == &default_types[TYPE_S64] || + t == &default_types[TYPE_F32] || + t == &default_types[TYPE_F64]) { + return true; + } + return false; +} + +Symbol * +find_symbol(Scope *scope, Node *node) { + while (scope != NULL) { + Symbol *val = ht_lookup(scope->symbols, node); + if (val != NULL) { + return val; + } + scope = scope->parent; + } + push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col); + return NULL; +} + +bool +resolve_type(Scope *scope, Node *node) { + if (node->expr_type != NULL) { + return true; + } + switch (node->type) { + case NODE_BUILTIN: { + for (size_t i = 0; i < array_size(node->builtin.args); ++i) { + Node *arg = node->builtin.args[i]; + if (!resolve_type(scope, arg)) { + return false; + } + } + switch (node->builtin.type) { + // Numbers. + case TOKEN_ADD: + case TOKEN_SUB: + case TOKEN_MUL: + case TOKEN_DIV: + case TOKEN_MOD: { + Type *type = NULL; + for (size_t i = 0; i < array_size(node->builtin.args); ++i) { + Node *arg = node->builtin.args[i]; + + // Check that all arguments are numbers. + if (!type_is_numeric(arg->expr_type)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM, + arg->line, arg->col); + return false; + } + + if (type == NULL) { + type = arg->expr_type; + } else if (type != arg->expr_type) { + type = coerce_numeric_types(type, arg->expr_type); + } + } + node->expr_type = type; + } break; + // Bools. + case TOKEN_NOT: + case TOKEN_AND: + case TOKEN_OR: { + // Check that all arguments are boolean. + for (size_t i = 0; i < array_size(node->builtin.args); ++i) { + Node *arg = node->builtin.args[i]; + if (arg->expr_type != &default_types[TYPE_BOOL]) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_BOOL, + arg->line, arg->col); + return false; + } + } + node->expr_type = &default_types[TYPE_BOOL]; + } break; + case TOKEN_EQ: + case TOKEN_LT: + case TOKEN_GT: + case TOKEN_LE: + case TOKEN_GE: { + // Check that all arguments are nums. + for (size_t i = 0; i < array_size(node->builtin.args); ++i) { + Node *arg = node->builtin.args[i]; + if (!type_is_numeric(arg->expr_type)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM, + arg->line, arg->col); + return false; + } + } + node->expr_type = &default_types[TYPE_BOOL]; + } break; + default: break; + } + } break; + case NODE_SYMBOL: { + Symbol *val = find_symbol(scope, node); + if (val == NULL) { + return false; + } + + Type *type = NULL; + switch (val->type) { + case SYMBOL_VAR: { + type = find_type(scope, val->var.type); + } break; + case SYMBOL_FUN: { + type = find_type(scope, val->fun.return_type); + } break; + } + if (type == NULL) { + return false; + } + node->expr_type = type; + } break; + case NODE_FUN: { + // Fill up new scope with parameters + scope = alloc_scope(scope); + + // Parameters. + for (size_t i = 0; i < array_size(node->fun.param_names); ++i) { + Node *param = node->fun.param_names[i]; + Node *type = node->fun.param_types[i]; + Symbol *var = alloc_symval(param, SYMBOL_VAR); + var->var.type = type; + if (!insert_symbol(scope, param, var)) { + return false; + } + } + + // Body. + Node *body = node->fun.body; + if (body->type == NODE_BLOCK) { + for (size_t i = 0; i < array_size(body->block.expr); ++i) { + Node *expr = body->block.expr[i]; + if (!resolve_type(scope, expr)) { + return false; + } + } + Node *last_expr = body->block.expr[array_size(body->block.expr) - 1]; + node->expr_type = last_expr->expr_type; + } else { + if (!resolve_type(scope, body)) { + return false; + } + } + + // Check that the type of body matches the return type. + StringView *type_body = &node->fun.body->expr_type->name; + StringView *return_type = &node->fun.return_type->string; + if (!sv_equal(type_body, return_type)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_RET_TYPE, node->line, node->col); + return false; + } + } break; + case NODE_BLOCK: { + scope = alloc_scope(scope); + for (size_t i = 0; i < array_size(node->block.expr); ++i) { + Node *expr = node->block.expr[i]; + if (!resolve_type(scope, expr)) { + return false; + } + } + Node *last_expr = node->block.expr[array_size(node->block.expr) - 1]; + node->expr_type = last_expr->expr_type; + } break; + case NODE_IF: { + if (!resolve_type(scope, node->ifexpr.cond)) { + return false; + } + if (!resolve_type(scope, node->ifexpr.expr_true)) { + return false; + } + Type *type_true = node->ifexpr.expr_true->expr_type; + node->expr_type = type_true; + if (node->ifexpr.expr_false != NULL) { + if (!resolve_type(scope, node->ifexpr.expr_false)) { + return false; + } + } + + // Check ifexpr.cond is a bool. + Type *type_cond = node->ifexpr.cond->expr_type; + if (!sv_equal(&type_cond->name, &default_types[TYPE_BOOL].name)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_COND_TYPE, + node->line, node->col); + return false; + } + + // Check if types of expr_true and expr_false match + if (node->ifexpr.expr_false != NULL) { + Type *type_false = node->ifexpr.expr_false->expr_type; + if (type_is_numeric(type_true) && type_is_numeric(type_false)) { + node->expr_type = coerce_numeric_types(type_true, type_false); + } else if (!sv_equal(&type_true->name, &type_false->name)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_T_F, + node->line, node->col); + return false; + } + } + } break; + case NODE_SET: { + node->expr_type = &default_types[TYPE_VOID]; + if (!resolve_type(scope, node->set.symbol)) { + return false; + } + if (!resolve_type(scope, node->set.value)) { + return false; + } + Node *symbol = node->set.symbol; + Node *value = node->set.value; + if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { + push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, + node->line, node->col); + return false; + } + } break; + case NODE_DEF: { + // Prepare value for symbol table. + Symbol *var = alloc_symval(node->def.symbol, SYMBOL_VAR); + var->var.type = node->def.type; + if (!insert_symbol(scope, node->def.symbol, var)) { + return false; + } + + Type *type = find_type(scope, node->def.type); + if (type == NULL) { + return false; + } + node->def.symbol->expr_type = type; + + node->expr_type = &default_types[TYPE_VOID]; + // TODO: type inference from right side when not annotated? + if (!resolve_type(scope, node->def.value)) { + return false; + } + Node *symbol = node->def.symbol; + Node *value = node->def.value; + if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) { + push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, + node->line, node->col); + return false; + } + } break; + case NODE_NUMBER: { + // TODO: Numbers are f64/s64 unless explicitely annotated. Annotated + // numbers must fit in the given range (e.g. no negative constants + // inside a u64, no numbers bigger than 255 in a u8, etc.). + if (node->number.fractional != 0) { + node->expr_type = &default_types[TYPE_F64]; + } else { + node->expr_type = &default_types[TYPE_S64]; + } + } break; + case NODE_BOOL: { + node->expr_type = &default_types[TYPE_BOOL]; + } break; + case NODE_STRING: { + node->expr_type = &default_types[TYPE_STR]; + } break; + case NODE_FUNCALL: { + Symbol *val = find_symbol(scope, node->funcall.name); + if (!resolve_type(scope, node->funcall.name)) { + return false; + } + if (val->type != SYMBOL_FUN) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_FUN, + node->funcall.name->line, node->funcall.name->col); + return false; + } + if (array_size(node->funcall.args) != array_size(val->fun.param_types)) { + push_error(ERR_TYPE_PARSER, ERR_BAD_ARGS, node->line, node->col); + return false; + } + node->expr_type = node->funcall.name->expr_type; + for (size_t i = 0; i < array_size(node->funcall.args); ++i) { + Node *arg = node->funcall.args[i]; + if (!resolve_type(scope, arg)) { + return false; + } + Node *expected = val->fun.param_types[i]; + if (!sv_equal(&arg->expr_type->name, &expected->string)) { + push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, + arg->line, arg->col); + return false; + } + } + } break; + default: break; + } + return true; +} + +ParseTree * +semantic_analysis(Root *roots) { + ParseTree *parse_tree = malloc(sizeof(ParseTree)); + parse_tree->roots = roots; + parse_tree->global_scope = alloc_scope(NULL); + parse_tree->current_scope = parse_tree->global_scope; + + // Fill global scope with default types. + HashTable *types = parse_tree->global_scope->types; + for (size_t i = 0; i < sizeof(default_types)/sizeof(Type); ++i) { + Type *type = &default_types[i]; + ht_insert(types, &type->name, type); + } + + // Fill up global function symbols. + Scope *scope = parse_tree->global_scope; + for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { + Node *root = parse_tree->roots[i]; + if (root->type == NODE_FUN) { + Node *name = root->fun.name; + Symbol *fun = alloc_symval(root->fun.name, SYMBOL_FUN); + fun->fun.param_types = root->fun.param_types; + fun->fun.return_type = root->fun.return_type; + if (!insert_symbol(scope, name, fun)) { + return NULL; + } + } + } + + for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { + // Fill up symbol tables in proper scope and resolve type of expression + // for all elements. + if (!resolve_type(scope, parse_tree->roots[i])) { + return NULL; + } + } + + return parse_tree; +} diff --git a/src/viz.c b/src/viz.c index 81cc1ff..d519d2c 100644 --- a/src/viz.c +++ b/src/viz.c @@ -152,8 +152,8 @@ viz_node(Node *node) { } void -viz_ast(ParseTree *parse_tree) { - if (parse_tree == NULL) { +viz_ast(Root *roots) { + if (roots == NULL) { return; } printf("digraph ast {\n"); @@ -161,9 +161,9 @@ viz_ast(ParseTree *parse_tree) { printf("ranksep=\"0.95 equally\";\n"); printf("nodesep=\"0.5 equally\";\n"); printf("overlap=scale;\n"); - for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { + for (size_t i = 0; i < array_size(roots); ++i) { printf("subgraph %zu {\n", i); - Node *root = parse_tree->roots[array_size(parse_tree->roots) - 1 - i]; + Node *root = roots[array_size(roots) - 1 - i]; viz_node(root); printf("}\n"); } -- cgit v1.2.1