From c6fd7856bfe5dd0567f672d0d1a70a3b698feaa4 Mon Sep 17 00:00:00 2001 From: Bad Diode Date: Sun, 23 Jun 2024 15:55:34 +0200 Subject: Add more expressions to type inference method --- src/main.c | 126 +++++++++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 101 insertions(+), 25 deletions(-) (limited to 'src') diff --git a/src/main.c b/src/main.c index 3614e56..077039a 100644 --- a/src/main.c +++ b/src/main.c @@ -79,6 +79,8 @@ typedef struct Analyzer { Str file_name; sz scope_gen; Scope **scopes; + StrSet *numeric_types; + StrSet *integer_types; } Analyzer; Scope * @@ -195,7 +197,83 @@ type_inference(Analyzer *a, Node *node, Scope *scope) { assert(a); assert(scope); assert(node); + // NOTE: For now we are not going to do implicit numeric conversions. switch (node->kind) { + case NODE_TRUE: + case NODE_FALSE: return cstr("bool"); + case NODE_NOT: + case NODE_AND: + case NODE_OR: { + Type left = type_inference(a, node->left, scope); + if (!str_eq(left, cstr("bool"))) { + emit_semantic_error(a, node, + cstr("expected bool on logic expression")); + return (Type){0}; + } + if (node->right) { + Type right = type_inference(a, node->right, scope); + if (!str_eq(right, cstr("bool"))) { + emit_semantic_error( + a, node, cstr("expected bool on logic expression")); + return (Type){0}; + } + } + return cstr("bool"); + } break; + case NODE_EQ: + case NODE_NEQ: + case NODE_LT: + case NODE_GT: + case NODE_LE: + case NODE_GE: { + Type left = type_inference(a, node->left, scope); + Type right = type_inference(a, node->right, scope); + if (!str_eq(left, right)) { + emit_semantic_error( + a, node, cstr("mismatched types on binary expression")); + return (Type){0}; + } + return cstr("bool"); + } break; + case NODE_BITNOT: + case NODE_BITAND: + case NODE_BITOR: + case NODE_BITLSHIFT: + case NODE_BITRSHIFT: { + Type left = type_inference(a, node->left, scope); + Type right = type_inference(a, node->right, scope); + if (!strset_lookup(&a->integer_types, left) || + !strset_lookup(&a->integer_types, right)) { + emit_semantic_error( + a, node, cstr("non integer type on bit twiddling expr")); + return (Type){0}; + } + if (!str_eq(left, right)) { + emit_semantic_error( + a, node, cstr("mismatched types on binary expression")); + return (Type){0}; + } + } break; + case NODE_ADD: + case NODE_SUB: + case NODE_DIV: + case NODE_MUL: + case NODE_MOD: { + Type left = type_inference(a, node->left, scope); + Type right = type_inference(a, node->right, scope); + if (!strset_lookup(&a->numeric_types, left) || + !strset_lookup(&a->numeric_types, right)) { + emit_semantic_error( + a, node, cstr("non numeric type on arithmetic expr")); + return (Type){0}; + } + if (!str_eq(left, right)) { + emit_semantic_error( + a, node, cstr("mismatched types on binary expression")); + return (Type){0}; + } + return left; + } break; case NODE_NUM_UINT: return cstr("uint"); break; case NODE_NUM_INT: return cstr("int"); break; case NODE_NUM_FLOAT: return cstr("f64"); break; @@ -471,7 +549,7 @@ symbolic_analysis(Analyzer *a, Parser *parser) { assert(parser); assert(scope); - // Fill builtin functions. + // Fill builtin tables. Str builtin_functions[] = { cstr("print"), cstr("println"), @@ -481,35 +559,33 @@ symbolic_analysis(Analyzer *a, Parser *parser) { Symbol sym = (Symbol){.kind = SYM_BUILTIN, .name = symbol}; symmap_insert(&scope->symbols, symbol, sym, a->storage); } - - // Fill builtin types. Type builtin_types[] = { - cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), cstr("u32"), - cstr("s32"), cstr("u64"), cstr("s64"), cstr("f32"), cstr("f64"), - cstr("ptr"), cstr("int"), cstr("uint"), cstr("str"), - }; + cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), + cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"), + cstr("f32"), cstr("f64"), cstr("ptr"), cstr("int"), + cstr("uint"), cstr("str"), cstr("bool"), cstr("nil")}; for (sz i = 0; i < LEN(builtin_types); i++) { Type type = builtin_types[i]; typemap_insert(&scope->types, type, type, a->storage); } - - // Str valid_int_types[] = { - // cstr("u8"), cstr("u16"), cstr("u32"), cstr("u64"), cstr("s8"), - // cstr("s16"), cstr("s32"), cstr("s64"), cstr("int"), cstr("uint"), - // }; - // for (sz i = 0; i < LEN(valid_int_types); i++) { - // Str type = valid_int_types[i]; - // strset_insert(&a->valid_int_types, type, a->storage); - // } - - // Str valid_float_types[] = { - // cstr("f32"), - // cstr("f64"), - // }; - // for (sz i = 0; i < LEN(valid_float_types); i++) { - // Str type = valid_float_types[i]; - // strset_insert(&a->valid_float_types, type, a->storage); - // } + Type numeric_types[] = { + cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), cstr("u32"), + cstr("s32"), cstr("u64"), cstr("s64"), cstr("f32"), cstr("f64"), + cstr("ptr"), cstr("int"), cstr("uint"), + }; + for (sz i = 0; i < LEN(numeric_types); i++) { + Type type = numeric_types[i]; + strset_insert(&a->numeric_types, type, a->storage); + } + Type integer_types[] = { + cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), + cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"), + cstr("ptr"), cstr("int"), cstr("uint"), + }; + for (sz i = 0; i < LEN(integer_types); i++) { + Type type = integer_types[i]; + strset_insert(&a->integer_types, type, a->storage); + } // Find top level function declarations. for (sz i = 0; i < array_size(parser->nodes); i++) { -- cgit v1.2.1