aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorBad Diode <bd@badd10de.dev>2022-04-18 16:27:21 -0300
committerBad Diode <bd@badd10de.dev>2022-04-18 16:27:21 -0300
commit3da041f2e17fdeb69bf345aadf89c5fcc1814260 (patch)
treec1979ffee13f45f757712a61304a3edba89a80f5 /src
parentdcd3192e50d7b4ea333ecf57a7e8b325af145547 (diff)
downloadbdl-3da041f2e17fdeb69bf345aadf89c5fcc1814260.tar.gz
bdl-3da041f2e17fdeb69bf345aadf89c5fcc1814260.zip
Move semantic analysis to separate file
Diffstat (limited to 'src')
-rw-r--r--src/main.c9
-rw-r--r--src/nodes.c1
-rw-r--r--src/nodes.h1
-rw-r--r--src/parser.c525
-rw-r--r--src/parser.h22
-rw-r--r--src/semantic.c527
-rw-r--r--src/viz.c8
7 files changed, 545 insertions, 548 deletions
diff --git a/src/main.c b/src/main.c
index 863d9d1..cdf4167 100644
--- a/src/main.c
+++ b/src/main.c
@@ -9,6 +9,7 @@
9#include "lexer.c" 9#include "lexer.c"
10#include "nodes.c" 10#include "nodes.c"
11#include "parser.c" 11#include "parser.c"
12#include "semantic.c"
12#include "viz.c" 13#include "viz.c"
13 14
14void 15void
@@ -29,9 +30,13 @@ process_source(const StringView *source, const char *file_name) {
29 // print_tokens(tokens); 30 // print_tokens(tokens);
30 31
31 // Parser. 32 // Parser.
32 ParseTree *parse_tree = parse(tokens); 33 Root *roots = parse(tokens);
33 check_errors(file_name); 34 check_errors(file_name);
34 viz_ast(parse_tree); 35 // viz_ast(roots);
36
37 // Symbol table generation and type checking.
38 ParseTree *parse_tree = semantic_analysis(roots);
39 viz_ast(parse_tree->roots);
35} 40}
36 41
37void 42void
diff --git a/src/nodes.c b/src/nodes.c
index b8a5f09..6978acc 100644
--- a/src/nodes.c
+++ b/src/nodes.c
@@ -11,7 +11,6 @@ alloc_node(NodeType type) {
11 node->type = type; 11 node->type = type;
12 node->line = 0; 12 node->line = 0;
13 node->col = 0; 13 node->col = 0;
14 node->scope = NULL;
15 node->expr_type = NULL; 14 node->expr_type = NULL;
16 return node; 15 return node;
17} 16}
diff --git a/src/nodes.h b/src/nodes.h
index 11b30dd..af10573 100644
--- a/src/nodes.h
+++ b/src/nodes.h
@@ -21,7 +21,6 @@ typedef struct Node {
21 NodeType type; 21 NodeType type;
22 size_t line; 22 size_t line;
23 size_t col; 23 size_t col;
24 struct Scope *scope;
25 struct Type *expr_type; 24 struct Type *expr_type;
26 25
27 union { 26 union {
diff --git a/src/parser.c b/src/parser.c
index 0cf70d7..3f15b47 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -1,59 +1,6 @@
1#include "parser.h" 1#include "parser.h"
2#include "darray.h" 2#include "darray.h"
3 3
4typedef enum DefaultType {
5 TYPE_VOID,
6 TYPE_BOOL,
7 TYPE_STR,
8 TYPE_U8,
9 TYPE_U16,
10 TYPE_U32,
11 TYPE_U64,
12 TYPE_S8,
13 TYPE_S16,
14 TYPE_S32,
15 TYPE_S64,
16 TYPE_F32,
17 TYPE_F64,
18} DefaultType;
19
20static Type default_types[] = {
21 [TYPE_VOID] = {STRING("void"), 0},
22 [TYPE_BOOL] = {STRING("bool"), 1},
23 [TYPE_STR] = {STRING("str"), 16}, // size (8) + pointer to data (8).
24 [TYPE_U8] = {STRING("u8"), 1},
25 [TYPE_U16] = {STRING("u16"), 2},
26 [TYPE_U32] = {STRING("u32"), 4},
27 [TYPE_U64] = {STRING("u64"), 8},
28 [TYPE_S8] = {STRING("s8"), 1},
29 [TYPE_S16] = {STRING("s16"), 2},
30 [TYPE_S32] = {STRING("s32"), 4},
31 [TYPE_S64] = {STRING("s64"), 8},
32 [TYPE_F32] = {STRING("f32"), 4},
33 [TYPE_F64] = {STRING("f64"), 8},
34};
35
36typedef enum SymbolType {
37 SYMBOL_VAR,
38 SYMBOL_FUN,
39} SymbolType;
40
41typedef struct SymbolValue {
42 Node *name;
43 SymbolType type;
44
45 union {
46 struct {
47 Node *type;
48 } var;
49
50 struct {
51 Node **param_types;
52 Node *return_type;
53 } fun;
54 };
55} SymbolValue;
56
57Token 4Token
58next_token(Parser *parser) { 5next_token(Parser *parser) {
59 return parser->tokens[parser->current_token++]; 6 return parser->tokens[parser->current_token++];
@@ -510,67 +457,6 @@ parse_next(Parser *parser) {
510 } 457 }
511} 458}
512 459
513u64 sym_hash(const struct HashTable *table, void *bytes) {
514 Node *symbol = bytes;
515 u64 hash = _xor_shift_hash(symbol->string.start, symbol->string.n);
516 hash = _fibonacci_hash(hash, table->shift_amount);
517 return hash;
518}
519
520bool sym_eq(void *a, void *b) {
521 Node *a_node = a;
522 Node *b_node = b;
523 assert(a_node->type == NODE_SYMBOL);
524 assert(b_node->type == NODE_SYMBOL);
525 return sv_equal(&a_node->string, &b_node->string);
526}
527
528u64 type_hash(const struct HashTable *table, void *bytes) {
529 StringView *type = bytes;
530 u64 hash = _xor_shift_hash(type->start, type->n);
531 hash = _fibonacci_hash(hash, table->shift_amount);
532 return hash;
533}
534
535bool type_eq(void *a, void *b) {
536 StringView *a_type = a;
537 StringView *b_type = b;
538 return sv_equal(a_type, b_type);
539}
540
541Scope *
542alloc_scope(Scope *parent) {
543 Scope *scope = malloc(sizeof(Scope));
544 scope->parent = parent;
545 scope->symbols = ht_init(sym_hash, sym_eq);
546 scope->types = ht_init(type_hash, type_eq);
547 return scope;
548}
549
550ParseTree *
551alloc_parsetree(void) {
552 ParseTree *parse_tree = malloc(sizeof(ParseTree));
553 array_init(parse_tree->roots, 0);
554 parse_tree->global_scope = alloc_scope(NULL);
555 parse_tree->current_scope = parse_tree->global_scope;
556
557 // Fill global scope with default types.
558 HashTable *types = parse_tree->global_scope->types;
559 for (size_t i = 0; i < sizeof(default_types)/sizeof(Type); ++i) {
560 Type *type = &default_types[i];
561 ht_insert(types, &type->name, type);
562 }
563 return parse_tree;
564}
565
566SymbolValue *
567alloc_symval(Node *name, SymbolType type) {
568 SymbolValue *val = malloc(sizeof(SymbolValue));
569 val->name = name;
570 val->type = type;
571 return val;
572}
573
574bool 460bool
575parse_roots(Parser *parser) { 461parse_roots(Parser *parser) {
576 while (has_next(parser)) { 462 while (has_next(parser)) {
@@ -582,424 +468,21 @@ parse_roots(Parser *parser) {
582 if (node == NULL) { 468 if (node == NULL) {
583 return false; 469 return false;
584 } 470 }
585 array_push(parser->parse_tree->roots, node); 471 array_push(parser->roots, node);
586 }
587 return true;
588}
589
590Type *
591find_type(Scope *scope, Node *type) {
592 // TODO: Normally default types will be used more often. Since we don't
593 // allow type shadowing, we should search first on the global scope.
594 while (scope != NULL) {
595 Type *ret = ht_lookup(scope->types, &type->string);
596 if (ret != NULL) {
597 return ret;
598 }
599 scope = scope->parent;
600 }
601 push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_TYPE, type->line, type->col);
602 return NULL;
603}
604
605bool
606insert_symbol(Scope *scope, Node *symbol, SymbolValue *val) {
607 // Check if symbol already exists.
608 HashTable *symbols = scope->symbols;
609 if (ht_lookup(symbols, symbol) != NULL) {
610 push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col);
611 return false;
612 }
613 ht_insert(symbols, symbol, val);
614 return true;
615}
616
617Type *
618coerce_numeric_types(Type *a, Type *b) {
619 // TODO: Decide what to do with mixed numeric types. What are the promotion
620 // rules, etc.
621 if (a == &default_types[TYPE_U8]) {
622 if (b == &default_types[TYPE_U16] ||
623 b == &default_types[TYPE_U32] ||
624 b == &default_types[TYPE_U64]) {
625 return b;
626 }
627 } else if (a == &default_types[TYPE_U16]) {
628 if (b == &default_types[TYPE_U32] ||
629 b == &default_types[TYPE_U64]) {
630 return b;
631 }
632 } else if (a == &default_types[TYPE_U32]) {
633 if (b == &default_types[TYPE_U64]) {
634 return b;
635 }
636 } else if (a == &default_types[TYPE_S8]) {
637 if (b == &default_types[TYPE_S16] ||
638 b == &default_types[TYPE_S32] ||
639 b == &default_types[TYPE_S64]) {
640 return b;
641 }
642 } else if (a == &default_types[TYPE_S16]) {
643 if (b == &default_types[TYPE_S32] ||
644 b == &default_types[TYPE_S64]) {
645 return b;
646 }
647 } else if (a == &default_types[TYPE_S32]) {
648 if (b == &default_types[TYPE_S64]) {
649 return b;
650 }
651 } else if (a == &default_types[TYPE_F32]) {
652 if (b == &default_types[TYPE_F64]) {
653 return b;
654 }
655 }
656 return a;
657}
658
659bool
660type_is_numeric(Type *t) {
661 if (t == &default_types[TYPE_U8] ||
662 t == &default_types[TYPE_U16] ||
663 t == &default_types[TYPE_U32] ||
664 t == &default_types[TYPE_U64] ||
665 t == &default_types[TYPE_S8] ||
666 t == &default_types[TYPE_S16] ||
667 t == &default_types[TYPE_S32] ||
668 t == &default_types[TYPE_S64] ||
669 t == &default_types[TYPE_F32] ||
670 t == &default_types[TYPE_F64]) {
671 return true;
672 }
673 return false;
674}
675
676SymbolValue *
677find_symbol(Scope *scope, Node *node) {
678 while (scope != NULL) {
679 SymbolValue *val = ht_lookup(scope->symbols, node);
680 if (val != NULL) {
681 return val;
682 }
683 scope = scope->parent;
684 }
685 push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col);
686 return NULL;
687}
688
689bool
690resolve_type(Scope *scope, Node *node) {
691 if (node->expr_type != NULL) {
692 return true;
693 }
694 switch (node->type) {
695 case NODE_BUILTIN: {
696 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
697 Node *arg = node->builtin.args[i];
698 if (!resolve_type(scope, arg)) {
699 return false;
700 }
701 }
702 switch (node->builtin.type) {
703 // Numbers.
704 case TOKEN_ADD:
705 case TOKEN_SUB:
706 case TOKEN_MUL:
707 case TOKEN_DIV:
708 case TOKEN_MOD: {
709 Type *type = NULL;
710 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
711 Node *arg = node->builtin.args[i];
712
713 // Check that all arguments are numbers.
714 if (!type_is_numeric(arg->expr_type)) {
715 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM,
716 arg->line, arg->col);
717 return false;
718 }
719
720 if (type == NULL) {
721 type = arg->expr_type;
722 } else if (type != arg->expr_type) {
723 type = coerce_numeric_types(type, arg->expr_type);
724 }
725 }
726 node->expr_type = type;
727 } break;
728 // Bools.
729 case TOKEN_NOT:
730 case TOKEN_AND:
731 case TOKEN_OR: {
732 // Check that all arguments are boolean.
733 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
734 Node *arg = node->builtin.args[i];
735 if (arg->expr_type != &default_types[TYPE_BOOL]) {
736 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_BOOL,
737 arg->line, arg->col);
738 return false;
739 }
740 }
741 node->expr_type = &default_types[TYPE_BOOL];
742 } break;
743 case TOKEN_EQ:
744 case TOKEN_LT:
745 case TOKEN_GT:
746 case TOKEN_LE:
747 case TOKEN_GE: {
748 // Check that all arguments are nums.
749 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
750 Node *arg = node->builtin.args[i];
751 if (!type_is_numeric(arg->expr_type)) {
752 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM,
753 arg->line, arg->col);
754 return false;
755 }
756 }
757 node->expr_type = &default_types[TYPE_BOOL];
758 } break;
759 default: break;
760 }
761 } break;
762 case NODE_SYMBOL: {
763 SymbolValue *val = find_symbol(scope, node);
764 if (val == NULL) {
765 return false;
766 }
767
768 Type *type = NULL;
769 switch (val->type) {
770 case SYMBOL_VAR: {
771 type = find_type(scope, val->var.type);
772 } break;
773 case SYMBOL_FUN: {
774 type = find_type(scope, val->fun.return_type);
775 } break;
776 }
777 if (type == NULL) {
778 return false;
779 }
780 node->expr_type = type;
781 } break;
782 case NODE_FUN: {
783 // Fill up new scope with parameters
784 scope = alloc_scope(scope);
785
786 // Parameters.
787 for (size_t i = 0; i < array_size(node->fun.param_names); ++i) {
788 Node *param = node->fun.param_names[i];
789 Node *type = node->fun.param_types[i];
790 SymbolValue *var = alloc_symval(param, SYMBOL_VAR);
791 var->var.type = type;
792 if (!insert_symbol(scope, param, var)) {
793 return false;
794 }
795 }
796
797 // Body.
798 Node *body = node->fun.body;
799 if (body->type == NODE_BLOCK) {
800 body->scope = scope;
801 for (size_t i = 0; i < array_size(body->block.expr); ++i) {
802 Node *expr = body->block.expr[i];
803 if (!resolve_type(scope, expr)) {
804 return false;
805 }
806 }
807 Node *last_expr = body->block.expr[array_size(body->block.expr) - 1];
808 node->expr_type = last_expr->expr_type;
809 } else {
810 if (!resolve_type(scope, body)) {
811 return false;
812 }
813 }
814
815 // Check that the type of body matches the return type.
816 StringView *type_body = &node->fun.body->expr_type->name;
817 StringView *return_type = &node->fun.return_type->string;
818 if (!sv_equal(type_body, return_type)) {
819 push_error(ERR_TYPE_PARSER, ERR_WRONG_RET_TYPE, node->line, node->col);
820 return false;
821 }
822 } break;
823 case NODE_BLOCK: {
824 scope = alloc_scope(scope);
825 for (size_t i = 0; i < array_size(node->block.expr); ++i) {
826 Node *expr = node->block.expr[i];
827 if (!resolve_type(scope, expr)) {
828 return false;
829 }
830 }
831 Node *last_expr = node->block.expr[array_size(node->block.expr) - 1];
832 node->expr_type = last_expr->expr_type;
833 } break;
834 case NODE_IF: {
835 if (!resolve_type(scope, node->ifexpr.cond)) {
836 return false;
837 }
838 if (!resolve_type(scope, node->ifexpr.expr_true)) {
839 return false;
840 }
841 Type *type_true = node->ifexpr.expr_true->expr_type;
842 node->expr_type = type_true;
843 if (node->ifexpr.expr_false != NULL) {
844 if (!resolve_type(scope, node->ifexpr.expr_false)) {
845 return false;
846 }
847 }
848
849 // Check ifexpr.cond is a bool.
850 Type *type_cond = node->ifexpr.cond->expr_type;
851 if (!sv_equal(&type_cond->name, &default_types[TYPE_BOOL].name)) {
852 push_error(ERR_TYPE_PARSER, ERR_WRONG_COND_TYPE,
853 node->line, node->col);
854 return false;
855 }
856
857 // Check if types of expr_true and expr_false match
858 if (node->ifexpr.expr_false != NULL) {
859 Type *type_false = node->ifexpr.expr_false->expr_type;
860 if (type_is_numeric(type_true) && type_is_numeric(type_false)) {
861 node->expr_type = coerce_numeric_types(type_true, type_false);
862 } else if (!sv_equal(&type_true->name, &type_false->name)) {
863 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_T_F,
864 node->line, node->col);
865 return false;
866 }
867 }
868 } break;
869 case NODE_SET: {
870 node->expr_type = &default_types[TYPE_VOID];
871 if (!resolve_type(scope, node->set.symbol)) {
872 return false;
873 }
874 if (!resolve_type(scope, node->set.value)) {
875 return false;
876 }
877 Node *symbol = node->set.symbol;
878 Node *value = node->set.value;
879 if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) {
880 push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH,
881 node->line, node->col);
882 return false;
883 }
884 } break;
885 case NODE_DEF: {
886 // Prepare value for symbol table.
887 SymbolValue *var = alloc_symval(node->def.symbol, SYMBOL_VAR);
888 var->var.type = node->def.type;
889 if (!insert_symbol(scope, node->def.symbol, var)) {
890 return false;
891 }
892
893 Type *type = find_type(scope, node->def.type);
894 if (type == NULL) {
895 return false;
896 }
897 node->def.symbol->expr_type = type;
898
899 node->expr_type = &default_types[TYPE_VOID];
900 // TODO: type inference from right side when not annotated?
901 if (!resolve_type(scope, node->def.value)) {
902 return false;
903 }
904 Node *symbol = node->def.symbol;
905 Node *value = node->def.value;
906 if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) {
907 push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH,
908 node->line, node->col);
909 return false;
910 }
911 } break;
912 case NODE_NUMBER: {
913 // TODO: Numbers are f64/s64 unless explicitely annotated. Annotated
914 // numbers must fit in the given range (e.g. no negative constants
915 // inside a u64, no numbers bigger than 255 in a u8, etc.).
916 if (node->number.fractional != 0) {
917 node->expr_type = &default_types[TYPE_F64];
918 } else {
919 node->expr_type = &default_types[TYPE_S64];
920 }
921 } break;
922 case NODE_BOOL: {
923 node->expr_type = &default_types[TYPE_BOOL];
924 } break;
925 case NODE_STRING: {
926 node->expr_type = &default_types[TYPE_STR];
927 } break;
928 case NODE_FUNCALL: {
929 SymbolValue *val = find_symbol(scope, node->funcall.name);
930 if (!resolve_type(scope, node->funcall.name)) {
931 return false;
932 }
933 if (val->type != SYMBOL_FUN) {
934 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_FUN,
935 node->funcall.name->line, node->funcall.name->col);
936 return false;
937 }
938 if (array_size(node->funcall.args) != array_size(val->fun.param_types)) {
939 push_error(ERR_TYPE_PARSER, ERR_BAD_ARGS, node->line, node->col);
940 return false;
941 }
942 node->expr_type = node->funcall.name->expr_type;
943 for (size_t i = 0; i < array_size(node->funcall.args); ++i) {
944 Node *arg = node->funcall.args[i];
945 if (!resolve_type(scope, arg)) {
946 return false;
947 }
948 Node *expected = val->fun.param_types[i];
949 if (!sv_equal(&arg->expr_type->name, &expected->string)) {
950 push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH,
951 arg->line, arg->col);
952 return false;
953 }
954 }
955 } break;
956 default: break;
957 }
958 return true;
959}
960
961bool
962semantic_analysis(Parser *parser) {
963 // Fill up global function symbols.
964 Scope *scope = parser->parse_tree->global_scope;
965 for (size_t i = 0; i < array_size(parser->parse_tree->roots); ++i) {
966 Node *root = parser->parse_tree->roots[i];
967 if (root->type == NODE_FUN) {
968 Node *name = root->fun.name;
969 SymbolValue *fun = alloc_symval(root->fun.name, SYMBOL_FUN);
970 fun->fun.param_types = root->fun.param_types;
971 fun->fun.return_type = root->fun.return_type;
972 if (!insert_symbol(scope, name, fun)) {
973 return false;
974 }
975 }
976 }
977
978 for (size_t i = 0; i < array_size(parser->parse_tree->roots); ++i) {
979 // Fill up symbol tables in proper scope and resolve type of expression
980 // for all elements.
981 if (!resolve_type(scope, parser->parse_tree->roots[i])) {
982 return false;
983 }
984 } 472 }
985
986 return true; 473 return true;
987} 474}
988 475
989ParseTree * 476Root *
990parse(Token *tokens) { 477parse(Token *tokens) {
991 Parser parser = { 478 Parser parser = {
992 .tokens = tokens, 479 .tokens = tokens,
993 .current_token = 0, 480 .current_token = 0,
994 }; 481 };
995 parser.parse_tree = alloc_parsetree(); 482 array_init(parser.roots, 0);
996 483
997 if (!parse_roots(&parser)) { 484 if (!parse_roots(&parser)) {
998 return NULL; 485 return NULL;
999 } 486 }
1000 if (!semantic_analysis(&parser)) { 487 return parser.roots;
1001 return NULL;
1002 }
1003
1004 return parser.parse_tree;
1005} 488}
diff --git a/src/parser.h b/src/parser.h
index cc3ba92..206ca4c 100644
--- a/src/parser.h
+++ b/src/parser.h
@@ -3,32 +3,16 @@
3 3
4#include "lexer.h" 4#include "lexer.h"
5#include "nodes.h" 5#include "nodes.h"
6#include "hashtable.h"
7 6
8typedef struct Type { 7typedef Node* Root;
9 StringView name;
10 size_t size; // (bytes)
11} Type;
12
13typedef struct Scope {
14 struct Scope *parent;
15 HashTable *symbols;
16 HashTable *types;
17} Scope;
18
19typedef struct ParseTree {
20 Node **roots;
21 Scope *global_scope;
22 Scope *current_scope;
23} ParseTree;
24 8
25typedef struct Parser { 9typedef struct Parser {
26 Token *tokens; 10 Token *tokens;
27 size_t current_token; 11 size_t current_token;
28 ParseTree *parse_tree; 12 Root *roots;
29} Parser; 13} Parser;
30 14
31ParseTree * parse(Token *tokens); 15Root * parse(Token *tokens);
32Node * parse_next(Parser *parser); 16Node * parse_next(Parser *parser);
33 17
34#endif // BDL_PARSER_H 18#endif // BDL_PARSER_H
diff --git a/src/semantic.c b/src/semantic.c
new file mode 100644
index 0000000..06958b9
--- /dev/null
+++ b/src/semantic.c
@@ -0,0 +1,527 @@
1#include "hashtable.h"
2
3typedef struct Scope {
4 struct Scope *parent;
5 HashTable *symbols;
6 HashTable *types;
7} Scope;
8
9typedef struct ParseTree {
10 Root *roots;
11 Scope *global_scope;
12 Scope *current_scope;
13} ParseTree;
14
15typedef struct Type {
16 StringView name;
17 size_t size; // (bytes)
18} Type;
19
20typedef enum DefaultType {
21 TYPE_VOID,
22 TYPE_BOOL,
23 TYPE_STR,
24 TYPE_U8,
25 TYPE_U16,
26 TYPE_U32,
27 TYPE_U64,
28 TYPE_S8,
29 TYPE_S16,
30 TYPE_S32,
31 TYPE_S64,
32 TYPE_F32,
33 TYPE_F64,
34} DefaultType;
35
36static Type default_types[] = {
37 [TYPE_VOID] = {STRING("void"), 0},
38 [TYPE_BOOL] = {STRING("bool"), 1},
39 [TYPE_STR] = {STRING("str"), 16}, // size (8) + pointer to data (8).
40 [TYPE_U8] = {STRING("u8"), 1},
41 [TYPE_U16] = {STRING("u16"), 2},
42 [TYPE_U32] = {STRING("u32"), 4},
43 [TYPE_U64] = {STRING("u64"), 8},
44 [TYPE_S8] = {STRING("s8"), 1},
45 [TYPE_S16] = {STRING("s16"), 2},
46 [TYPE_S32] = {STRING("s32"), 4},
47 [TYPE_S64] = {STRING("s64"), 8},
48 [TYPE_F32] = {STRING("f32"), 4},
49 [TYPE_F64] = {STRING("f64"), 8},
50};
51
52typedef enum SymbolType {
53 SYMBOL_VAR,
54 SYMBOL_FUN,
55} SymbolType;
56
57typedef struct Symbol {
58 Node *name;
59 SymbolType type;
60
61 union {
62 struct {
63 Node *type;
64 } var;
65
66 struct {
67 Node **param_types;
68 Node *return_type;
69 } fun;
70 };
71} Symbol;
72
73
74Symbol *
75alloc_symval(Node *name, SymbolType type) {
76 Symbol *val = malloc(sizeof(Symbol));
77 val->name = name;
78 val->type = type;
79 return val;
80}
81
82u64 sym_hash(const struct HashTable *table, void *bytes) {
83 Node *symbol = bytes;
84 u64 hash = _xor_shift_hash(symbol->string.start, symbol->string.n);
85 hash = _fibonacci_hash(hash, table->shift_amount);
86 return hash;
87}
88
89bool sym_eq(void *a, void *b) {
90 Node *a_node = a;
91 Node *b_node = b;
92 assert(a_node->type == NODE_SYMBOL);
93 assert(b_node->type == NODE_SYMBOL);
94 return sv_equal(&a_node->string, &b_node->string);
95}
96
97u64 type_hash(const struct HashTable *table, void *bytes) {
98 StringView *type = bytes;
99 u64 hash = _xor_shift_hash(type->start, type->n);
100 hash = _fibonacci_hash(hash, table->shift_amount);
101 return hash;
102}
103
104bool type_eq(void *a, void *b) {
105 StringView *a_type = a;
106 StringView *b_type = b;
107 return sv_equal(a_type, b_type);
108}
109
110Scope *
111alloc_scope(Scope *parent) {
112 Scope *scope = malloc(sizeof(Scope));
113 scope->parent = parent;
114 scope->symbols = ht_init(sym_hash, sym_eq);
115 scope->types = ht_init(type_hash, type_eq);
116 return scope;
117}
118
119Type *
120find_type(Scope *scope, Node *type) {
121 // TODO: Normally default types will be used more often. Since we don't
122 // allow type shadowing, we should search first on the global scope.
123 while (scope != NULL) {
124 Type *ret = ht_lookup(scope->types, &type->string);
125 if (ret != NULL) {
126 return ret;
127 }
128 scope = scope->parent;
129 }
130 push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_TYPE, type->line, type->col);
131 return NULL;
132}
133
134bool
135insert_symbol(Scope *scope, Node *symbol, Symbol *val) {
136 // Check if symbol already exists.
137 HashTable *symbols = scope->symbols;
138 if (ht_lookup(symbols, symbol) != NULL) {
139 push_error(ERR_TYPE_PARSER, ERR_SYMBOL_REDEF, symbol->line, symbol->col);
140 return false;
141 }
142 ht_insert(symbols, symbol, val);
143 return true;
144}
145
146Type *
147coerce_numeric_types(Type *a, Type *b) {
148 // TODO: Decide what to do with mixed numeric types. What are the promotion
149 // rules, etc.
150 if (a == &default_types[TYPE_U8]) {
151 if (b == &default_types[TYPE_U16] ||
152 b == &default_types[TYPE_U32] ||
153 b == &default_types[TYPE_U64]) {
154 return b;
155 }
156 } else if (a == &default_types[TYPE_U16]) {
157 if (b == &default_types[TYPE_U32] ||
158 b == &default_types[TYPE_U64]) {
159 return b;
160 }
161 } else if (a == &default_types[TYPE_U32]) {
162 if (b == &default_types[TYPE_U64]) {
163 return b;
164 }
165 } else if (a == &default_types[TYPE_S8]) {
166 if (b == &default_types[TYPE_S16] ||
167 b == &default_types[TYPE_S32] ||
168 b == &default_types[TYPE_S64]) {
169 return b;
170 }
171 } else if (a == &default_types[TYPE_S16]) {
172 if (b == &default_types[TYPE_S32] ||
173 b == &default_types[TYPE_S64]) {
174 return b;
175 }
176 } else if (a == &default_types[TYPE_S32]) {
177 if (b == &default_types[TYPE_S64]) {
178 return b;
179 }
180 } else if (a == &default_types[TYPE_F32]) {
181 if (b == &default_types[TYPE_F64]) {
182 return b;
183 }
184 }
185 return a;
186}
187
188bool
189type_is_numeric(Type *t) {
190 if (t == &default_types[TYPE_U8] ||
191 t == &default_types[TYPE_U16] ||
192 t == &default_types[TYPE_U32] ||
193 t == &default_types[TYPE_U64] ||
194 t == &default_types[TYPE_S8] ||
195 t == &default_types[TYPE_S16] ||
196 t == &default_types[TYPE_S32] ||
197 t == &default_types[TYPE_S64] ||
198 t == &default_types[TYPE_F32] ||
199 t == &default_types[TYPE_F64]) {
200 return true;
201 }
202 return false;
203}
204
205Symbol *
206find_symbol(Scope *scope, Node *node) {
207 while (scope != NULL) {
208 Symbol *val = ht_lookup(scope->symbols, node);
209 if (val != NULL) {
210 return val;
211 }
212 scope = scope->parent;
213 }
214 push_error(ERR_TYPE_PARSER, ERR_UNKNOWN_SYMBOL, node->line, node->col);
215 return NULL;
216}
217
218bool
219resolve_type(Scope *scope, Node *node) {
220 if (node->expr_type != NULL) {
221 return true;
222 }
223 switch (node->type) {
224 case NODE_BUILTIN: {
225 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
226 Node *arg = node->builtin.args[i];
227 if (!resolve_type(scope, arg)) {
228 return false;
229 }
230 }
231 switch (node->builtin.type) {
232 // Numbers.
233 case TOKEN_ADD:
234 case TOKEN_SUB:
235 case TOKEN_MUL:
236 case TOKEN_DIV:
237 case TOKEN_MOD: {
238 Type *type = NULL;
239 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
240 Node *arg = node->builtin.args[i];
241
242 // Check that all arguments are numbers.
243 if (!type_is_numeric(arg->expr_type)) {
244 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM,
245 arg->line, arg->col);
246 return false;
247 }
248
249 if (type == NULL) {
250 type = arg->expr_type;
251 } else if (type != arg->expr_type) {
252 type = coerce_numeric_types(type, arg->expr_type);
253 }
254 }
255 node->expr_type = type;
256 } break;
257 // Bools.
258 case TOKEN_NOT:
259 case TOKEN_AND:
260 case TOKEN_OR: {
261 // Check that all arguments are boolean.
262 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
263 Node *arg = node->builtin.args[i];
264 if (arg->expr_type != &default_types[TYPE_BOOL]) {
265 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_BOOL,
266 arg->line, arg->col);
267 return false;
268 }
269 }
270 node->expr_type = &default_types[TYPE_BOOL];
271 } break;
272 case TOKEN_EQ:
273 case TOKEN_LT:
274 case TOKEN_GT:
275 case TOKEN_LE:
276 case TOKEN_GE: {
277 // Check that all arguments are nums.
278 for (size_t i = 0; i < array_size(node->builtin.args); ++i) {
279 Node *arg = node->builtin.args[i];
280 if (!type_is_numeric(arg->expr_type)) {
281 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_NUM,
282 arg->line, arg->col);
283 return false;
284 }
285 }
286 node->expr_type = &default_types[TYPE_BOOL];
287 } break;
288 default: break;
289 }
290 } break;
291 case NODE_SYMBOL: {
292 Symbol *val = find_symbol(scope, node);
293 if (val == NULL) {
294 return false;
295 }
296
297 Type *type = NULL;
298 switch (val->type) {
299 case SYMBOL_VAR: {
300 type = find_type(scope, val->var.type);
301 } break;
302 case SYMBOL_FUN: {
303 type = find_type(scope, val->fun.return_type);
304 } break;
305 }
306 if (type == NULL) {
307 return false;
308 }
309 node->expr_type = type;
310 } break;
311 case NODE_FUN: {
312 // Fill up new scope with parameters
313 scope = alloc_scope(scope);
314
315 // Parameters.
316 for (size_t i = 0; i < array_size(node->fun.param_names); ++i) {
317 Node *param = node->fun.param_names[i];
318 Node *type = node->fun.param_types[i];
319 Symbol *var = alloc_symval(param, SYMBOL_VAR);
320 var->var.type = type;
321 if (!insert_symbol(scope, param, var)) {
322 return false;
323 }
324 }
325
326 // Body.
327 Node *body = node->fun.body;
328 if (body->type == NODE_BLOCK) {
329 for (size_t i = 0; i < array_size(body->block.expr); ++i) {
330 Node *expr = body->block.expr[i];
331 if (!resolve_type(scope, expr)) {
332 return false;
333 }
334 }
335 Node *last_expr = body->block.expr[array_size(body->block.expr) - 1];
336 node->expr_type = last_expr->expr_type;
337 } else {
338 if (!resolve_type(scope, body)) {
339 return false;
340 }
341 }
342
343 // Check that the type of body matches the return type.
344 StringView *type_body = &node->fun.body->expr_type->name;
345 StringView *return_type = &node->fun.return_type->string;
346 if (!sv_equal(type_body, return_type)) {
347 push_error(ERR_TYPE_PARSER, ERR_WRONG_RET_TYPE, node->line, node->col);
348 return false;
349 }
350 } break;
351 case NODE_BLOCK: {
352 scope = alloc_scope(scope);
353 for (size_t i = 0; i < array_size(node->block.expr); ++i) {
354 Node *expr = node->block.expr[i];
355 if (!resolve_type(scope, expr)) {
356 return false;
357 }
358 }
359 Node *last_expr = node->block.expr[array_size(node->block.expr) - 1];
360 node->expr_type = last_expr->expr_type;
361 } break;
362 case NODE_IF: {
363 if (!resolve_type(scope, node->ifexpr.cond)) {
364 return false;
365 }
366 if (!resolve_type(scope, node->ifexpr.expr_true)) {
367 return false;
368 }
369 Type *type_true = node->ifexpr.expr_true->expr_type;
370 node->expr_type = type_true;
371 if (node->ifexpr.expr_false != NULL) {
372 if (!resolve_type(scope, node->ifexpr.expr_false)) {
373 return false;
374 }
375 }
376
377 // Check ifexpr.cond is a bool.
378 Type *type_cond = node->ifexpr.cond->expr_type;
379 if (!sv_equal(&type_cond->name, &default_types[TYPE_BOOL].name)) {
380 push_error(ERR_TYPE_PARSER, ERR_WRONG_COND_TYPE,
381 node->line, node->col);
382 return false;
383 }
384
385 // Check if types of expr_true and expr_false match
386 if (node->ifexpr.expr_false != NULL) {
387 Type *type_false = node->ifexpr.expr_false->expr_type;
388 if (type_is_numeric(type_true) && type_is_numeric(type_false)) {
389 node->expr_type = coerce_numeric_types(type_true, type_false);
390 } else if (!sv_equal(&type_true->name, &type_false->name)) {
391 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_T_F,
392 node->line, node->col);
393 return false;
394 }
395 }
396 } break;
397 case NODE_SET: {
398 node->expr_type = &default_types[TYPE_VOID];
399 if (!resolve_type(scope, node->set.symbol)) {
400 return false;
401 }
402 if (!resolve_type(scope, node->set.value)) {
403 return false;
404 }
405 Node *symbol = node->set.symbol;
406 Node *value = node->set.value;
407 if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) {
408 push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH,
409 node->line, node->col);
410 return false;
411 }
412 } break;
413 case NODE_DEF: {
414 // Prepare value for symbol table.
415 Symbol *var = alloc_symval(node->def.symbol, SYMBOL_VAR);
416 var->var.type = node->def.type;
417 if (!insert_symbol(scope, node->def.symbol, var)) {
418 return false;
419 }
420
421 Type *type = find_type(scope, node->def.type);
422 if (type == NULL) {
423 return false;
424 }
425 node->def.symbol->expr_type = type;
426
427 node->expr_type = &default_types[TYPE_VOID];
428 // TODO: type inference from right side when not annotated?
429 if (!resolve_type(scope, node->def.value)) {
430 return false;
431 }
432 Node *symbol = node->def.symbol;
433 Node *value = node->def.value;
434 if (!sv_equal(&symbol->expr_type->name, &value->expr_type->name)) {
435 push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH,
436 node->line, node->col);
437 return false;
438 }
439 } break;
440 case NODE_NUMBER: {
441 // TODO: Numbers are f64/s64 unless explicitely annotated. Annotated
442 // numbers must fit in the given range (e.g. no negative constants
443 // inside a u64, no numbers bigger than 255 in a u8, etc.).
444 if (node->number.fractional != 0) {
445 node->expr_type = &default_types[TYPE_F64];
446 } else {
447 node->expr_type = &default_types[TYPE_S64];
448 }
449 } break;
450 case NODE_BOOL: {
451 node->expr_type = &default_types[TYPE_BOOL];
452 } break;
453 case NODE_STRING: {
454 node->expr_type = &default_types[TYPE_STR];
455 } break;
456 case NODE_FUNCALL: {
457 Symbol *val = find_symbol(scope, node->funcall.name);
458 if (!resolve_type(scope, node->funcall.name)) {
459 return false;
460 }
461 if (val->type != SYMBOL_FUN) {
462 push_error(ERR_TYPE_PARSER, ERR_WRONG_TYPE_FUN,
463 node->funcall.name->line, node->funcall.name->col);
464 return false;
465 }
466 if (array_size(node->funcall.args) != array_size(val->fun.param_types)) {
467 push_error(ERR_TYPE_PARSER, ERR_BAD_ARGS, node->line, node->col);
468 return false;
469 }
470 node->expr_type = node->funcall.name->expr_type;
471 for (size_t i = 0; i < array_size(node->funcall.args); ++i) {
472 Node *arg = node->funcall.args[i];
473 if (!resolve_type(scope, arg)) {
474 return false;
475 }
476 Node *expected = val->fun.param_types[i];
477 if (!sv_equal(&arg->expr_type->name, &expected->string)) {
478 push_error(ERR_TYPE_PARSER, ERR_TYPE_MISMATCH,
479 arg->line, arg->col);
480 return false;
481 }
482 }
483 } break;
484 default: break;
485 }
486 return true;
487}
488
489ParseTree *
490semantic_analysis(Root *roots) {
491 ParseTree *parse_tree = malloc(sizeof(ParseTree));
492 parse_tree->roots = roots;
493 parse_tree->global_scope = alloc_scope(NULL);
494 parse_tree->current_scope = parse_tree->global_scope;
495
496 // Fill global scope with default types.
497 HashTable *types = parse_tree->global_scope->types;
498 for (size_t i = 0; i < sizeof(default_types)/sizeof(Type); ++i) {
499 Type *type = &default_types[i];
500 ht_insert(types, &type->name, type);
501 }
502
503 // Fill up global function symbols.
504 Scope *scope = parse_tree->global_scope;
505 for (size_t i = 0; i < array_size(parse_tree->roots); ++i) {
506 Node *root = parse_tree->roots[i];
507 if (root->type == NODE_FUN) {
508 Node *name = root->fun.name;
509 Symbol *fun = alloc_symval(root->fun.name, SYMBOL_FUN);
510 fun->fun.param_types = root->fun.param_types;
511 fun->fun.return_type = root->fun.return_type;
512 if (!insert_symbol(scope, name, fun)) {
513 return NULL;
514 }
515 }
516 }
517
518 for (size_t i = 0; i < array_size(parse_tree->roots); ++i) {
519 // Fill up symbol tables in proper scope and resolve type of expression
520 // for all elements.
521 if (!resolve_type(scope, parse_tree->roots[i])) {
522 return NULL;
523 }
524 }
525
526 return parse_tree;
527}
diff --git a/src/viz.c b/src/viz.c
index 81cc1ff..d519d2c 100644
--- a/src/viz.c
+++ b/src/viz.c
@@ -152,8 +152,8 @@ viz_node(Node *node) {
152} 152}
153 153
154void 154void
155viz_ast(ParseTree *parse_tree) { 155viz_ast(Root *roots) {
156 if (parse_tree == NULL) { 156 if (roots == NULL) {
157 return; 157 return;
158 } 158 }
159 printf("digraph ast {\n"); 159 printf("digraph ast {\n");
@@ -161,9 +161,9 @@ viz_ast(ParseTree *parse_tree) {
161 printf("ranksep=\"0.95 equally\";\n"); 161 printf("ranksep=\"0.95 equally\";\n");
162 printf("nodesep=\"0.5 equally\";\n"); 162 printf("nodesep=\"0.5 equally\";\n");
163 printf("overlap=scale;\n"); 163 printf("overlap=scale;\n");
164 for (size_t i = 0; i < array_size(parse_tree->roots); ++i) { 164 for (size_t i = 0; i < array_size(roots); ++i) {
165 printf("subgraph %zu {\n", i); 165 printf("subgraph %zu {\n", i);
166 Node *root = parse_tree->roots[array_size(parse_tree->roots) - 1 - i]; 166 Node *root = roots[array_size(roots) - 1 - i];
167 viz_node(root); 167 viz_node(root);
168 printf("}\n"); 168 printf("}\n");
169 } 169 }