From 233dd92768a54060df9096558aa58c1f598cce7d Mon Sep 17 00:00:00 2001 From: Bad Diode Date: Fri, 8 Apr 2022 18:49:40 -0300 Subject: Add rudimentary type checking --- src/errors.c | 5 ++ src/errors.h | 5 ++ src/nodes.c | 2 +- src/nodes.h | 10 +-- src/parser.c | 233 +++++++++++++++++++++++++++++++++++++++++++++++------------ src/parser.h | 7 +- src/viz.c | 7 ++ 7 files changed, 212 insertions(+), 57 deletions(-) (limited to 'src') diff --git a/src/errors.c b/src/errors.c index 2781bf5..b987b34 100644 --- a/src/errors.c +++ b/src/errors.c @@ -17,6 +17,11 @@ static const char* error_msgs[] = { [ERR_NOT_A_RPAREN] = "error: expected closing parentheses (rparen)", [ERR_SYMBOL_REDEF] = "error: symbol redefinition", [ERR_UNKNOWN_SYMBOL] = "error: unknown symbol", + [ERR_TYPE_REDEF] = "error: type redefinition", + [ERR_UNKNOWN_TYPE] = "error: unknown type", + [ERR_WRONG_RET_TYPE] = "error: return type don't match type signature", + [ERR_WRONG_COND_TYPE] = "error: conditional expression is not boolean", + [ERR_WRONG_TYPE_T_F] = "error: unmatched types between true and false expression", }; static Error current_error = {.value = ERR_OK}; diff --git a/src/errors.h b/src/errors.h index eb83e52..a84305b 100644 --- a/src/errors.h +++ b/src/errors.h @@ -25,6 +25,11 @@ typedef enum ErrorValue { ERR_NOT_A_RPAREN, ERR_SYMBOL_REDEF, ERR_UNKNOWN_SYMBOL, + ERR_TYPE_REDEF, + ERR_UNKNOWN_TYPE, + ERR_WRONG_RET_TYPE, + ERR_WRONG_COND_TYPE, + ERR_WRONG_TYPE_T_F, ERR_OK, } ErrorValue; diff --git a/src/nodes.c b/src/nodes.c index 35d123a..51cc9ef 100644 --- a/src/nodes.c +++ b/src/nodes.c @@ -12,7 +12,7 @@ alloc_node(NodeType type) { node->line = 0; node->col = 0; node->scope = NULL; - node->type_class = TYPE_UNK; + node->expr_type = NULL; return node; } diff --git a/src/nodes.h b/src/nodes.h index be6f7df..52022cc 100644 --- a/src/nodes.h +++ b/src/nodes.h @@ -15,21 +15,13 @@ typedef enum NodeType { NODE_IF, } NodeType; -typedef enum TypeClass { - TYPE_UNK, - TYPE_NONE, - TYPE_NUM, - TYPE_BOOL, - TYPE_STRING, -} TypeClass; - typedef struct Node { size_t id; NodeType type; size_t line; size_t col; struct Scope *scope; - TypeClass type_class; + struct Type *expr_type; union { // Numbers. diff --git a/src/parser.c b/src/parser.c index 0594d2c..ddcee56 100644 --- a/src/parser.c +++ b/src/parser.c @@ -1,6 +1,38 @@ #include "parser.h" #include "darray.h" +typedef enum DefaultType { + TYPE_VOID, + TYPE_BOOL, + TYPE_STR, + TYPE_U8, + TYPE_U16, + TYPE_U32, + TYPE_U64, + TYPE_S8, + TYPE_S16, + TYPE_S32, + TYPE_S64, + TYPE_F32, + TYPE_F64, +} DefaultType; + +static Type default_types[] = { + [TYPE_VOID] = {STRING("void"), 0}, + [TYPE_BOOL] = {STRING("bool"), 1}, + [TYPE_STR] = {STRING("str"), 16}, // size (8) + pointer to data (8). + [TYPE_U8] = {STRING("u8"), 1}, + [TYPE_U16] = {STRING("u16"), 2}, + [TYPE_U32] = {STRING("u32"), 4}, + [TYPE_U64] = {STRING("u64"), 8}, + [TYPE_S8] = {STRING("s8"), 1}, + [TYPE_S16] = {STRING("s16"), 2}, + [TYPE_S32] = {STRING("s32"), 4}, + [TYPE_S64] = {STRING("s64"), 8}, + [TYPE_F32] = {STRING("f32"), 4}, + [TYPE_F64] = {STRING("f64"), 8}, +}; + Token next_token(Parser *parser) { return parser->tokens[parser->current_token++]; @@ -445,11 +477,25 @@ bool sym_eq(void *a, void *b) { return sv_equal(&a_node->string, &b_node->string); } +u64 type_hash(const struct HashTable *table, void *bytes) { + StringView *type = bytes; + u64 hash = _xor_shift_hash(type->start, type->n); + hash = _fibonacci_hash(hash, table->shift_amount); + return hash; +} + +bool type_eq(void *a, void *b) { + StringView *a_type = a; + StringView *b_type = b; + return sv_equal(a_type, b_type); +} + Scope * alloc_scope(Scope *parent) { Scope *scope = malloc(sizeof(Scope)); scope->parent = parent; scope->symbols = ht_init(sym_hash, sym_eq); + scope->types = ht_init(type_hash, type_eq); return scope; } @@ -459,19 +505,14 @@ alloc_parsetree(void) { array_init(parse_tree->roots, 0); parse_tree->global_scope = alloc_scope(NULL); parse_tree->current_scope = parse_tree->global_scope; - // TODO: Fill global scope with default types/symbols. - return parse_tree; -} -bool -insert_symbol(Parser *parser, Node *symbol) { - HashTable *symbols = parser->parse_tree->current_scope->symbols; - if (ht_lookup(symbols, symbol) != NULL) { - push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col); - return false; + // Fill global scope with default types. + HashTable *types = parse_tree->global_scope->types; + for (size_t i = 0; i < sizeof(default_types)/sizeof(Type); ++i) { + Type *type = &default_types[i]; + ht_insert(types, &type->name, type); } - ht_insert(symbols, symbol, symbol); - return true; + return parse_tree; } bool @@ -486,17 +527,67 @@ parse_roots(Parser *parser) { return true; } -bool +Type * +find_type(Parser *parser, Node *type) { + // Normally default types will be used more often. Since we don't + // allow type shadowing, we search first on the global scope. + Scope *scope = parser->parse_tree->global_scope; + Type *ret = ht_lookup(scope->types, &type->string); + if (ret != NULL) { + return ret; + } + scope = parser->parse_tree->current_scope; + while (scope->parent != NULL) { + Type *ret = ht_lookup(scope->types, &type->string); + if (ret != NULL) { + return ret; + } + scope = scope->parent; + } + push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_TYPE, type->line, type->col); + return NULL; +} + +// TODO: Review this function when needed +// bool +// insert_type(Parser *parser, Node *type, size_t size) { +// HashTable *types = parser->parse_tree->current_scope->types; +// if (ht_lookup(types, type) != NULL) { +// push_error(ERR_TYPE_PARSER, ERR_TYPE_REDEF, type->line, type->col); +// return false; +// } +// // TODO: alloc_type. +// ht_insert(types, &type->string, type); +// return true; +// } + +Type * find_symbol(Parser *parser, Node *node) { Scope *scope = parser->parse_tree->current_scope; while (scope != NULL) { - if (ht_lookup(scope->symbols, node) != NULL) { - return true; + Type *type = ht_lookup(scope->symbols, node); + if (type != NULL) { + return type; } scope = scope->parent; } push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col); - return false; + return NULL; +} + +bool +insert_symbol(Parser *parser, Node *symbol, Node *type) { + HashTable *symbols = parser->parse_tree->current_scope->symbols; + if (ht_lookup(symbols, symbol) != NULL) { + push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col); + return false; + } + Type *t = find_type(parser, type); + if (t == NULL) { + return NULL; + } + ht_insert(symbols, symbol, t); + return true; } bool @@ -511,13 +602,13 @@ symbol_check(Parser *parser, Node *node) { } } break; case NODE_SYMBOL: { - if (!find_symbol(parser, node)) { + if (find_symbol(parser, node) == NULL) { return false; } } break; case NODE_SET: { Node *symbol = node->set.symbol; - if (!find_symbol(parser, symbol)) { + if (find_symbol(parser, symbol) == NULL) { return false; } if (!symbol_check(parser, node->set.value)) { @@ -532,7 +623,8 @@ symbol_check(Parser *parser, Node *node) { // Parameters. for (size_t i = 0; i < array_size(node->fun.param_names); ++i) { Node *param = node->fun.param_names[i]; - if (!insert_symbol(parser, param)) { + Node *type = node->fun.param_types[i]; + if (!insert_symbol(parser, param, type)) { return false; } } @@ -583,7 +675,7 @@ symbol_check(Parser *parser, Node *node) { if (!symbol_check(parser, node->def.value)) { return false; } - if (!insert_symbol(parser, node->def.symbol)) { + if (!insert_symbol(parser, node->def.symbol, node->def.type)) { return false; } } break; @@ -593,7 +685,15 @@ symbol_check(Parser *parser, Node *node) { } bool -resolve_typeclass(Parser *parser, Node *node) { +resolve_type(Parser *parser, Node *node) { + if (node->expr_type != NULL) { + return true; + } + Scope *prev_scope = NULL; + if (node->scope != NULL) { + prev_scope = parser->parse_tree->current_scope; + parser->parse_tree->current_scope = node->scope; + } switch (node->type) { case NODE_BUILTIN: { switch (node->builtin.type) { @@ -602,7 +702,10 @@ resolve_typeclass(Parser *parser, Node *node) { case TOKEN_SUB: case TOKEN_MUL: case TOKEN_DIV: - case TOKEN_MOD: { node->type_class = TYPE_NUM; } break; + case TOKEN_MOD: { + // TODO: Properly resolve this + node->expr_type = &default_types[TYPE_U64]; + } break; // Bools. case TOKEN_NOT: case TOKEN_AND: @@ -611,49 +714,87 @@ resolve_typeclass(Parser *parser, Node *node) { case TOKEN_LT: case TOKEN_GT: case TOKEN_LE: - case TOKEN_GE: { node->type_class = TYPE_BOOL; } break; + case TOKEN_GE: { + node->expr_type = &default_types[TYPE_BOOL]; + } break; default: break; } for (size_t i = 0; i < array_size(node->builtin.args); ++i) { Node *arg = node->builtin.args[i]; - resolve_typeclass(parser, arg); + resolve_type(parser, arg); } } break; case NODE_SYMBOL: { - // TODO: Resolve symbol type? + node->expr_type = find_symbol(parser, node); } break; case NODE_FUN: { - // TODO: Resolve `node->type_class` based on the return value. - resolve_typeclass(parser, node->fun.body); + resolve_type(parser, node->fun.body); + StringView *type_body = &node->fun.body->expr_type->name; + StringView *return_type = &node->fun.return_type->string; + // Check that the type of body matches the return type. + if (!sv_equal(type_body, return_type)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_RET_TYPE, node->line, node->col); + return false; + } } break; case NODE_BLOCK: { for (size_t i = 0; i < array_size(node->block.expr); ++i) { Node *expr = node->block.expr[i]; - resolve_typeclass(parser, expr); + resolve_type(parser, expr); } Node *last_expr = node->block.expr[array_size(node->block.expr) - 1]; - node->type_class = last_expr->type_class; + node->expr_type = last_expr->expr_type; } break; case NODE_IF: { - node->type_class = TYPE_NONE; - resolve_typeclass(parser, node->ifexpr.cond); - resolve_typeclass(parser, node->ifexpr.expr_true); + resolve_type(parser, node->ifexpr.cond); + // Check ifexpr.cond is a bool. + if (!sv_equal(&node->ifexpr.cond->expr_type->name, &default_types[TYPE_BOOL].name)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_COND_TYPE, node->line, node->col); + return false; + } + + resolve_type(parser, node->ifexpr.expr_true); if (node->ifexpr.expr_false != NULL) { - resolve_typeclass(parser, node->ifexpr.expr_false); + resolve_type(parser, node->ifexpr.expr_false); + // Check if types of expr_true and expr_false match + if (!sv_equal(&node->ifexpr.expr_true->expr_type->name, &node->ifexpr.expr_false->expr_type->name)) { + push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_T_F, node->line, node->col); + return false; + } } + node->expr_type = node->ifexpr.expr_true->expr_type; } break; case NODE_SET: { - node->type_class = TYPE_NONE; - resolve_typeclass(parser, node->set.value); + node->expr_type = &default_types[TYPE_VOID]; + resolve_type(parser, node->set.value); } break; case NODE_DEF: { - node->type_class = TYPE_NONE; - resolve_typeclass(parser, node->def.value); + node->expr_type = &default_types[TYPE_VOID]; + resolve_type(parser, node->def.value); } break; - case NODE_NUMBER: { node->type_class = TYPE_NUM; } break; - case NODE_BOOL: { node->type_class = TYPE_BOOL; } break; - case NODE_STRING: { node->type_class = TYPE_STRING; } break; - case NODE_TYPE: { node->type_class = TYPE_NONE; } break; + case NODE_NUMBER: { + // TODO: Is this the best way of doing it? We probably need a more + // sophisticated way of approaching this. For example: + // `(if (< 1 2) 1 -2)` will currently fail with this approach, since + // 1:u64 and -2:s64 + if (node->number.fractional != 0) { + node->expr_type = &default_types[TYPE_F64]; + } else if (node->number.negative) { + node->expr_type = &default_types[TYPE_S64]; + } else { + node->expr_type = &default_types[TYPE_U64]; + } + } break; + case NODE_BOOL: { + node->expr_type = &default_types[TYPE_BOOL]; + } break; + case NODE_STRING: { + node->expr_type = &default_types[TYPE_STR]; + } break; + default: break; + } + if (node->scope != NULL) { + parser->parse_tree->current_scope = prev_scope; } return true; } @@ -665,7 +806,10 @@ semantic_analysis(Parser *parser) { Node *root = parser->parse_tree->roots[i]; if (root->type == NODE_FUN) { Node *name = root->fun.name; - if (!insert_symbol(parser, name)) { + // TODO: make sure we store information in the symbol table that + // this is actually a function, not just a variable with + // return_type. + if (!insert_symbol(parser, name, root->fun.return_type)) { return false; } } @@ -674,13 +818,10 @@ semantic_analysis(Parser *parser) { for (size_t i = 0; i < array_size(parser->parse_tree->roots); ++i) { // Fill up symbol tables in proper scope and check existance. symbol_check(parser, parser->parse_tree->roots[i]); - // Resolve TypeClass for all elements. - resolve_typeclass(parser, parser->parse_tree->roots[i]); + // Resolve type of expression for all elements. + resolve_type(parser, parser->parse_tree->roots[i]); } - // TODO: Resolve concrete types. - // TODO: Type check. - return true; } diff --git a/src/parser.h b/src/parser.h index 3c2dc2b..cc3ba92 100644 --- a/src/parser.h +++ b/src/parser.h @@ -5,10 +5,15 @@ #include "nodes.h" #include "hashtable.h" +typedef struct Type { + StringView name; + size_t size; // (bytes) +} Type; + typedef struct Scope { struct Scope *parent; HashTable *symbols; - // HashTable types; + HashTable *types; } Scope; typedef struct ParseTree { diff --git a/src/viz.c b/src/viz.c index d409472..8b5d4cf 100644 --- a/src/viz.c +++ b/src/viz.c @@ -19,6 +19,10 @@ viz_node(Node *node) { } printf("%zu [width=2.5,shape=Mrecord,label=\"", node->id); printf(" %s -- [%4ld:%-4ld] ", node_str[node->type], node->line, node->col); + if (node->expr_type != NULL) { + printf("| T: "); + sv_write(&node->expr_type->name); + } switch (node->type) { case NODE_NUMBER: { printf("| Value: "); @@ -136,6 +140,9 @@ viz_node(Node *node) { void viz_ast(ParseTree *parse_tree) { + if (parse_tree == NULL) { + return; + } printf("digraph ast {\n"); printf("rankdir=LR;\n"); printf("ranksep=\"0.95 equally\";\n"); -- cgit v1.2.1