From 082764ad89c1b9613b6894d6593dee66041a5e54 Mon Sep 17 00:00:00 2001 From: Bad Diode Date: Thu, 30 Dec 2021 16:07:41 +0100 Subject: Add WIP compilation of lambdas --- src/ir.h | 173 ++++++++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 116 insertions(+), 57 deletions(-) diff --git a/src/ir.h b/src/ir.h index d9706ee..d92f912 100644 --- a/src/ir.h +++ b/src/ir.h @@ -124,15 +124,13 @@ typedef struct Procedure { // Procedure name. char *name; + struct Procedure *parent; + // Program code. Instruction *instructions; // Locals code. Object **locals; - - // Number of locals and parameters. - size_t n_params; - size_t n_locals; } Procedure; typedef struct ProgramIr { @@ -205,12 +203,13 @@ print_procedure(Procedure *proc) { } Procedure * -proc_alloc(ProgramIr *program, StringView name) { +proc_alloc(ProgramIr *program, StringView name, Procedure *parent) { Procedure *proc = calloc(1, sizeof(Procedure)); array_init(proc->name, name.n); array_insert(proc->name, name.start, name.n); array_init(proc->instructions, 0); array_init(proc->locals, 0); + proc->parent = parent; array_push(program->procedures, proc); return proc; } @@ -304,58 +303,63 @@ compile_or(ProgramIr *program, Procedure *proc, } void -compile_proc_call(ProgramIr *program, Procedure *proc, Object *obj) { +compile_builtin(ProgramIr *program, Procedure *proc, Object *obj) { size_t line = obj->line; size_t col = obj->col; - if (obj->head->type == OBJ_TYPE_BUILTIN) { - switch (obj->head->builtin) { - case BUILTIN_ADD: { - compile_arithmetic(program, proc, OP_ADD, line, col, obj->tail); - } break; - case BUILTIN_SUB: { - compile_arithmetic(program, proc, OP_SUB, line, col, obj->tail); - } break; - case BUILTIN_MUL: { - compile_arithmetic(program, proc, OP_MUL, line, col, obj->tail); - } break; - case BUILTIN_DIV: { - compile_arithmetic(program, proc, OP_DIV, line, col, obj->tail); - } break; - case BUILTIN_MOD: { - compile_arithmetic(program, proc, OP_MOD, line, col, obj->tail); - } break; - case BUILTIN_PRINT: { - compile_print(program, proc, line, col, obj->tail); - } break; - case BUILTIN_NOT: { - compile_not(program, proc, line, col, obj->tail); - } break; - case BUILTIN_AND: { - compile_and(program, proc, line, col, obj->tail); - } break; - case BUILTIN_OR: { - compile_or(program, proc, line, col, obj->tail); - } break; - case BUILTIN_EQ: { - compile_numeric_cmp(program, proc, OP_JUMP_IF_NEQ, line, col, obj->tail); - } break; - case BUILTIN_GT: { - compile_numeric_cmp(program, proc, OP_JUMP_IF_LE, line, col, obj->tail); - } break; - case BUILTIN_LT: { - compile_numeric_cmp(program, proc, OP_JUMP_IF_GE, line, col, obj->tail); - } break; - case BUILTIN_GE: { - compile_numeric_cmp(program, proc, OP_JUMP_IF_LT, line, col, obj->tail); - } break; - case BUILTIN_LE: { - compile_numeric_cmp(program, proc, OP_JUMP_IF_GT, line, col, obj->tail); - } break; - // TODO: cons, car, cdr, type checks (nil? zero? fixnum? bool? ...) - default: { - assert(false && "builtin not implemented"); - } break; - } + switch (obj->head->builtin) { + case BUILTIN_ADD: { + compile_arithmetic(program, proc, OP_ADD, line, col, obj->tail); + } break; + case BUILTIN_SUB: { + compile_arithmetic(program, proc, OP_SUB, line, col, obj->tail); + } break; + case BUILTIN_MUL: { + compile_arithmetic(program, proc, OP_MUL, line, col, obj->tail); + } break; + case BUILTIN_DIV: { + compile_arithmetic(program, proc, OP_DIV, line, col, obj->tail); + } break; + case BUILTIN_MOD: { + compile_arithmetic(program, proc, OP_MOD, line, col, obj->tail); + } break; + case BUILTIN_PRINT: { + compile_print(program, proc, line, col, obj->tail); + } break; + case BUILTIN_NOT: { + compile_not(program, proc, line, col, obj->tail); + } break; + case BUILTIN_AND: { + compile_and(program, proc, line, col, obj->tail); + } break; + case BUILTIN_OR: { + compile_or(program, proc, line, col, obj->tail); + } break; + case BUILTIN_EQ: { + compile_numeric_cmp(program, proc, OP_JUMP_IF_NEQ, line, col, obj->tail); + } break; + case BUILTIN_GT: { + compile_numeric_cmp(program, proc, OP_JUMP_IF_LE, line, col, obj->tail); + } break; + case BUILTIN_LT: { + compile_numeric_cmp(program, proc, OP_JUMP_IF_GE, line, col, obj->tail); + } break; + case BUILTIN_GE: { + compile_numeric_cmp(program, proc, OP_JUMP_IF_LT, line, col, obj->tail); + } break; + case BUILTIN_LE: { + compile_numeric_cmp(program, proc, OP_JUMP_IF_GT, line, col, obj->tail); + } break; + // TODO: cons, car, cdr, type checks (nil? zero? fixnum? bool? ...) + default: { + assert(false && "builtin not implemented"); + } break; + } +} + +void +compile_proc_call(ProgramIr *program, Procedure *proc, Object *obj) { + if (IS_BUILTIN(obj->head)) { + compile_builtin(program, proc, obj); } else { assert(false && "compile_proc_call: not implemented"); } @@ -389,6 +393,61 @@ compile_def(ProgramIr *program, Procedure *proc, Object *obj) { INST_VAR(proc, OP_STORE_LOCAL, idx, obj->line, obj->col); } +void +compile_lambda(ProgramIr *program, Procedure *proc, Object *obj) { + Procedure *lambda = proc_alloc(program, STRING("lambda"), proc); + for (size_t i = 0; i < array_size(obj->body) - 1; i++) { + compile_object(program, lambda, obj->body[i]); + } + Object *last_expr = obj->body[array_size(obj->body) - 1]; + + // Tail Call Optimization. + // TODO: also for if statements + if (IS_PAIR(last_expr)) { + if (IS_BUILTIN(last_expr->head)) { + compile_builtin(program, lambda, last_expr); + } else { + // Discard the previous stack frame. + // context_printf(" mov rsp, rbp\n"); + + // size_t old_offset = n_locals + n_captured + n_params; + // size_t new_offset = compile_call_body(last_expr); + // context_printf(" mov rdi, [rbp - 8]\n"); + // for (size_t i = 0; i < new_offset + 1; i++) { + // context_printf(" mov rax, [rbp - 8 * %zu]\n", i + 1); + // context_printf(" mov [rbp + 8 * %zu], rax\n", old_offset - i); + // } + + // // Set the stack pointer at the end of given parameters. + // context_printf(" mov rsp, rbp\n"); + // ssize_t offset_diff = old_offset - new_offset; + // if (offset_diff > 0) { + // context_printf(" add rsp, 8 * %zu\n", offset_diff); + // } else { + // context_printf(" sub rsp, 8 * %zu\n", offset_diff); + // } + + // context_printf(" jmp rdi\n"); + } + } else { + // compile_nil(); + compile_object(program, lambda, last_expr); + + // // Return is stored in the `rax`. + // context_printf(" pop rax\n"); + + // // Restore the previous call frame. + // size_t rp_offset = (n_locals + n_params + n_captured + 1); + // context_printf(" mov rdi, [rbp + %zu]\n", 8 * rp_offset); + // context_printf(" mov rsp, rbp\n"); + // context_printf(" add rsp, %zu\n", 8 * n_locals); + // context_printf(" jmp rdi\n"); + } + INST_SIMPLE(lambda, OP_RETURN, obj->line, obj->col); + + INST_ARG(proc, OP_PUSH, obj, obj->line, obj->col); +} + void compile_object(ProgramIr *program, Procedure *proc, Object *obj) { switch (obj->type) { @@ -399,7 +458,7 @@ compile_object(ProgramIr *program, Procedure *proc, Object *obj) { case OBJ_TYPE_FIXNUM: { INST_ARG(proc, OP_PUSH, obj, obj->line, obj->col); } break; case OBJ_TYPE_PAIR: { compile_proc_call(program, proc, obj); } break; case OBJ_TYPE_IF: { compile_if(program, proc, obj); } break; - // case OBJ_TYPE_LAMBDA: { compile_lambda(obj); } break; + case OBJ_TYPE_LAMBDA: { compile_lambda(program, proc, obj); } break; case OBJ_TYPE_DEF: { compile_def(program, proc, obj); } break; // case OBJ_TYPE_SYMBOL: { compile_symbol(obj); } break; default: { @@ -415,7 +474,7 @@ ProgramIr compile(Program program) { ProgramIr program_ir = {0}; array_init(program_ir.procedures, 0); - Procedure *main = proc_alloc(&program_ir, STRING("main")); + Procedure *main = proc_alloc(&program_ir, STRING("main"), NULL); for (size_t i = 0; i < array_size(program.roots); i++) { Object *root = program.roots[i]; -- cgit v1.2.1