aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBad Diode <bd@badd10de.dev>2021-11-16 22:10:52 +0100
committerBad Diode <bd@badd10de.dev>2021-11-16 22:10:52 +0100
commit4344d214a99321f81e2af6d075ef789a6a324c0e (patch)
tree68815790d06054dcf39a9cd6589a5233a2b1d132
parent92f33efa0cb88654e1b182416dbac631abf2be49 (diff)
downloadbdl-4344d214a99321f81e2af6d075ef789a6a324c0e.tar.gz
bdl-4344d214a99321f81e2af6d075ef789a6a324c0e.zip
Add tail call optimization for function calls
-rw-r--r--src/compiler.h82
1 files changed, 52 insertions, 30 deletions
diff --git a/src/compiler.h b/src/compiler.h
index 1ef81e3..e0d071c 100644
--- a/src/compiler.h
+++ b/src/compiler.h
@@ -340,15 +340,8 @@ compile_cdr(Object *obj) {
340 context_printf(" ;; <-- compile_cdr\n"); 340 context_printf(" ;; <-- compile_cdr\n");
341} 341}
342 342
343void 343size_t
344compile_call(Object *obj) { 344compile_call_body(Object *obj) {
345 context_printf(" ;; --> compile_call\n");
346
347 // Prepare return pointer.
348 char *lab_end = generate_label("BDLL");
349 context_printf(" lea rcx, [%s]\n", lab_end);
350 context_printf(" push rcx\n");
351
352 // Compile operator. 345 // Compile operator.
353 compile_object(obj->head); 346 compile_object(obj->head);
354 context_printf(" pop rax\n"); 347 context_printf(" pop rax\n");
@@ -385,13 +378,29 @@ compile_call(Object *obj) {
385 obj = obj->tail; 378 obj = obj->tail;
386 compile_object(obj->head); 379 compile_object(obj->head);
387 } 380 }
388
389 context_printf(" mov rax, [rsp + %zu]\n", 8 * offset); 381 context_printf(" mov rax, [rsp + %zu]\n", 8 * offset);
390 context_printf(" jmp rax\n"); 382 context_printf(" jmp rax\n");
391 context_printf("%s:\n", lab_end); 383
384 return offset;
385}
386
387void
388compile_call(Object *obj) {
389 context_printf(" ;; --> compile_call\n");
390 context_printf(" push rbp\n");
391
392 // Prepare return pointer.
393 char *lab_end = generate_label("BDLL");
394 context_printf(" lea rcx, [%s]\n", lab_end);
395 context_printf(" push rcx\n");
396
397 // Function call compilation without start/end.
398 size_t offset = compile_call_body(obj);
392 399
393 // Restore stack to previous location and store the result on top. 400 // Restore stack to previous location and store the result on top.
401 context_printf("%s:\n", lab_end);
394 context_printf(" add rsp, %zu\n", 8 * (offset + 2)); 402 context_printf(" add rsp, %zu\n", 8 * (offset + 2));
403 context_printf(" pop rbp\n");
395 context_printf(" push rax\n"); 404 context_printf(" push rax\n");
396 context_printf(" ;; <-- compile_call\n"); 405 context_printf(" ;; <-- compile_call\n");
397} 406}
@@ -515,9 +524,14 @@ compile_lambda(Object *obj) {
515 context_printf("alignb 8\n"); 524 context_printf("alignb 8\n");
516 context_printf("%s:\n", name); 525 context_printf("%s:\n", name);
517 526
527 // Prepare size vars.
528 size_t n_locals = array_size(current_env->locals);
529 size_t n_params = array_size(current_env->params);
530 size_t n_captured = array_size(current_env->captured);
531 size_t offset = 8 * (n_locals + n_params + n_captured + 1);
532
518 // Initialize function call frame. 533 // Initialize function call frame.
519 context_printf(" push rbp\n"); 534 context_printf(" sub rsp, %zu\n", 8 * n_locals);
520 context_printf(" sub rsp, %zu\n", 8 * array_size(current_env->locals));
521 context_printf(" mov rbp, rsp\n"); 535 context_printf(" mov rbp, rsp\n");
522 536
523 // Procedure body. 537 // Procedure body.
@@ -533,21 +547,29 @@ compile_lambda(Object *obj) {
533 compile_object(obj->body[i]); 547 compile_object(obj->body[i]);
534 } 548 }
535 compile_nil(); 549 compile_nil();
536 compile_object(obj->body[array_size(obj->body) - 1]); 550 Object *last_expr = obj->body[array_size(obj->body) - 1];
537 551
538 // Return is stored in the `rax`. 552 // Tail Call Optimization.
539 context_printf(" pop rax\n"); 553 // TODO: also for if statements
554 if (IS_PAIR(last_expr)) {
555 // Discard the previous stack frame.
556 context_printf(" mov rsp, rbp\n");
557 context_printf(" add rsp, %zu\n", offset);
558
559 compile_call_body(last_expr);
560 } else {
561 compile_object(last_expr);
562
563 // Return is stored in the `rax`.
564 context_printf(" pop rax\n");
565
566 // Restore the previous call frame.
567 context_printf(" mov rdi, [rbp + %zu]\n", offset);
568 context_printf(" mov rsp, rbp\n");
569 context_printf(" add rsp, %zu\n", 8 * n_locals);
570 context_printf(" jmp rdi\n");
571 }
540 572
541 // Restore the previous call frame.
542 size_t n_locals = array_size(current_env->locals);
543 size_t n_params = array_size(current_env->params);
544 size_t n_captured = array_size(current_env->captured);
545 size_t offset = 8 * (n_locals + n_params + n_captured + 2);
546 context_printf(" mov rdi, [rbp + %zu]\n", offset);
547 context_printf(" mov rsp, rbp\n");
548 context_printf(" add rsp, %zu\n", 8 * array_size(current_env->locals));
549 context_printf(" pop rbp\n");
550 context_printf(" jmp rdi\n");
551 context_printf("\n"); 573 context_printf("\n");
552 574
553 // Restore previous compilation context. 575 // Restore previous compilation context.
@@ -560,7 +582,7 @@ compile_lambda(Object *obj) {
560 context_printf(" mov [r15], rax\n"); 582 context_printf(" mov [r15], rax\n");
561 583
562 // Add captured variables to the heap. 584 // Add captured variables to the heap.
563 for (size_t i = 0; i < array_size(obj->env->captured); i++) { 585 for (size_t i = 0; i < n_captured; i++) {
564 ssize_t idx = find_var_index(current_env->locals, obj->env->captured[i]); 586 ssize_t idx = find_var_index(current_env->locals, obj->env->captured[i]);
565 context_printf(" mov rax, rbp\n"); 587 context_printf(" mov rax, rbp\n");
566 context_printf(" add rax, %ld\n", 8 * idx); 588 context_printf(" add rax, %ld\n", 8 * idx);
@@ -577,7 +599,7 @@ compile_lambda(Object *obj) {
577 context_printf(" push rax\n"); 599 context_printf(" push rax\n");
578 600
579 // Adjust the heap pointer depending on the number of variables captured. 601 // Adjust the heap pointer depending on the number of variables captured.
580 context_printf(" add r15, %ld\n", 8 * (array_size(obj->env->captured) + 2)); 602 context_printf(" add r15, %ld\n", 8 * (n_captured + 1));
581 603
582 context_printf(" ;; <-- compile_lambda\n"); 604 context_printf(" ;; <-- compile_lambda\n");
583} 605}
@@ -610,7 +632,7 @@ compile_symbol(Object *obj) {
610 size_t n_locals = array_size(current_env->locals); 632 size_t n_locals = array_size(current_env->locals);
611 size_t n_params = array_size(current_env->params); 633 size_t n_params = array_size(current_env->params);
612 size_t n_cap = array_size(current_env->captured); 634 size_t n_cap = array_size(current_env->captured);
613 size_t offset = 8 * (n_locals + n_params + n_cap - idx); 635 size_t offset = 8 * (n_locals + n_params + n_cap - idx - 1);
614 context_printf(" mov rcx, [rbp + %ld]\n", offset); 636 context_printf(" mov rcx, [rbp + %ld]\n", offset);
615 context_printf(" mov rax, [rcx]\n"); 637 context_printf(" mov rax, [rcx]\n");
616 context_printf(" push rax\n"); 638 context_printf(" push rax\n");
@@ -632,7 +654,7 @@ compile_symbol(Object *obj) {
632 if (idx != -1) { 654 if (idx != -1) {
633 size_t n_locals = array_size(current_env->locals); 655 size_t n_locals = array_size(current_env->locals);
634 size_t n_params = array_size(current_env->params); 656 size_t n_params = array_size(current_env->params);
635 size_t offset = 8 * (n_locals + n_params - idx); 657 size_t offset = 8 * (n_locals + n_params - idx - 1);
636 context_printf(" mov rax, [rbp + %ld]\n", offset); 658 context_printf(" mov rax, [rbp + %ld]\n", offset);
637 context_printf(" push rax\n"); 659 context_printf(" push rax\n");
638 context_printf(" ;; <-- compile_symbol\n"); 660 context_printf(" ;; <-- compile_symbol\n");