#ifndef SEMANTIC_C
#define SEMANTIC_C
#include "badlib.h"
#include "parser.c"
typedef enum {
SYM_UNKNOWN,
SYM_BUILTIN_FUN,
SYM_BUILTIN_TYPE,
SYM_FUN,
SYM_VAR,
SYM_PARAM,
SYM_ENUM,
SYM_ENUM_FIELD,
SYM_STRUCT,
SYM_STRUCT_FIELD,
} SymbolKind;
Str sym_kind_str[] = {
[SYM_UNKNOWN] = cstr("UNKNOWN "),
[SYM_BUILTIN_FUN] = cstr("BUILTIN FUN "),
[SYM_BUILTIN_TYPE] = cstr("BUILTIN TYPE "),
[SYM_FUN] = cstr("FUNCTION "),
[SYM_VAR] = cstr("VARIABLE "),
[SYM_PARAM] = cstr("PARAMETER "),
[SYM_ENUM] = cstr("ENUM "),
[SYM_ENUM_FIELD] = cstr("ENUM FIELD "),
[SYM_STRUCT] = cstr("STRUCT "),
[SYM_STRUCT_FIELD] = cstr("STRUCT FIELD "),
};
typedef struct Symbol {
Str name;
SymbolKind kind;
} Symbol;
typedef struct Fun {
Str name;
Str param_type;
Str return_type;
} Fun;
typedef struct Enum {
Str name;
Node *val;
} Enum;
typedef struct Struct {
Str name;
Str type;
Node *val;
} Struct;
MAPDEF(SymbolMap, symmap, Str, Symbol, str_hash, str_eq)
MAPDEF(FunMap, funmap, Str, Fun, str_hash, str_eq)
MAPDEF(EnumMap, enummap, Str, Enum, str_hash, str_eq)
MAPDEF(StructMap, structmap, Str, Struct, str_hash, str_eq)
typedef struct Scope {
sz id;
sz depth;
Str name;
SymbolMap *symbols;
FunMap *funcs;
EnumMap *enums;
StructMap *structs;
struct Scope *parent;
} Scope;
typedef struct Analyzer {
Arena *storage;
Str file_name;
sz typescope_gen;
Scope **scopes;
StrSet *numeric_types;
StrSet *integer_types;
bool err;
} Analyzer;
Scope *
typescope_alloc(Analyzer *a, Scope *parent) {
Scope *scope = arena_calloc(sizeof(Scope), a->storage);
scope->parent = parent;
scope->id = a->typescope_gen++;
scope->depth = parent == NULL ? 0 : parent->depth + 1;
array_push(a->scopes, scope, a->storage);
return scope;
}
SymbolMap *
find_type(Scope *scope, Str type) {
while (scope != NULL) {
SymbolMap *val = symmap_lookup(&scope->symbols, type);
if (val != NULL) {
return val;
}
scope = scope->parent;
}
return NULL;
}
FunMap *
find_fun(Scope *scope, Str type) {
while (scope != NULL) {
FunMap *val = funmap_lookup(&scope->funcs, type);
if (val != NULL) {
return val;
}
scope = scope->parent;
}
return NULL;
}
typedef struct FindEnumResult {
EnumMap *map;
Scope *scope;
} FindEnumResult;
FindEnumResult
find_enum(Scope *scope, Str type) {
while (scope != NULL) {
EnumMap *val = enummap_lookup(&scope->enums, type);
if (val != NULL) {
return (FindEnumResult){.map = val, .scope = scope};
}
scope = scope->parent;
}
return (FindEnumResult){0};
}
typedef struct FindStructResult {
StructMap *map;
Scope *scope;
} FindStructResult;
FindStructResult
find_struct(Scope *scope, Str type) {
while (scope != NULL) {
StructMap *val = structmap_lookup(&scope->structs, type);
if (val != NULL) {
return (FindStructResult){.map = val, .scope = scope};
}
scope = scope->parent;
}
return (FindStructResult){0};
}
void
graph_typescope(Scope *scope, Arena a) {
if (!scope->symbols) {
return;
}
SymbolMapIter iter = symmap_iterator(scope->symbols, &a);
SymbolMap *type = symmap_next(&iter, &a);
print(
"%d[shape=\"none\" label=<
",
scope->id);
print(
""
" NAME | "
" TYPE | "
"
");
while (type) {
print(
""
" %s | "
" %s | "
"
",
type->key, type->val.name);
type = symmap_next(&iter, &a);
}
println("
>];");
sz this_id = scope->id;
while (scope->parent) {
if (scope->parent->symbols) {
println("%d:e->%d:w;", this_id, scope->parent->id);
break;
} else {
scope = scope->parent;
}
}
}
void
graph_functions(Scope *scope, Arena a) {
if (!scope->funcs) {
return;
}
FunMapIter iter = funmap_iterator(scope->funcs, &a);
FunMap *func = funmap_next(&iter, &a);
print(
"fun_%d[shape=\"none\" label=<",
scope->id);
print(
""
" NAME | "
" PARAMS | "
" RETURN | "
"
");
while (func) {
print(
""
" %s | "
" %s | "
" %s | "
"
",
func->val.name, func->val.name, func->val.param_type,
func->val.return_type);
func = funmap_next(&iter, &a);
}
println("
>];");
sz this_id = scope->id;
while (scope->parent) {
if (scope->parent->symbols) {
println("fun_%d:e->fun_%d:%s:w;", this_id, scope->parent->id,
scope->name);
break;
} else {
scope = scope->parent;
}
}
}
void
graph_types(Scope **scopes, Arena a) {
if (scopes == NULL) return;
println("digraph types {");
println("rankdir=LR;");
println("ranksep=\"0.95 equally\";");
println("nodesep=\"0.5 equally\";");
println("overlap=scale;");
println("bgcolor=\"transparent\";");
for (sz i = 0; i < array_size(scopes); i++) {
Scope *scope = scopes[i];
if (!scope) {
continue;
}
println("subgraph %d {", i);
graph_typescope(scope, a);
graph_functions(scope, a);
println("}");
}
println("}");
}
void
emit_semantic_error(Analyzer *a, Node *n, Str msg) {
eprintln("%s:%d:%d: error: %s", a->file_name, n->line, n->col, msg);
a->err = true;
}
Str type_inference(Analyzer *a, Node *node, Scope *scope);
void
typecheck_field(Analyzer *a, Node *node, Scope *scope, Str symbol) {
if (node->field_type->kind == NODE_COMPOUND_TYPE) {
Str field_name = str_concat(symbol, cstr("."), a->storage);
field_name = str_concat(field_name, node->value.str, a->storage);
if (structmap_lookup(&scope->structs, field_name)) {
eprintln("%s:%d:%d: error: struct field '%s' already exists",
a->file_name, node->line, node->col, field_name);
a->err = true;
}
Str type = cstr("\\{ ");
for (sz i = 0; i < array_size(node->field_type->elements); i++) {
Node *field = node->field_type->elements[i];
typecheck_field(a, field, scope, field_name);
type = str_concat(type, field->type, a->storage);
type = str_concat(type, cstr(" "), a->storage);
}
type = str_concat(type, cstr("\\}"), a->storage);
node->type = type;
} else {
Str field_name = str_concat(symbol, cstr("."), a->storage);
field_name = str_concat(field_name, node->value.str, a->storage);
Str field_type = node->field_type->value.str;
if (!find_type(scope, field_type)) {
eprintln("%s:%d:%d: error: unknown type '%s'", a->file_name,
node->field_type->line, node->field_type->col, field_type);
a->err = true;
}
if (node->field_type->is_ptr) {
field_type = str_concat(cstr("@"), field_type, a->storage);
}
if (node->field_type->kind == NODE_ARR_TYPE) {
field_type = str_concat(cstr("@"), field_type, a->storage);
}
if (structmap_lookup(&scope->structs, field_name)) {
eprintln("%s:%d:%d: error: struct field '%s' already exists",
a->file_name, node->line, node->col, field_name);
a->err = true;
}
if (node->field_val) {
Str type = type_inference(a, node->field_val, scope);
if (!str_eq(type, field_type)) {
eprintln(
"%s:%d:%d: error: mismatched types in struct "
"value "
"for '%s': %s expected %s",
a->file_name, node->line, node->col, field_name, type,
field_type);
a->err = true;
}
}
structmap_insert(&scope->structs, field_name,
(Struct){
.name = field_name,
.type = field_type,
.val = node->field_val,
},
a->storage);
symmap_insert(&scope->symbols, field_name,
(Symbol){.name = field_type, .kind = SYM_STRUCT_FIELD},
a->storage);
node->type = field_type;
}
}
void
typecheck_lit_field(Analyzer *a, Node *node, Scope *scope, Str symbol) {
if (node->field_val->kind == NODE_COMPOUND_TYPE) {
Str type = cstr("\\{ ");
for (sz i = 0; i < array_size(node->field_val->elements); i++) {
Node *field = node->field_val->elements[i];
Str field_name = str_concat(symbol, cstr("."), a->storage);
field_name = str_concat(field_name, field->value.str, a->storage);
typecheck_lit_field(a, field, scope, field_name);
type = str_concat(type, field->type, a->storage);
type = str_concat(type, cstr(" "), a->storage);
}
type = str_concat(type, cstr("\\}"), a->storage);
node->type = type;
} else {
StructMap *s = structmap_lookup(&scope->structs, symbol);
if (!s) {
eprintln("%s:%d:%d: error: unknown struct field '%s'", a->file_name,
node->line, node->col, symbol);
a->err = true;
return;
}
Str field_type = s->val.type;
Str type = type_inference(a, node->field_val, scope);
if (!str_eq(type, field_type)) {
eprintln(
"%s:%d:%d: error: mismatched types in struct "
"value "
"for '%s': %s expected %s",
a->file_name, node->line, node->col, symbol, type, field_type);
a->err = true;
}
node->type = field_type;
}
}
void
typecheck_returns(Analyzer *a, Node *node, Str expected) {
if (!node) {
return;
}
// Traverse the tree again.
switch (node->kind) {
case NODE_COND:
case NODE_MATCH: {
for (sz i = 0; i < array_size(node->match_cases); i++) {
Node *next = node->match_cases[i];
typecheck_returns(a, next, expected);
}
} break;
case NODE_RETURN: {
bool err = !str_eq(node->type, expected);
if (err) {
eprintln(
"%s:%d:%d: error: mismatched return type %s, expected %s",
a->file_name, node->line, node->col, node->type, expected);
a->err = true;
}
} break;
case NODE_BLOCK: {
for (sz i = 0; i < array_size(node->elements); i++) {
Node *next = node->elements[i];
typecheck_returns(a, next, expected);
}
} break;
case NODE_IF: {
if (node->cond_expr) {
typecheck_returns(a, node->cond_expr, expected);
}
if (node->cond_else) {
typecheck_returns(a, node->cond_else, expected);
}
} break;
case NODE_SET:
case NODE_LET: {
if (node->var_val) {
typecheck_returns(a, node->var_val, expected);
}
} break;
case NODE_ADD:
case NODE_SUB:
case NODE_DIV:
case NODE_MUL:
case NODE_MOD:
case NODE_NOT:
case NODE_AND:
case NODE_OR:
case NODE_EQ:
case NODE_NEQ:
case NODE_LT:
case NODE_GT:
case NODE_LE:
case NODE_GE:
case NODE_BITNOT:
case NODE_BITAND:
case NODE_BITOR:
case NODE_BITLSHIFT:
case NODE_BITRSHIFT: {
if (node->left) {
typecheck_returns(a, node->left, expected);
}
if (node->right) {
typecheck_returns(a, node->right, expected);
}
} break;
default: break;
}
}
Str
type_inference(Analyzer *a, Node *node, Scope *scope) {
assert(a);
assert(scope);
if (!node) {
return cstr("");
}
// NOTE: For now we are not going to do implicit numeric conversions.
switch (node->kind) {
case NODE_LET: {
node->type = cstr("nil");
Str symbol = node->var_name->value.str;
if (symmap_lookup(&scope->symbols, symbol)) {
eprintln(
"%s:%d:%d: error: symbol '%s' already exists in current "
"scope ",
a->file_name, node->var_name->line, node->var_name->col,
symbol);
a->err = true;
return cstr("");
}
if (node->var_type) {
Str type_name = node->var_type->value.str;
SymbolMap *type = find_type(scope, type_name);
if (type == NULL) {
eprintln("%s:%d:%d: error: unknown type '%s'", a->file_name,
node->var_type->line, node->var_type->col,
type_name);
a->err = true;
return cstr("");
}
if (node->var_type->is_ptr) {
type_name = str_concat(cstr("@"), type_name, a->storage);
}
if (node->var_type->kind == NODE_ARR_TYPE) {
type_name = str_concat(cstr("@"), type_name, a->storage);
// TODO: typecheck size
// TODO: register array in scope
}
if (node->var_val) {
Str type = type_inference(a, node->var_val, scope);
if (!type.size) {
eprintln(
"%s:%d:%d: error: can't bind `nil` to variable "
"'%s'",
a->file_name, node->var_type->line,
node->var_type->col, symbol);
a->err = true;
return cstr("");
}
// TODO: Consider compatible types.
if (!str_eq(type, type_name)) {
// Special case, enums can be treated as ints.
FindEnumResult res = find_enum(scope, type_name);
if (!(res.map && str_eq(type, cstr("int")))) {
eprintln(
"%s:%d:%d: error: type mismatch, trying to "
"assing "
"%s"
" to a variable of type %s",
a->file_name, node->var_type->line,
node->var_type->col, type, type_name);
a->err = true;
return cstr("");
}
}
}
symmap_insert(&scope->symbols, symbol,
(Symbol){
.name = type_name,
.kind = SYM_VAR,
},
a->storage);
return node->type;
}
// We don't know the type for this symbol, perform inference.
Str type = type_inference(a, node->var_val, scope);
if (type.size) {
symmap_insert(&scope->symbols, symbol,
(Symbol){.name = type, .kind = SYM_VAR},
a->storage);
node->var_name->type = type;
}
symbol = str_concat(cstr("."), symbol, a->storage);
symbol = str_concat(symbol, str_from_int(scope->id, a->storage),
a->storage);
node->unique_name = symbol;
return node->type;
} break;
case NODE_SET: {
Str name = type_inference(a, node->var_name, scope);
Str val = type_inference(a, node->var_val, scope);
if (!str_eq(name, val)) {
eprintln(
"%s:%d:%d: error: type mismatch, trying to assing "
"%s"
" to a variable of type %s",
a->file_name, node->line, node->col, val, name);
a->err = true;
return cstr("");
}
node->type = cstr("nil");
return node->type;
} break;
case NODE_STRUCT: {
node->type = cstr("nil");
Str symbol = node->value.str;
if (symmap_lookup(&scope->symbols, symbol) != NULL) {
eprintln(
"%s:%d:%d: error: struct '%s' already exists in current "
"scope",
a->file_name, node->line, node->col, symbol);
a->err = true;
return cstr("");
}
structmap_insert(&scope->structs, symbol, (Struct){.name = symbol},
a->storage);
for (sz i = 0; i < array_size(node->struct_field); i++) {
Node *field = node->struct_field[i];
typecheck_field(a, field, scope, symbol);
}
symmap_insert(&scope->symbols, symbol,
(Symbol){.name = symbol, .kind = SYM_STRUCT},
a->storage);
return node->type;
} break;
case NODE_ENUM: {
node->type = cstr("nil");
Str symbol = node->value.str;
if (symmap_lookup(&scope->symbols, symbol) != NULL) {
eprintln(
"%s:%d:%d: error: enum '%s' already exists in current "
"scope",
a->file_name, node->line, node->col, symbol);
a->err = true;
return cstr("");
}
enummap_insert(&scope->enums, symbol,
(Enum){
.name = symbol,
.val = node->field_val,
},
a->storage);
for (sz i = 0; i < array_size(node->struct_field); i++) {
Node *field = node->struct_field[i];
Str field_name = str_concat(symbol, cstr("."), a->storage);
field_name =
str_concat(field_name, field->value.str, a->storage);
if (enummap_lookup(&scope->enums, field_name)) {
eprintln("%s:%d:%d: error: enum field '%s' already exists",
a->file_name, field->line, field->col, field_name);
a->err = true;
}
if (field->field_val) {
Str type = type_inference(a, field->field_val, scope);
if (!str_eq(type, cstr("int"))) {
eprintln(
"%s:%d:%d: error: non int enum value for '%s.%s'",
a->file_name, field->line, field->col, symbol,
field_name);
a->err = true;
}
}
enummap_insert(&scope->enums, field_name,
(Enum){.name = field_name}, a->storage);
symmap_insert(
&scope->symbols, field_name,
(Symbol){.name = field_name, .kind = SYM_ENUM_FIELD},
a->storage);
field->type = symbol;
}
symmap_insert(&scope->symbols, symbol,
(Symbol){.name = symbol, .kind = SYM_ENUM},
a->storage);
return node->type;
} break;
case NODE_IF: {
Str cond_type = type_inference(a, node->cond_if, scope);
if (!str_eq(cond_type, cstr("bool"))) {
emit_semantic_error(
a, node->cond_if,
cstr("non boolean expression on if condition"));
return cstr("");
}
if (node->cond_expr->kind == NODE_BLOCK) {
node->type = type_inference(a, node->cond_expr, scope);
} else {
Scope *next = typescope_alloc(a, scope);
node->type = type_inference(a, node->cond_expr, next);
}
if (node->cond_else) {
Str else_type;
if (node->cond_else->kind == NODE_BLOCK) {
else_type = type_inference(a, node->cond_else, scope);
} else {
Scope *next = typescope_alloc(a, scope);
else_type = type_inference(a, node->cond_else, next);
}
if (!str_eq(node->type, else_type)) {
emit_semantic_error(
a, node, cstr("mismatch types for if/else branches"));
return cstr("");
}
}
return node->type;
} break;
case NODE_WHILE: {
Str cond_type = type_inference(a, node->while_cond, scope);
if (!str_eq(cond_type, cstr("bool"))) {
emit_semantic_error(
a, node->cond_if,
cstr("non boolean expression on while condition"));
return cstr("");
}
if (node->while_expr->kind != NODE_BLOCK) {
scope = typescope_alloc(a, scope);
}
type_inference(a, node->while_expr, scope);
node->type = cstr("nil");
return node->type;
} break;
case NODE_COND: {
Str previous = cstr("");
for (sz i = 0; i < array_size(node->match_cases); i++) {
Node *expr = node->match_cases[i];
Str next = type_inference(a, expr, scope);
if (i != 0 && !str_eq(next, previous)) {
emit_semantic_error(
a, node,
cstr("non-matching types for cond expressions"));
return cstr("");
}
previous = next;
}
node->type = previous;
return node->type;
} break;
case NODE_MATCH: {
Str e = type_inference(a, node->match_expr, scope);
if (str_eq(e, cstr("int"))) {
// Integer matching.
for (sz i = 0; i < array_size(node->match_cases); i++) {
Node *field = node->match_cases[i];
if (field->case_value) {
if (field->case_value->kind != NODE_NUM_INT &&
field->case_value->kind != NODE_NUM_UINT) {
emit_semantic_error(
a, field->case_value,
cstr(
"non-integer or enum types on match case"));
}
}
}
} else {
// Get enum type and de-structure the match.
FindEnumResult res = find_enum(scope, e);
Str enum_prefix =
str_concat(res.map->val.name, cstr("."), a->storage);
for (sz i = 0; i < array_size(node->match_cases); i++) {
Node *field = node->match_cases[i];
if (field->case_value) {
Str field_name = str_concat(
enum_prefix, field->case_value->value.str,
a->storage);
if (!enummap_lookup(&res.scope->enums, field_name)) {
eprintln("%s:%d:%d: error: unknown enum field '%s'",
a->file_name, field->case_value->line,
field->case_value->col, field_name);
a->err = true;
}
}
}
}
Str previous = cstr("");
for (sz i = 0; i < array_size(node->match_cases); i++) {
Node *expr = node->match_cases[i];
Str next = type_inference(a, expr, scope);
if (i != 0 && !str_eq(next, previous)) {
emit_semantic_error(
a, node,
cstr("non-matching types for match expressions"));
return cstr("");
}
previous = next;
}
node->type = previous;
return node->type;
} break;
case NODE_CASE_MATCH: {
if (node->case_expr->kind != NODE_BLOCK) {
scope = typescope_alloc(a, scope);
}
node->type = type_inference(a, node->case_expr, scope);
return node->type;
} break;
case NODE_CASE_COND: {
if (node->case_value) {
Str cond = type_inference(a, node->case_value, scope);
if (!str_eq(cond, cstr("bool"))) {
emit_semantic_error(a, node,
cstr("non-boolean case condition"));
}
}
if (node->case_expr->kind != NODE_BLOCK) {
scope = typescope_alloc(a, scope);
}
node->type = type_inference(a, node->case_expr, scope);
return node->type;
} break;
case NODE_TRUE:
case NODE_FALSE: {
node->type = cstr("bool");
return node->type;
} break;
case NODE_NIL: {
node->type = cstr("nil");
return node->type;
} break;
case NODE_NOT:
case NODE_AND:
case NODE_OR: {
Str left = type_inference(a, node->left, scope);
if (!str_eq(left, cstr("bool"))) {
emit_semantic_error(a, node,
cstr("expected bool on logic expression"));
return cstr("");
}
if (node->right) {
Str right = type_inference(a, node->right, scope);
if (!str_eq(right, cstr("bool"))) {
emit_semantic_error(
a, node, cstr("expected bool on logic expression"));
return cstr("");
}
}
node->type = cstr("bool");
return node->type;
} break;
case NODE_EQ:
case NODE_NEQ:
case NODE_LT:
case NODE_GT:
case NODE_LE:
case NODE_GE: {
Str left = type_inference(a, node->left, scope);
Str right = type_inference(a, node->right, scope);
if (!str_eq(left, right)) {
emit_semantic_error(
a, node, cstr("mismatched types on binary expression"));
return cstr("");
}
node->type = cstr("bool");
return node->type;
} break;
case NODE_BITNOT: {
Str left = type_inference(a, node->left, scope);
if (!strset_lookup(&a->integer_types, left)) {
emit_semantic_error(
a, node, cstr("non integer type on bit twiddling expr"));
return cstr("");
}
node->type = left;
return node->type;
} break;
case NODE_BITAND:
case NODE_BITOR:
case NODE_BITLSHIFT:
case NODE_BITRSHIFT: {
Str left = type_inference(a, node->left, scope);
Str right = type_inference(a, node->right, scope);
if (!strset_lookup(&a->integer_types, left) ||
!strset_lookup(&a->integer_types, right)) {
emit_semantic_error(
a, node, cstr("non integer type on bit twiddling expr"));
return cstr("");
}
node->type = left;
return node->type;
} break;
case NODE_ADD:
case NODE_SUB:
case NODE_DIV:
case NODE_MUL:
case NODE_MOD: {
Str left = type_inference(a, node->left, scope);
Str right = type_inference(a, node->right, scope);
if (!strset_lookup(&a->numeric_types, left) ||
!strset_lookup(&a->numeric_types, right)) {
emit_semantic_error(
a, node, cstr("non numeric type on arithmetic expr"));
return cstr("");
}
if (!str_eq(left, right)) {
emit_semantic_error(
a, node, cstr("mismatched types on binary expression"));
return cstr("");
}
node->type = left;
return node->type;
} break;
case NODE_NUM_UINT: {
node->type = cstr("uint");
return node->type;
} break;
case NODE_NUM_INT: {
node->type = cstr("int");
return node->type;
} break;
case NODE_NUM_FLOAT: {
node->type = cstr("f64");
return node->type;
} break;
case NODE_STRING: {
node->type = cstr("str");
return node->type;
} break;
case NODE_ARR_TYPE:
case NODE_TYPE: {
SymbolMap *type = find_type(scope, node->value.str);
if (!type) {
emit_semantic_error(a, node, cstr("unknown type"));
return cstr("");
}
node->type = type->val.name;
return node->type;
} break;
case NODE_SYMBOL_IDX:
case NODE_SYMBOL: {
Str symbol = node->value.str;
SymbolMap *type = find_type(scope, symbol);
if (!type) {
eprintln("%s:%d:%d: error: couldn't resolve symbol '%s'",
a->file_name, node->line, node->col, symbol);
a->err = true;
return cstr("");
}
Str type_name = type->val.name;
if (node->kind == NODE_SYMBOL_IDX) {
Str idx_type = type_inference(a, node->arr_size, scope);
if (!strset_lookup(&a->integer_types, idx_type)) {
emit_semantic_error(
a, node, cstr("can't resolve non integer index"));
return cstr("");
}
type_name = str_remove_prefix(type_name, cstr("@"));
}
if (node->is_ptr) {
type_name = str_concat(cstr("@"), type_name, a->storage);
}
FindEnumResult e = find_enum(scope, type_name);
if (e.map && str_eq(symbol, type_name)) {
if (!node->next) {
eprintln(
"%s:%d:%d: error: unspecified enum field for symbol "
"'%s'",
a->file_name, node->line, node->col, symbol);
a->err = true;
return cstr("");
}
// Check if there is a next and it matches the enum field.
Str field = str_concat(type_name, cstr("."), a->storage);
field = str_concat(field, node->next->value.str, a->storage);
if (!enummap_lookup(&e.scope->enums, field)) {
eprintln(
"%s:%d:%d: error: unknown enum field for "
"'%s': %s",
a->file_name, node->line, node->col, symbol,
node->next->value.str);
a->err = true;
return cstr("");
}
node->next->type = type_name;
node->type = type_name;
return node->next->type;
}
FindStructResult s = find_struct(scope, type_name);
if (s.map) {
if (str_eq(symbol, type_name)) {
eprintln(
"%s:%d:%d: error: struct incomplete struct literal "
"'%s', did you mean to use %s:{}?",
a->file_name, node->line, node->col, symbol, symbol);
a->err = true;
return cstr("");
} else {
if (node->next) {
Str chain = type_name;
Node *next = node;
while (next->next) {
next = next->next;
chain = str_concat(chain, cstr("."), a->storage);
chain =
str_concat(chain, next->value.str, a->storage);
}
StructMap *field =
structmap_lookup(&s.scope->structs, chain);
if (!field) {
eprintln(
"%s:%d:%d: error: unknown struct field '%s'",
a->file_name, node->line, node->col, chain);
a->err = true;
return cstr("");
}
Str field_type = field->val.type;
if (next->kind == NODE_SYMBOL_IDX) {
Str idx_type =
type_inference(a, next->arr_size, scope);
if (!strset_lookup(&a->integer_types, idx_type)) {
emit_semantic_error(
a, next,
cstr("can't resolve non integer index"));
return cstr("");
}
field_type =
str_remove_prefix(field_type, cstr("@"));
}
node->type = field_type;
return node->type;
}
}
}
node->type = type_name;
return node->type;
} break;
case NODE_STRUCT_LIT: {
Str name = node->value.str;
FindStructResult s = find_struct(scope, name);
if (!s.map) {
eprintln("%s:%d:%d: error: unknown struct type '%s'",
a->file_name, node->line, node->col, name);
a->err = true;
return cstr("");
}
StrSet *set = NULL;
for (sz i = 0; i < array_size(node->elements); i++) {
Node *next = node->elements[i];
Str field_name = str_concat(name, cstr("."), a->storage);
field_name =
str_concat(field_name, next->value.str, a->storage);
if (strset_lookup(&set, field_name)) {
eprintln(
"%s:%d:%d: error: field '%s' already present in struct "
"literal",
a->file_name, next->line, next->col, field_name);
a->err = true;
} else {
strset_insert(&set, field_name, a->storage);
}
typecheck_lit_field(a, next, s.scope, field_name);
}
node->type = name;
return node->type;
} break;
case NODE_FUNCALL: {
Str symbol = node->value.str;
FunMap *fun = find_fun(scope, symbol);
if (!fun) {
eprintln(
"%s:%d:%d: error: function '%s' doesn't exist in current "
"scope ",
a->file_name, node->line, node->col, symbol);
a->err = true;
return cstr("");
}
// Check that actual parameters typecheck
Str args = cstr("");
for (sz i = 0; i < array_size(node->elements); i++) {
Node *expr = node->elements[i];
Str type = type_inference(a, expr, scope);
args = str_concat(args, type, a->storage);
if (i != array_size(node->elements) - 1) {
args = str_concat(args, cstr(","), a->storage);
}
}
if (!args.size) {
args = cstr("nil");
}
Str expected = fun->val.param_type;
if (!str_eq(args, expected) && !str_eq(expected, cstr("..."))) {
eprintln(
"%s:%d:%d: error: mismatched parameter types: %s expected "
"%s",
a->file_name, node->line, node->col, args, expected);
a->err = true;
}
node->type = fun->val.return_type;
return node->type;
} break;
case NODE_BLOCK: {
scope = typescope_alloc(a, scope);
Str type;
for (sz i = 0; i < array_size(node->elements); i++) {
Node *expr = node->elements[i];
type = type_inference(a, expr, scope);
}
node->type = type;
return node->type;
} break;
case NODE_RETURN: {
Str ret_type = cstr("");
for (sz i = 0; i < array_size(node->elements); i++) {
Node *expr = node->elements[i];
Str type = type_inference(a, expr, scope);
ret_type = str_concat(ret_type, type, a->storage);
if (i != array_size(node->elements) - 1) {
ret_type = str_concat(ret_type, cstr(","), a->storage);
}
}
if (!ret_type.size) {
ret_type = cstr("nil");
}
node->type = ret_type;
return node->type;
} break;
case NODE_FUN: {
node->type = cstr("nil");
Scope *prev_scope = scope;
scope = typescope_alloc(a, scope);
Str param_type = cstr("");
for (sz i = 0; i < array_size(node->func_params); i++) {
Node *param = node->func_params[i];
Str symbol = param->param_name->value.str;
Str type = param->param_type->value.str;
if (param->param_type->is_ptr) {
type = str_concat(cstr("@"), type, a->storage);
}
if (param->param_type->kind == NODE_ARR_TYPE) {
type = str_concat(cstr("@"), type, a->storage);
}
param->param_name->type =
type_inference(a, param->param_type, scope);
param->type = type;
symmap_insert(&scope->symbols, symbol,
(Symbol){.name = type, .kind = SYM_PARAM},
a->storage);
param_type = str_concat(param_type, type, a->storage);
if (i != array_size(node->func_params) - 1) {
param_type = str_concat(param_type, cstr(","), a->storage);
}
}
if (!param_type.size) {
param_type = cstr("nil");
}
node->fun_params = param_type;
Str ret_type = cstr("");
for (sz i = 0; i < array_size(node->func_ret); i++) {
Node *expr = node->func_ret[i];
Str type = type_inference(a, expr, scope);
if (expr->is_ptr) {
type = str_concat(cstr("@"), type, a->storage);
}
if (expr->kind == NODE_ARR_TYPE) {
type = str_concat(cstr("@"), type, a->storage);
}
ret_type = str_concat(ret_type, type, a->storage);
if (i != array_size(node->func_ret) - 1) {
ret_type = str_concat(ret_type, cstr(","), a->storage);
}
}
if (!ret_type.size) {
ret_type = cstr("nil");
}
node->fun_return = ret_type;
Str symbol = node->func_name->value.str;
if (prev_scope->parent != NULL) {
if (symmap_lookup(&prev_scope->symbols, symbol)) {
eprintln(
"%s:%d:%d: error: function '%s' already defined in "
"current "
"scope ",
a->file_name, node->var_name->line, node->var_name->col,
symbol);
a->err = true;
return cstr("");
}
symmap_insert(&scope->symbols, symbol,
(Symbol){.name = symbol, .kind = SYM_FUN},
a->storage);
}
scope->name = symbol;
funmap_insert(&prev_scope->funcs, symbol,
(Fun){.name = symbol,
.param_type = param_type,
.return_type = ret_type},
a->storage);
if (node->func_body->kind == NODE_BLOCK) {
Str type;
for (sz i = 0; i < array_size(node->func_body->elements); i++) {
Node *expr = node->func_body->elements[i];
type = type_inference(a, expr, scope);
}
if (!type.size) {
type = cstr("nil");
}
node->func_body->type = type;
} else {
type_inference(a, node->func_body, scope);
}
// Ensure main body return matches the prototype.
if (!str_eq(node->func_body->type, ret_type)) {
eprintln(
"%s:%d:%d: error: mismatched return type %s, expected %s",
a->file_name, node->line, node->col, node->func_body->type,
ret_type);
a->err = true;
}
// Ensure ALL return statements match the function prototype.
typecheck_returns(a, node->func_body, ret_type);
// TODO: should return statements be allowed on let blocks?
return node->type;
} break;
default: {
emit_semantic_error(a, node,
cstr("type inference not implemented for this "
"kind of expression"));
println("KIND: %s", node_str[node->kind]);
} break;
}
return cstr("");
}
void
symbolic_analysis(Analyzer *a, Parser *parser) {
Scope *scope = typescope_alloc(a, NULL);
assert(a);
assert(parser);
// Fill builtin tables.
Str builtin_functions[] = {
cstr("print"),
cstr("println"),
};
for (sz i = 0; i < LEN(builtin_functions); i++) {
Str symbol = builtin_functions[i];
symmap_insert(&scope->symbols, symbol,
(Symbol){.name = symbol, .kind = SYM_BUILTIN_FUN},
a->storage);
funmap_insert(&scope->funcs, symbol,
(Fun){.name = symbol,
.param_type = cstr("..."),
.return_type = cstr("nil")},
a->storage);
}
Str builtin_types[] = {
cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"),
cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"),
cstr("f32"), cstr("f64"), cstr("ptr"), cstr("int"),
cstr("uint"), cstr("str"), cstr("bool"), cstr("nil")};
for (sz i = 0; i < LEN(builtin_types); i++) {
Str type = builtin_types[i];
symmap_insert(&scope->symbols, type,
(Symbol){.name = type, .kind = SYM_BUILTIN_TYPE},
a->storage);
}
Str numeric_types[] = {
cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"), cstr("u32"),
cstr("s32"), cstr("u64"), cstr("s64"), cstr("f32"), cstr("f64"),
cstr("ptr"), cstr("int"), cstr("uint"),
};
for (sz i = 0; i < LEN(numeric_types); i++) {
Str type = numeric_types[i];
strset_insert(&a->numeric_types, type, a->storage);
}
Str integer_types[] = {
cstr("u8"), cstr("s8"), cstr("u16"), cstr("s16"),
cstr("u32"), cstr("s32"), cstr("u64"), cstr("s64"),
cstr("ptr"), cstr("int"), cstr("uint"),
};
for (sz i = 0; i < LEN(integer_types); i++) {
Str type = integer_types[i];
strset_insert(&a->integer_types, type, a->storage);
}
// Find top level function declarations.
for (sz i = 0; i < array_size(parser->nodes); i++) {
Node *root = parser->nodes[i];
if (root->kind == NODE_FUN) {
Str symbol = root->func_name->value.str;
if (symmap_lookup(&scope->symbols, symbol)) {
eprintln(
"%s:%d:%d: error: function '%s' already defined in "
"current "
"scope ",
a->file_name, root->var_name->line, root->var_name->col,
symbol);
a->err = true;
}
symmap_insert(&scope->symbols, symbol,
(Symbol){.name = symbol, .kind = SYM_FUN},
a->storage);
}
}
// Recursively fill symbol tables.
for (sz i = 0; i < array_size(parser->nodes); i++) {
Node *root = parser->nodes[i];
type_inference(a, root, scope);
}
}
#endif // SEMANTIC_C