diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index d78d17e..8406925 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -171,7 +171,7 @@ int main(int argc, char ** argv) // TODO: remove, testing-only code; should execute entry point instead auto offset = module.code.symbol_table.at("test"); auto fptr = (int(*)(int))(module.code.memory.data.get() + offset); - auto x = fptr(42); + auto x = fptr(10); std::cout << "Result: " << std::boolalpha << x << std::endl; } } diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 11b367e..08a6fd1 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -1,2 +1,6 @@ -func test(x : i32) -> bool: - return x <= 42 +func factorial(x : i32) -> i32: + mut r = 1 + while x > 0: + r = r * x + x = x - 1 + return r diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 8a58b9c..7bae70d 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -94,15 +94,32 @@ namespace pslang::jit::aarch64 // Take @bit_count lowest bits from @reg_src, copy them to @reg_dst, and zero-extend @reg_dst void ubfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count); + // Perform a bitwise and of @reg_src1 and @reg_src2 and set the flags + void tst(std::uint8_t reg_src1, std::uint8_t reg_src2); + + // If the value of @reg_src is non-zero, move the program counter to the value of + // 19-bit signed @offset multiplied by 4 + void cbnz(std::uint8_t reg_src, std::int32_t offset); + + // If the value of @reg_src is zero, move the program counter to the value of + // 19-bit signed @offset multiplied by 4 + void cbz(std::uint8_t reg_src, std::int32_t offset); + + // Inject the 19-bit signed @offset into the opcode of a cbz or cbnz instruction + // starting at @opcode + void cb_inject(std::uint8_t * opcode, std::int32_t offset); + + // Unconditionally move the program counter to the value of + // 26-bit signed @offset multiplied by 4 + void b(std::int32_t offset); + + // Inject the 26-bit signed @offset into the opcode of b instruction + // starting at @opcode + void b_inject(std::uint8_t * opcode, std::int32_t offset); + // Return from a subroutine, taking the return address from register @reg void ret(std::uint8_t reg = 30); - // Helper function: push a 64-bit register @reg to the stack - void push(std::uint8_t reg); - - // Helper function: pop a 64-bit register @reg from the stack - void pop(std::uint8_t reg); - private: void do_push(std::uint32_t opcode); }; diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index c145666..09eb3af 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -1,4 +1,3 @@ -#include "pslang/ast/statement_fwd.hpp" #include #include #include @@ -266,6 +265,7 @@ namespace pslang::jit::aarch64 push(0); apply(*node.arg2); pop(1); + builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); else @@ -323,6 +323,67 @@ namespace pslang::jit::aarch64 scopes.back().variables[node.name] = {.frame_offset = stack_offset}; } + void apply(ast::if_chain const & node) + { + std::vector branch_to_end; + for (std::size_t i = 0; i < node.blocks.size(); ++i) + { + auto const & block = node.blocks[i]; + + std::optional branch_skip; + if (block.condition) + { + apply(*block.condition); + branch_skip = context.code.size(); + builder.cbz(0, 0); + } + + scopes.emplace_back(); + apply(*block.statements); + scope_cleanup(); + scopes.pop_back(); + + if (i + 1 < node.blocks.size()) + { + branch_to_end.push_back(context.code.size()); + builder.b(0); + } + + if (branch_skip) + { + auto branch_offset = context.code.size() - *branch_skip; + builder.cb_inject(context.code.data() + *branch_skip, branch_offset / 4); + } + } + + auto end = context.code.size(); + for (auto instruction : branch_to_end) + { + auto delta = end - instruction; + builder.b_inject(context.code.data() + instruction, delta / 4); + } + } + + void apply(ast::while_block const & node) + { + std::int32_t start = context.code.size(); + apply(*node.condition); + std::int32_t skip = context.code.size(); + builder.cbz(0, 0); + + scopes.emplace_back(); + apply(*node.statements); + scope_cleanup(); + scopes.pop_back(); + + std::int32_t loop = context.code.size(); + builder.b(0); + std::int32_t end = context.code.size(); + + builder.cb_inject(context.code.data() + skip, (end - skip) / 4); + builder.b_inject(context.code.data() + loop, (start - loop) / 4); + } + void apply(ast::return_statement const & node) { apply(*node.value); @@ -340,7 +401,6 @@ namespace pslang::jit::aarch64 { for (auto const & statement : node.statements) apply(*statement); - builder.add_imm(31, 31, scopes.back().stack_offset); } void do_apply(ast::function_definition const & node) @@ -385,6 +445,12 @@ namespace pslang::jit::aarch64 { reg_extend_visitor{{}, builder, reg}.apply(*type); } + + void scope_cleanup() + { + if (scopes.back().stack_offset > 0) + builder.add_imm(31, 31, scopes.back().stack_offset); + } }; struct compile_visitor diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp index ea7f4b0..e8302ba 100644 --- a/libs/jit/source/arch/aarch64/instruction_builder.cpp +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -125,21 +125,45 @@ namespace pslang::jit::aarch64 do_push(0xd3400000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((bit_count - 1) & 0x3fu) << 10)); } + void instruction_builder::tst(std::uint8_t reg_src1, std::uint8_t reg_src2) + { + do_push(0xea00001fu | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::cbnz(std::uint8_t reg_src, std::int32_t offset) + { + do_push(0xb5000000u | (reg_src & REG_MASK) | ((std::uint32_t(offset) & 0x7ffffu) << 5)); + } + + void instruction_builder::cbz(std::uint8_t reg_src, std::int32_t offset) + { + do_push(0xb4000000u | (reg_src & REG_MASK) | ((std::uint32_t(offset) & 0x7ffffu) << 5)); + } + + void instruction_builder::cb_inject(std::uint8_t * opcode, std::int32_t offset) + { + auto dst = (std::uint32_t *)opcode; + *dst &= ~(0x7ffffu << 5); + *dst |= (std::uint32_t(offset) & 0x7ffffu) << 5; + } + + void instruction_builder::b(std::int32_t offset) + { + do_push(0x14000000u | (std::uint32_t(offset) & 0x3ffffffu)); + } + + void instruction_builder::b_inject(std::uint8_t * opcode, std::int32_t offset) + { + auto dst = (std::uint32_t *)opcode; + *dst &= ~0x3ffffffu; + *dst |= (std::uint32_t(offset) & 0x3ffffffu); + } + void instruction_builder::ret(std::uint8_t reg) { do_push(0xd65f0000u | ((reg & REG_MASK) << 5)); } - void instruction_builder::push(std::uint8_t reg) - { - str_pre(reg, 31, -8); - } - - void instruction_builder::pop(std::uint8_t reg) - { - ldr_post(reg, 31, 8); - } - void instruction_builder::do_push(std::uint32_t opcode) { code.push_back((opcode >> 0) & 0xffu);