aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorBad Diode <bd@badd10de.dev>2024-07-05 16:50:41 +0200
committerBad Diode <bd@badd10de.dev>2024-07-05 16:50:41 +0200
commit9d548e7dd018d8f365d788f9a716ee6c04592c9d (patch)
treea6435e40974c23121816d2d02537db10189bb2d9 /src
parentf67c4050a21d0d7ec2c5936ca5ea87fd48872be7 (diff)
downloadbdl-9d548e7dd018d8f365d788f9a716ee6c04592c9d.tar.gz
bdl-9d548e7dd018d8f365d788f9a716ee6c04592c9d.zip
Add tail-call-optimization
Diffstat (limited to 'src')
-rw-r--r--src/compiler.c104
-rw-r--r--src/vm.c5
2 files changed, 107 insertions, 2 deletions
diff --git a/src/compiler.c b/src/compiler.c
index 11720d5..c2b48b3 100644
--- a/src/compiler.c
+++ b/src/compiler.c
@@ -112,7 +112,8 @@ typedef enum OpCode {
112 OP_LDLADDR, // ldladdr rx, va 112 OP_LDLADDR, // ldladdr rx, va
113 OP_LDSTR, // ldstr rx, sa ; Stores the address of the string sa into rx 113 OP_LDSTR, // ldstr rx, sa ; Stores the address of the string sa into rx
114 // Functions. 114 // Functions.
115 OP_CALL, // call fx ; Bumps the stack pointer by cx 115 OP_CALL, // call fx ; Call the function fx
116 OP_RECUR, // recur ; Jump to the beginning of the function.
116 OP_RET, // ret ; Returns from current function 117 OP_RET, // ret ; Returns from current function
117 OP_RESERVE, // reserve cx ; Increments the stack pointer by cx bytes 118 OP_RESERVE, // reserve cx ; Increments the stack pointer by cx bytes
118 OP_POP, // pop rx ; Pops the last value of the stack into rx. 119 OP_POP, // pop rx ; Pops the last value of the stack into rx.
@@ -241,6 +242,7 @@ Str op_str[] = {
241 [OP_PUTRETI] = cstr("PUTRETI "), 242 [OP_PUTRETI] = cstr("PUTRETI "),
242 // Functions. 243 // Functions.
243 [OP_CALL] = cstr("CALL "), 244 [OP_CALL] = cstr("CALL "),
245 [OP_RECUR] = cstr("RECUR "),
244 [OP_RET] = cstr("RET "), 246 [OP_RET] = cstr("RET "),
245 [OP_RESERVE] = cstr("RESERVE "), 247 [OP_RESERVE] = cstr("RESERVE "),
246 [OP_POP] = cstr("POP "), 248 [OP_POP] = cstr("POP "),
@@ -900,6 +902,44 @@ compile_while(Compiler *compiler, Chunk *chunk, Node *node) {
900} 902}
901 903
902CompResult 904CompResult
905compile_tail_call(Compiler *compiler, Chunk *chunk, Node *node) {
906 // Update the local parameters.
907 for (sz i = 0; i < array_size(node->elements); i++) {
908 Node *expr = node->elements[i];
909 CompResult result = compile_expr(compiler, chunk, expr);
910 switch (result.type) {
911 case COMP_CONST: {
912 emit_op(OP_STLVARI, i, result.idx, 0, node, chunk);
913 } break;
914 case COMP_REG: {
915 if (str_eq(expr->type, cstr("str"))) {
916 sz var_addr = chunk->reg_idx++;
917 sz str_addr = result.idx;
918 emit_op(OP_LDLADDR, var_addr, i, 0, node, chunk);
919 emit_fat_copy(chunk, node, var_addr, str_addr);
920 } else {
921 emit_op(OP_STLVAR, i, result.idx, 0, node, chunk);
922 }
923 } break;
924 case COMP_STRING: {
925 sz var_addr = chunk->reg_idx++;
926 sz str_addr = chunk->reg_idx++;
927 emit_op(OP_LDLADDR, var_addr, i, 0, node, chunk);
928 emit_op(OP_LDSTR, str_addr, result.idx, 0, node, chunk);
929 emit_fat_copy(chunk, node, var_addr, str_addr);
930 } break;
931 default: {
932 emit_compile_err(compiler, chunk, node);
933 return (CompResult){.type = COMP_ERR};
934 } break;
935 }
936 }
937
938 emit_op(OP_RECUR, 0, 0, 0, node, chunk);
939 return (CompResult){.type = COMP_NIL};
940}
941
942CompResult
903compile_funcall(Compiler *compiler, Chunk *chunk, Node *node) { 943compile_funcall(Compiler *compiler, Chunk *chunk, Node *node) {
904 Str name = node->value.str; 944 Str name = node->value.str;
905 945
@@ -977,6 +1017,67 @@ compile_funcall(Compiler *compiler, Chunk *chunk, Node *node) {
977 } 1017 }
978 Function fun = map->val; 1018 Function fun = map->val;
979 1019
1020 // Check for tail recursive opportunities.
1021 if (str_eq(fun.name, node->unique_name) &&
1022 str_eq(chunk->name, node->unique_name)) {
1023 Node *parent = node->parent;
1024 Node *current = node;
1025 bool tail_recursive = true;
1026 while (parent != NULL) {
1027 switch (parent->kind) {
1028 case NODE_BLOCK: {
1029 sz idx = array_size(parent->statements) - 1;
1030 if (parent->statements[idx] != node) {
1031 tail_recursive = false;
1032 break;
1033 }
1034 } break;
1035 case NODE_WHILE: {
1036 if (current == parent->loop.cond) {
1037 tail_recursive = false;
1038 break;
1039 }
1040 } break;
1041 case NODE_IF: {
1042 if (current == parent->ifelse.cond) {
1043 tail_recursive = false;
1044 break;
1045 }
1046 } break;
1047 case NODE_FUN: {
1048 sz idx = array_size(parent->func.body->statements) - 1;
1049 if (parent->func.body->statements[idx] != current) {
1050 tail_recursive = false;
1051 break;
1052 }
1053 break;
1054 } break;
1055 case NODE_MATCH: {
1056 if (current == parent->match.expr) {
1057 tail_recursive = false;
1058 break;
1059 }
1060 } break;
1061 case NODE_COND: break;
1062 case NODE_CASE_COND: {
1063 if (current == parent->case_entry.cond) {
1064 tail_recursive = false;
1065 break;
1066 }
1067 } break;
1068 default: {
1069 tail_recursive = false;
1070 break;
1071 } break;
1072 }
1073 parent = parent->parent;
1074 current = current->parent;
1075 }
1076 if (tail_recursive) {
1077 return compile_tail_call(compiler, chunk, node);
1078 }
1079 }
1080
980 // Reserve space for the return value if needed. 1081 // Reserve space for the return value if needed.
981 if (fun.return_arity > 0) { 1082 if (fun.return_arity > 0) {
982 // Put the return data into a register 1083 // Put the return data into a register
@@ -1655,6 +1756,7 @@ disassemble_instruction(Instruction instruction) {
1655 instruction.a, instruction.b); 1756 instruction.a, instruction.b);
1656 break; 1757 break;
1657 case OP_RET: 1758 case OP_RET:
1759 case OP_RECUR:
1658 case OP_HALT: println("%s", op_str[instruction.op]); break; 1760 case OP_HALT: println("%s", op_str[instruction.op]); break;
1659 default: println("Unknown opcode %d", instruction.op); break; 1761 default: println("Unknown opcode %d", instruction.op); break;
1660 } 1762 }
diff --git a/src/vm.c b/src/vm.c
index 0929fb6..0791706 100644
--- a/src/vm.c
+++ b/src/vm.c
@@ -5,7 +5,7 @@
5#include "compiler.c" 5#include "compiler.c"
6 6
7#define N_CONST 256 7#define N_CONST 256
8#define STACK_SIZE KB(64) 8#define STACK_SIZE MB(4)
9typedef struct VM { 9typedef struct VM {
10 Chunk *main; 10 Chunk *main;
11 Chunk *chunk; 11 Chunk *chunk;
@@ -382,6 +382,9 @@ vm_run(VM *vm) {
382 p[3] = old_fp; 382 p[3] = old_fp;
383 vm->sp += sizeof(ptrsize) * 4; 383 vm->sp += sizeof(ptrsize) * 4;
384 } break; 384 } break;
385 case OP_RECUR: {
386 vm->ip = vm->chunk->code;
387 } break;
385 case OP_RET: { 388 case OP_RET: {
386 u64 *p = (u64 *)vm->sp; 389 u64 *p = (u64 *)vm->sp;
387 ptrsize chunk_addr = p[-4]; 390 ptrsize chunk_addr = p[-4];