diff options
author | Bad Diode <bd@badd10de.dev> | 2021-11-16 22:10:52 +0100 |
---|---|---|
committer | Bad Diode <bd@badd10de.dev> | 2021-11-16 22:10:52 +0100 |
commit | 4344d214a99321f81e2af6d075ef789a6a324c0e (patch) | |
tree | 68815790d06054dcf39a9cd6589a5233a2b1d132 | |
parent | 92f33efa0cb88654e1b182416dbac631abf2be49 (diff) | |
download | bdl-4344d214a99321f81e2af6d075ef789a6a324c0e.tar.gz bdl-4344d214a99321f81e2af6d075ef789a6a324c0e.zip |
Add tail call optimization for function calls
-rw-r--r-- | src/compiler.h | 82 |
1 files changed, 52 insertions, 30 deletions
diff --git a/src/compiler.h b/src/compiler.h index 1ef81e3..e0d071c 100644 --- a/src/compiler.h +++ b/src/compiler.h | |||
@@ -340,15 +340,8 @@ compile_cdr(Object *obj) { | |||
340 | context_printf(" ;; <-- compile_cdr\n"); | 340 | context_printf(" ;; <-- compile_cdr\n"); |
341 | } | 341 | } |
342 | 342 | ||
343 | void | 343 | size_t |
344 | compile_call(Object *obj) { | 344 | compile_call_body(Object *obj) { |
345 | context_printf(" ;; --> compile_call\n"); | ||
346 | |||
347 | // Prepare return pointer. | ||
348 | char *lab_end = generate_label("BDLL"); | ||
349 | context_printf(" lea rcx, [%s]\n", lab_end); | ||
350 | context_printf(" push rcx\n"); | ||
351 | |||
352 | // Compile operator. | 345 | // Compile operator. |
353 | compile_object(obj->head); | 346 | compile_object(obj->head); |
354 | context_printf(" pop rax\n"); | 347 | context_printf(" pop rax\n"); |
@@ -385,13 +378,29 @@ compile_call(Object *obj) { | |||
385 | obj = obj->tail; | 378 | obj = obj->tail; |
386 | compile_object(obj->head); | 379 | compile_object(obj->head); |
387 | } | 380 | } |
388 | |||
389 | context_printf(" mov rax, [rsp + %zu]\n", 8 * offset); | 381 | context_printf(" mov rax, [rsp + %zu]\n", 8 * offset); |
390 | context_printf(" jmp rax\n"); | 382 | context_printf(" jmp rax\n"); |
391 | context_printf("%s:\n", lab_end); | 383 | |
384 | return offset; | ||
385 | } | ||
386 | |||
387 | void | ||
388 | compile_call(Object *obj) { | ||
389 | context_printf(" ;; --> compile_call\n"); | ||
390 | context_printf(" push rbp\n"); | ||
391 | |||
392 | // Prepare return pointer. | ||
393 | char *lab_end = generate_label("BDLL"); | ||
394 | context_printf(" lea rcx, [%s]\n", lab_end); | ||
395 | context_printf(" push rcx\n"); | ||
396 | |||
397 | // Function call compilation without start/end. | ||
398 | size_t offset = compile_call_body(obj); | ||
392 | 399 | ||
393 | // Restore stack to previous location and store the result on top. | 400 | // Restore stack to previous location and store the result on top. |
401 | context_printf("%s:\n", lab_end); | ||
394 | context_printf(" add rsp, %zu\n", 8 * (offset + 2)); | 402 | context_printf(" add rsp, %zu\n", 8 * (offset + 2)); |
403 | context_printf(" pop rbp\n"); | ||
395 | context_printf(" push rax\n"); | 404 | context_printf(" push rax\n"); |
396 | context_printf(" ;; <-- compile_call\n"); | 405 | context_printf(" ;; <-- compile_call\n"); |
397 | } | 406 | } |
@@ -515,9 +524,14 @@ compile_lambda(Object *obj) { | |||
515 | context_printf("alignb 8\n"); | 524 | context_printf("alignb 8\n"); |
516 | context_printf("%s:\n", name); | 525 | context_printf("%s:\n", name); |
517 | 526 | ||
527 | // Prepare size vars. | ||
528 | size_t n_locals = array_size(current_env->locals); | ||
529 | size_t n_params = array_size(current_env->params); | ||
530 | size_t n_captured = array_size(current_env->captured); | ||
531 | size_t offset = 8 * (n_locals + n_params + n_captured + 1); | ||
532 | |||
518 | // Initialize function call frame. | 533 | // Initialize function call frame. |
519 | context_printf(" push rbp\n"); | 534 | context_printf(" sub rsp, %zu\n", 8 * n_locals); |
520 | context_printf(" sub rsp, %zu\n", 8 * array_size(current_env->locals)); | ||
521 | context_printf(" mov rbp, rsp\n"); | 535 | context_printf(" mov rbp, rsp\n"); |
522 | 536 | ||
523 | // Procedure body. | 537 | // Procedure body. |
@@ -533,21 +547,29 @@ compile_lambda(Object *obj) { | |||
533 | compile_object(obj->body[i]); | 547 | compile_object(obj->body[i]); |
534 | } | 548 | } |
535 | compile_nil(); | 549 | compile_nil(); |
536 | compile_object(obj->body[array_size(obj->body) - 1]); | 550 | Object *last_expr = obj->body[array_size(obj->body) - 1]; |
537 | 551 | ||
538 | // Return is stored in the `rax`. | 552 | // Tail Call Optimization. |
539 | context_printf(" pop rax\n"); | 553 | // TODO: also for if statements |
554 | if (IS_PAIR(last_expr)) { | ||
555 | // Discard the previous stack frame. | ||
556 | context_printf(" mov rsp, rbp\n"); | ||
557 | context_printf(" add rsp, %zu\n", offset); | ||
558 | |||
559 | compile_call_body(last_expr); | ||
560 | } else { | ||
561 | compile_object(last_expr); | ||
562 | |||
563 | // Return is stored in the `rax`. | ||
564 | context_printf(" pop rax\n"); | ||
565 | |||
566 | // Restore the previous call frame. | ||
567 | context_printf(" mov rdi, [rbp + %zu]\n", offset); | ||
568 | context_printf(" mov rsp, rbp\n"); | ||
569 | context_printf(" add rsp, %zu\n", 8 * n_locals); | ||
570 | context_printf(" jmp rdi\n"); | ||
571 | } | ||
540 | 572 | ||
541 | // Restore the previous call frame. | ||
542 | size_t n_locals = array_size(current_env->locals); | ||
543 | size_t n_params = array_size(current_env->params); | ||
544 | size_t n_captured = array_size(current_env->captured); | ||
545 | size_t offset = 8 * (n_locals + n_params + n_captured + 2); | ||
546 | context_printf(" mov rdi, [rbp + %zu]\n", offset); | ||
547 | context_printf(" mov rsp, rbp\n"); | ||
548 | context_printf(" add rsp, %zu\n", 8 * array_size(current_env->locals)); | ||
549 | context_printf(" pop rbp\n"); | ||
550 | context_printf(" jmp rdi\n"); | ||
551 | context_printf("\n"); | 573 | context_printf("\n"); |
552 | 574 | ||
553 | // Restore previous compilation context. | 575 | // Restore previous compilation context. |
@@ -560,7 +582,7 @@ compile_lambda(Object *obj) { | |||
560 | context_printf(" mov [r15], rax\n"); | 582 | context_printf(" mov [r15], rax\n"); |
561 | 583 | ||
562 | // Add captured variables to the heap. | 584 | // Add captured variables to the heap. |
563 | for (size_t i = 0; i < array_size(obj->env->captured); i++) { | 585 | for (size_t i = 0; i < n_captured; i++) { |
564 | ssize_t idx = find_var_index(current_env->locals, obj->env->captured[i]); | 586 | ssize_t idx = find_var_index(current_env->locals, obj->env->captured[i]); |
565 | context_printf(" mov rax, rbp\n"); | 587 | context_printf(" mov rax, rbp\n"); |
566 | context_printf(" add rax, %ld\n", 8 * idx); | 588 | context_printf(" add rax, %ld\n", 8 * idx); |
@@ -577,7 +599,7 @@ compile_lambda(Object *obj) { | |||
577 | context_printf(" push rax\n"); | 599 | context_printf(" push rax\n"); |
578 | 600 | ||
579 | // Adjust the heap pointer depending on the number of variables captured. | 601 | // Adjust the heap pointer depending on the number of variables captured. |
580 | context_printf(" add r15, %ld\n", 8 * (array_size(obj->env->captured) + 2)); | 602 | context_printf(" add r15, %ld\n", 8 * (n_captured + 1)); |
581 | 603 | ||
582 | context_printf(" ;; <-- compile_lambda\n"); | 604 | context_printf(" ;; <-- compile_lambda\n"); |
583 | } | 605 | } |
@@ -610,7 +632,7 @@ compile_symbol(Object *obj) { | |||
610 | size_t n_locals = array_size(current_env->locals); | 632 | size_t n_locals = array_size(current_env->locals); |
611 | size_t n_params = array_size(current_env->params); | 633 | size_t n_params = array_size(current_env->params); |
612 | size_t n_cap = array_size(current_env->captured); | 634 | size_t n_cap = array_size(current_env->captured); |
613 | size_t offset = 8 * (n_locals + n_params + n_cap - idx); | 635 | size_t offset = 8 * (n_locals + n_params + n_cap - idx - 1); |
614 | context_printf(" mov rcx, [rbp + %ld]\n", offset); | 636 | context_printf(" mov rcx, [rbp + %ld]\n", offset); |
615 | context_printf(" mov rax, [rcx]\n"); | 637 | context_printf(" mov rax, [rcx]\n"); |
616 | context_printf(" push rax\n"); | 638 | context_printf(" push rax\n"); |
@@ -632,7 +654,7 @@ compile_symbol(Object *obj) { | |||
632 | if (idx != -1) { | 654 | if (idx != -1) { |
633 | size_t n_locals = array_size(current_env->locals); | 655 | size_t n_locals = array_size(current_env->locals); |
634 | size_t n_params = array_size(current_env->params); | 656 | size_t n_params = array_size(current_env->params); |
635 | size_t offset = 8 * (n_locals + n_params - idx); | 657 | size_t offset = 8 * (n_locals + n_params - idx - 1); |
636 | context_printf(" mov rax, [rbp + %ld]\n", offset); | 658 | context_printf(" mov rax, [rbp + %ld]\n", offset); |
637 | context_printf(" push rax\n"); | 659 | context_printf(" push rax\n"); |
638 | context_printf(" ;; <-- compile_symbol\n"); | 660 | context_printf(" ;; <-- compile_symbol\n"); |