From 4344d214a99321f81e2af6d075ef789a6a324c0e Mon Sep 17 00:00:00 2001 From: Bad Diode Date: Tue, 16 Nov 2021 22:10:52 +0100 Subject: Add tail call optimization for function calls --- src/compiler.h | 82 +++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 30 deletions(-) diff --git a/src/compiler.h b/src/compiler.h index 1ef81e3..e0d071c 100644 --- a/src/compiler.h +++ b/src/compiler.h @@ -340,15 +340,8 @@ compile_cdr(Object *obj) { context_printf(" ;; <-- compile_cdr\n"); } -void -compile_call(Object *obj) { - context_printf(" ;; --> compile_call\n"); - - // Prepare return pointer. - char *lab_end = generate_label("BDLL"); - context_printf(" lea rcx, [%s]\n", lab_end); - context_printf(" push rcx\n"); - +size_t +compile_call_body(Object *obj) { // Compile operator. compile_object(obj->head); context_printf(" pop rax\n"); @@ -385,13 +378,29 @@ compile_call(Object *obj) { obj = obj->tail; compile_object(obj->head); } - context_printf(" mov rax, [rsp + %zu]\n", 8 * offset); context_printf(" jmp rax\n"); - context_printf("%s:\n", lab_end); + + return offset; +} + +void +compile_call(Object *obj) { + context_printf(" ;; --> compile_call\n"); + context_printf(" push rbp\n"); + + // Prepare return pointer. + char *lab_end = generate_label("BDLL"); + context_printf(" lea rcx, [%s]\n", lab_end); + context_printf(" push rcx\n"); + + // Function call compilation without start/end. + size_t offset = compile_call_body(obj); // Restore stack to previous location and store the result on top. + context_printf("%s:\n", lab_end); context_printf(" add rsp, %zu\n", 8 * (offset + 2)); + context_printf(" pop rbp\n"); context_printf(" push rax\n"); context_printf(" ;; <-- compile_call\n"); } @@ -515,9 +524,14 @@ compile_lambda(Object *obj) { context_printf("alignb 8\n"); context_printf("%s:\n", name); + // Prepare size vars. + size_t n_locals = array_size(current_env->locals); + size_t n_params = array_size(current_env->params); + size_t n_captured = array_size(current_env->captured); + size_t offset = 8 * (n_locals + n_params + n_captured + 1); + // Initialize function call frame. - context_printf(" push rbp\n"); - context_printf(" sub rsp, %zu\n", 8 * array_size(current_env->locals)); + context_printf(" sub rsp, %zu\n", 8 * n_locals); context_printf(" mov rbp, rsp\n"); // Procedure body. @@ -533,21 +547,29 @@ compile_lambda(Object *obj) { compile_object(obj->body[i]); } compile_nil(); - compile_object(obj->body[array_size(obj->body) - 1]); + Object *last_expr = obj->body[array_size(obj->body) - 1]; - // Return is stored in the `rax`. - context_printf(" pop rax\n"); + // Tail Call Optimization. + // TODO: also for if statements + if (IS_PAIR(last_expr)) { + // Discard the previous stack frame. + context_printf(" mov rsp, rbp\n"); + context_printf(" add rsp, %zu\n", offset); + + compile_call_body(last_expr); + } else { + compile_object(last_expr); + + // Return is stored in the `rax`. + context_printf(" pop rax\n"); + + // Restore the previous call frame. + context_printf(" mov rdi, [rbp + %zu]\n", offset); + context_printf(" mov rsp, rbp\n"); + context_printf(" add rsp, %zu\n", 8 * n_locals); + context_printf(" jmp rdi\n"); + } - // Restore the previous call frame. - size_t n_locals = array_size(current_env->locals); - size_t n_params = array_size(current_env->params); - size_t n_captured = array_size(current_env->captured); - size_t offset = 8 * (n_locals + n_params + n_captured + 2); - context_printf(" mov rdi, [rbp + %zu]\n", offset); - context_printf(" mov rsp, rbp\n"); - context_printf(" add rsp, %zu\n", 8 * array_size(current_env->locals)); - context_printf(" pop rbp\n"); - context_printf(" jmp rdi\n"); context_printf("\n"); // Restore previous compilation context. @@ -560,7 +582,7 @@ compile_lambda(Object *obj) { context_printf(" mov [r15], rax\n"); // Add captured variables to the heap. - for (size_t i = 0; i < array_size(obj->env->captured); i++) { + for (size_t i = 0; i < n_captured; i++) { ssize_t idx = find_var_index(current_env->locals, obj->env->captured[i]); context_printf(" mov rax, rbp\n"); context_printf(" add rax, %ld\n", 8 * idx); @@ -577,7 +599,7 @@ compile_lambda(Object *obj) { context_printf(" push rax\n"); // Adjust the heap pointer depending on the number of variables captured. - context_printf(" add r15, %ld\n", 8 * (array_size(obj->env->captured) + 2)); + context_printf(" add r15, %ld\n", 8 * (n_captured + 1)); context_printf(" ;; <-- compile_lambda\n"); } @@ -610,7 +632,7 @@ compile_symbol(Object *obj) { size_t n_locals = array_size(current_env->locals); size_t n_params = array_size(current_env->params); size_t n_cap = array_size(current_env->captured); - size_t offset = 8 * (n_locals + n_params + n_cap - idx); + size_t offset = 8 * (n_locals + n_params + n_cap - idx - 1); context_printf(" mov rcx, [rbp + %ld]\n", offset); context_printf(" mov rax, [rcx]\n"); context_printf(" push rax\n"); @@ -632,7 +654,7 @@ compile_symbol(Object *obj) { if (idx != -1) { size_t n_locals = array_size(current_env->locals); size_t n_params = array_size(current_env->params); - size_t offset = 8 * (n_locals + n_params - idx); + size_t offset = 8 * (n_locals + n_params - idx - 1); context_printf(" mov rax, [rbp + %ld]\n", offset); context_printf(" push rax\n"); context_printf(" ;; <-- compile_symbol\n"); -- cgit v1.2.1