From dcd3192e50d7b4ea333ecf57a7e8b325af145547 Mon Sep 17 00:00:00 2001 From: Bad Diode Date: Mon, 18 Apr 2022 15:51:53 -0300 Subject: Add a more rich symbol table value and typecheck funcall args --- src/errors.c | 2 + src/errors.h | 2 + src/parser.c | 127 ++++++++++++++++++++++++++++++++++++++++++++--------------- 3 files changed, 100 insertions(+), 31 deletions(-) (limited to 'src') diff --git a/src/errors.c b/src/errors.c index e2dee00..6a69064 100644 --- a/src/errors.c +++ b/src/errors.c @@ -24,6 +24,8 @@ static const char* error_msgs[] = { [ERR_WRONG_TYPE_T_F] = "error: unmatched types between true and false expression", [ERR_WRONG_TYPE_NUM] = "error: non numeric argument types", [ERR_WRONG_TYPE_BOOL] = "error: non bool argument types", + [ERR_WRONG_TYPE_FUN] = "error: not a function", + [ERR_BAD_ARGS] = "error: arguments don't match expected types", [ERR_TYPE_MISMATCH] = "error: type mismatch", }; diff --git a/src/errors.h b/src/errors.h index a814549..f2737d0 100644 --- a/src/errors.h +++ b/src/errors.h @@ -32,7 +32,9 @@ typedef enum ErrorValue { ERR_WRONG_TYPE_T_F, ERR_WRONG_TYPE_NUM, ERR_WRONG_TYPE_BOOL, + ERR_WRONG_TYPE_FUN, ERR_TYPE_MISMATCH, + ERR_BAD_ARGS, ERR_OK, } ErrorValue; diff --git a/src/parser.c b/src/parser.c index 10d82d1..0cf70d7 100644 --- a/src/parser.c +++ b/src/parser.c @@ -33,6 +33,27 @@ static Type default_types[] = { [TYPE_F64] = {STRING("f64"), 8}, }; +typedef enum SymbolType { + SYMBOL_VAR, + SYMBOL_FUN, +} SymbolType; + +typedef struct SymbolValue { + Node *name; + SymbolType type; + + union { + struct { + Node *type; + } var; + + struct { + Node **param_types; + Node *return_type; + } fun; + }; +} SymbolValue; + Token next_token(Parser *parser) { return parser->tokens[parser->current_token++]; @@ -542,6 +563,14 @@ alloc_parsetree(void) { return parse_tree; } +SymbolValue * +alloc_symval(Node *name, SymbolType type) { + SymbolValue *val = malloc(sizeof(SymbolValue)); + val->name = name; + val->type = type; + return val; +} + bool parse_roots(Parser *parser) { while (has_next(parser)) { @@ -573,36 +602,22 @@ find_type(Scope *scope, Node *type) { return NULL; } -Type * -find_symbol(Scope *scope, Node *node) { - while (scope != NULL) { - Type *type = ht_lookup(scope->symbols, node); - if (type != NULL) { - return type; - } - scope = scope->parent; - } - push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col); - return NULL; -} - -Type * -insert_symbol(Scope *scope, Node *symbol, Node *type) { +bool +insert_symbol(Scope *scope, Node *symbol, SymbolValue *val) { + // Check if symbol already exists. HashTable *symbols = scope->symbols; if (ht_lookup(symbols, symbol) != NULL) { push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col); - return NULL; - } - Type *t = find_type(scope, type); - if (t == NULL) { - return NULL; + return false; } - ht_insert(symbols, symbol, t); - return t; + ht_insert(symbols, symbol, val); + return true; } Type * coerce_numeric_types(Type *a, Type *b) { + // TODO: Decide what to do with mixed numeric types. What are the promotion + // rules, etc. if (a == &default_types[TYPE_U8]) { if (b == &default_types[TYPE_U16] || b == &default_types[TYPE_U32] || @@ -658,6 +673,19 @@ type_is_numeric(Type *t) { return false; } +SymbolValue * +find_symbol(Scope *scope, Node *node) { + while (scope != NULL) { + SymbolValue *val = ht_lookup(scope->symbols, node); + if (val != NULL) { + return val; + } + scope = scope->parent; + } + push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col); + return NULL; +} + bool resolve_type(Scope *scope, Node *node) { if (node->expr_type != NULL) { @@ -732,7 +760,20 @@ resolve_type(Scope *scope, Node *node) { } } break; case NODE_SYMBOL: { - Type *type = find_symbol(scope, node); + SymbolValue *val = find_symbol(scope, node); + if (val == NULL) { + return false; + } + + Type *type = NULL; + switch (val->type) { + case SYMBOL_VAR: { + type = find_type(scope, val->var.type); + } break; + case SYMBOL_FUN: { + type = find_type(scope, val->fun.return_type); + } break; + } if (type == NULL) { return false; } @@ -746,7 +787,9 @@ resolve_type(Scope *scope, Node *node) { for (size_t i = 0; i < array_size(node->fun.param_names); ++i) { Node *param = node->fun.param_names[i]; Node *type = node->fun.param_types[i]; - if (insert_symbol(scope, param, type) == NULL) { + SymbolValue *var = alloc_symval(param, SYMBOL_VAR); + var->var.type = type; + if (!insert_symbol(scope, param, var)) { return false; } } @@ -840,7 +883,14 @@ resolve_type(Scope *scope, Node *node) { } } break; case NODE_DEF: { - Type *type = insert_symbol(scope, node->def.symbol, node->def.type); + // Prepare value for symbol table. + SymbolValue *var = alloc_symval(node->def.symbol, SYMBOL_VAR); + var->var.type = node->def.type; + if (!insert_symbol(scope, node->def.symbol, var)) { + return false; + } + + Type *type = find_type(scope, node->def.type); if (type == NULL) { return false; } @@ -862,7 +912,7 @@ resolve_type(Scope *scope, Node *node) { case NODE_NUMBER: { // TODO: Numbers are f64/s64 unless explicitely annotated. Annotated // numbers must fit in the given range (e.g. no negative constants - // inside a U64, no numbers bigger than 255 in a u8, etc.). + // inside a u64, no numbers bigger than 255 in a u8, etc.). if (node->number.fractional != 0) { node->expr_type = &default_types[TYPE_F64]; } else { @@ -876,16 +926,31 @@ resolve_type(Scope *scope, Node *node) { node->expr_type = &default_types[TYPE_STR]; } break; case NODE_FUNCALL: { - // TODO: Check that arguments type check with expectations. + SymbolValue *val = find_symbol(scope, node->funcall.name); if (!resolve_type(scope, node->funcall.name)) { return false; } + if (val->type != SYMBOL_FUN) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_FUN, + node->funcall.name->line, node->funcall.name->col); + return false; + } + if (array_size(node->funcall.args) != array_size(val->fun.param_types)) { + push_error(ERR_TYPE_PARSER, ERR_BAD_ARGS, node->line, node->col); + return false; + } node->expr_type = node->funcall.name->expr_type; for (size_t i = 0; i < array_size(node->funcall.args); ++i) { Node *arg = node->funcall.args[i]; if (!resolve_type(scope, arg)) { return false; } + Node *expected = val->fun.param_types[i]; + if (!sv_equal(&arg->expr_type->name, &expected->string)) { + push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH, + arg->line, arg->col); + return false; + } } } break; default: break; @@ -901,10 +966,10 @@ semantic_analysis(Parser *parser) { Node *root = parser->parse_tree->roots[i]; if (root->type == NODE_FUN) { Node *name = root->fun.name; - // TODO: make sure we store information in the symbol table that - // this is actually a function, not just a variable with - // return_type. - if (insert_symbol(scope, name, root->fun.return_type) == NULL) { + SymbolValue *fun = alloc_symval(root->fun.name, SYMBOL_FUN); + fun->fun.param_types = root->fun.param_types; + fun->fun.return_type = root->fun.return_type; + if (!insert_symbol(scope, name, fun)) { return false; } } -- cgit v1.2.1