From d5065ec38ee4f008350ed52be9f530540617b37c Mon Sep 17 00:00:00 2001 From: lisyarus Date: Sat, 3 Jan 2026 22:38:43 +0300 Subject: [PATCH] Aarch64 JIT-compiling wip: no-argument functions, arithmetic & logical operations, comparisons (integers only) --- apps/interpreter/CMakeLists.txt | 6 +- apps/interpreter/source/main.cpp | 30 +- examples/jit_test.psl | 2 + libs/jit/include/pslang/jit/abi.hpp | 2 + .../pslang/jit/arch/aarch64/compiler.hpp | 10 + .../jit/arch/aarch64/instruction_builder.hpp | 107 +++++ .../include/pslang/jit/compiled_module.hpp | 3 + libs/jit/include/pslang/jit/executable.hpp | 2 + libs/jit/source/abi.cpp | 19 + libs/jit/source/arch/aarch64/compiler.cpp | 367 ++++++++++++++++++ .../arch/aarch64/instruction_builder.cpp | 146 +++++++ libs/jit/source/executable.cpp | 16 +- libs/jit/source/jit.cpp | 20 +- 13 files changed, 716 insertions(+), 14 deletions(-) create mode 100644 examples/jit_test.psl create mode 100644 libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp create mode 100644 libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp create mode 100644 libs/jit/source/abi.cpp create mode 100644 libs/jit/source/arch/aarch64/compiler.cpp create mode 100644 libs/jit/source/arch/aarch64/instruction_builder.cpp diff --git a/apps/interpreter/CMakeLists.txt b/apps/interpreter/CMakeLists.txt index a86a52d..5372255 100644 --- a/apps/interpreter/CMakeLists.txt +++ b/apps/interpreter/CMakeLists.txt @@ -3,4 +3,8 @@ file(GLOB_RECURSE PSLI_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/source/*.cpp") add_executable(psli ${PSLI_HEADERS} ${PSLI_SOURCES}) target_include_directories(psli PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -target_link_libraries(psli PUBLIC pslang-parser pslang-interpreter) +target_link_libraries(psli PUBLIC + pslang-parser + pslang-interpreter + pslang-jit +) diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 1920f2e..6da0da2 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include #include #include @@ -102,7 +104,6 @@ int main(int argc, char ** argv) if (std::strcmp(argv[arg], "-j") == 0 || std::strcmp(argv[arg], "--jit") == 0) { - std::cerr << "Warning: JIT-compilation not supported yet" << std::endl; jit = true; continue; } @@ -158,9 +159,28 @@ int main(int argc, char ** argv) std::cout << std::flush; } - for (auto const & ast : parsed) - interpreter::exec(context, ast); + if (jit) + { + auto abi = jit::host_abi(); + std::vector modules; + for (auto const & ast : parsed) + modules.push_back(jit::make_host_executable(jit::compile(ast, abi))); - if (dump) - interpreter::dump(std::cout, context); + for (auto const & module : modules) + { + // TODO: remove, testing-only code; should execute entry point instead + auto offset = module.code.symbol_table.at("test"); + auto fptr = (bool(*)())(module.code.memory.data.get() + offset); + auto x = fptr(); + std::cout << "Result: " << std::boolalpha << x << std::endl; + } + } + else + { + for (auto const & ast : parsed) + interpreter::exec(context, ast); + + if (dump) + interpreter::dump(std::cout, context); + } } diff --git a/examples/jit_test.psl b/examples/jit_test.psl new file mode 100644 index 0000000..00e3c4d --- /dev/null +++ b/examples/jit_test.psl @@ -0,0 +1,2 @@ +func test() -> i32: + return -10*9 diff --git a/libs/jit/include/pslang/jit/abi.hpp b/libs/jit/include/pslang/jit/abi.hpp index c55c3b6..1956c6e 100644 --- a/libs/jit/include/pslang/jit/abi.hpp +++ b/libs/jit/include/pslang/jit/abi.hpp @@ -10,4 +10,6 @@ namespace pslang::jit armv8, }; + abi host_abi(); + } diff --git a/libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp b/libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp new file mode 100644 index 0000000..0c75f35 --- /dev/null +++ b/libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace pslang::jit::aarch64 +{ + + compiled_module compile(ast::statement_list_ptr const & statements); + +} \ No newline at end of file diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp new file mode 100644 index 0000000..7be2cbf --- /dev/null +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -0,0 +1,107 @@ +#pragma once + +#include +#include + +namespace pslang::jit::aarch64 +{ + + struct instruction_builder + { + std::vector & code; + + // NB: stack pointer is register 31 + + // Move @val shifted by 16*@shift bits into register @reg, zeroing out other bits + // @shift must be 0, 1, 2 or 3 + void movz(std::uint8_t reg, std::uint16_t val, std::uint8_t shift = 0); + + // Move @val shifted by 16*@shift bits into register @reg, keeping other bits intact + // @shift must be 0, 1, 2 or 3 + void movk(std::uint8_t reg, std::uint16_t val, std::uint8_t shift = 0); + + // Load the value of the register @reg_src at the address specified by the value of + // register @reg_addr plus an unsigned 12-bit offset multiplied by 8. + void str(std::uint8_t reg_src, std::uint8_t reg_addr, std::uint16_t offset); + + // Store the value of register @reg_src into an address specified by the value of register @reg_addr + // plus a signed 9-bit offset. Store the new address value in @reg_addr + void str_pre(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); + + // Load the value at address specified by the value of register @reg_addr + // plus an unsigned 12-bit offset multiplied by 8 and store it in the register @reg_dst. + void ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset); + + // Load the value at address specified by the value of register @reg_addr + // plus a signed 9-bit offset and store it in the register @reg_dst. + // Store the new address value in @reg_addr + void ldr_pre(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); + + // Load the value at address specified by the value of register @reg_addr + // and store it in the register @reg_dst. + // Store the address plus a signed 9-bit offset in @reg_addr + void ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); + + // Add a 12-bit @value to the register @reg_src and store the result in @reg_dst + void add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value); + + // Add a 12-bit @value to the register @reg_src and store the result in @reg_dst + void sub_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value); + + // Compute the value of (@reg_src1 + @reg_src2) and store the result in @reg_dst + void add_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of (@reg_src1 - @reg_src2) and store the result in @reg_dst + void sub_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of (@reg_src1 and @reg_src2) and store the result in @reg_dst + void and_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of (@reg_src1 or @reg_src2) and store the result in @reg_dst + void or_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of (@reg_src1 xor @reg_src2) and store the result in @reg_dst + void xor_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of (@reg_src1 or not @reg_src2) and store the result in @reg_dst + void or_not_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of (@reg_src1 * @reg_src2) and store the result in @reg_dst + void mul_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of signed division (@reg_src1 / @reg_src2) and store the result in @reg_dst + void sdiv_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compute the value of unsigned division (@reg_src1 / @reg_src2) and store the result in @reg_dst + void udiv_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst); + + // Compare the values of @reg_src1 and @reg_src2 and set the flags + void cmp_reg(std::uint8_t reg_src1, std::uint8_t reg_src2); + + // Set the value of @reg_dst to 1 is condition @cond is true + void cset(std::uint8_t reg_dst, std::uint8_t cond); + + // Set the value of @reg_dst to @reg_src1 if condition @cond is true, otherwise + // to @reg_src2. + void csel(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst, std::uint8_t cond); + + // Take @bit_count lowest bits from @reg_src, copy them to @reg_dst, and sign-extend @reg_dst + void sbfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count); + + // Take @bit_count lowest bits from @reg_src, copy them to @reg_dst, and zero-extend @reg_dst + void ubfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count); + + // Return from a subroutine, taking the return address from register @reg + void ret(std::uint8_t reg = 30); + + // Helper function: push a 64-bit register @reg to the stack + void push(std::uint8_t reg); + + // Helper function: pop a 64-bit register @reg from the stack + void pop(std::uint8_t reg); + + private: + void do_push(std::uint32_t opcode); + }; + +} \ No newline at end of file diff --git a/libs/jit/include/pslang/jit/compiled_module.hpp b/libs/jit/include/pslang/jit/compiled_module.hpp index 6f8ab6c..db76ec2 100644 --- a/libs/jit/include/pslang/jit/compiled_module.hpp +++ b/libs/jit/include/pslang/jit/compiled_module.hpp @@ -17,6 +17,9 @@ namespace pslang::jit std::unordered_map symbol_table; }; + // TODO: How do we reference data segment from the code segment? + // Maybe use the same relocation/dynamic linking mechanism + // as for imported modules? segment data; segment code; std::size_t entry_point; // in code segment diff --git a/libs/jit/include/pslang/jit/executable.hpp b/libs/jit/include/pslang/jit/executable.hpp index 4f0d372..907a495 100644 --- a/libs/jit/include/pslang/jit/executable.hpp +++ b/libs/jit/include/pslang/jit/executable.hpp @@ -5,6 +5,8 @@ namespace pslang::jit { + blob allocate(std::size_t size); + compiled_module make_host_executable(compiled_module module); } diff --git a/libs/jit/source/abi.cpp b/libs/jit/source/abi.cpp new file mode 100644 index 0000000..2a38f11 --- /dev/null +++ b/libs/jit/source/abi.cpp @@ -0,0 +1,19 @@ +#include + +namespace pslang::jit +{ + + abi host_abi() + { +#if defined(__linux__) and defined(__x86_64__) + return abi::itanium; +#elif defined(__APPLE__) and defined(__ARM64_ARCH_8__) + return abi::armv8; +#elif defined(_WIN64) and defined(_M_AMD64) + return abi::msvs; +#else + #error Unknown host ABI +#endif + } + +} \ No newline at end of file diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp new file mode 100644 index 0000000..eeb549c --- /dev/null +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -0,0 +1,367 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace pslang::jit::aarch64 +{ + + namespace + { + + struct context + { + std::vector code; + std::unordered_map code_symbol_table; + }; + + struct reg_extend_visitor + : types::const_visitor + { + using const_visitor::apply; + + instruction_builder & builder; + std::uint8_t reg; + + void apply(types::bool_type const &) + {} + + void apply(types::f32_type const &) + {} + + void apply(types::f64_type const &) + {} + + template + void apply(types::primitive_type_base const &) + { + if constexpr (sizeof(T) == 8) + { + return; + } + + if constexpr (std::is_signed_v) + { + builder.sbfm(reg, reg, sizeof(T) * 8); + } + + if constexpr (std::is_unsigned_v) + { + builder.ubfm(reg, reg, sizeof(T) * 8); + } + } + + template + void apply(T const &) + { + throw std::runtime_error("Not implemented"); + } + }; + + struct compile_function_visitor + : ast::const_statement_visitor + , ast::const_expression_visitor + { + using const_statement_visitor::apply; + using const_expression_visitor::apply; + + context & context; + instruction_builder builder{context.code}; + std::uint32_t stack_free_bytes = 0; // must be a multiple of 8 + + template + void apply(Node const &) + { + throw std::runtime_error("Not implemented"); + } + + void apply(ast::bool_literal const & node) + { + if (node.value) + set_m1(0); + else + builder.movz(0, 0); + } + + template + requires(std::is_integral_v && !std::is_same_v) + void apply(ast::primitive_literal_base const & node) + { + for (std::size_t i = 0; i < sizeof(T); i += 2) + { + if (i == 0) + { + builder.movz(0, std::uint64_t(node.value)); + } + else + { + auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 16)); + if (val != 0) builder.movk(0, val, i / 2); + } + } + + if (sizeof(T) < 8) + { + if (std::is_signed_v) + builder.sbfm(0, 0, sizeof(T) * 8); + else + builder.ubfm(0, 0, sizeof(T) * 8); + } + } + + void apply(ast::unary_operation const & node) + { + // TODO: floating-point + switch (node.type) + { + case ast::unary_operation_type::negation: + apply(*node.arg1); + builder.sub_reg(31, 0, 0); + extend(0, node.inferred_type); + break; + case ast::unary_operation_type::logical_not: + apply(*node.arg1); + builder.or_not_reg(31, 0, 0); + break; + } + } + + void apply(ast::binary_operation const & node) + { + // TODO: floating-point + switch (node.type) + { + case ast::binary_operation_type::addition: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.add_reg(1, 0, 0); + extend(0, node.inferred_type); + break; + case ast::binary_operation_type::subtraction: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.sub_reg(1, 0, 0); + extend(0, node.inferred_type); + break; + case ast::binary_operation_type::multiplication: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.mul_reg(1, 0, 0); + extend(0, node.inferred_type); + break; + case ast::binary_operation_type::division: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + if (types::is_signed_integer_type(*node.inferred_type)) + builder.sdiv_reg(1, 0, 0); + else + builder.udiv_reg(1, 0, 0); + extend(0, node.inferred_type); + break; + case ast::binary_operation_type::remainder: + // TODO: implement via div & mul & sub + throw std::runtime_error("Not implemented"); + case ast::binary_operation_type::logical_and: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.and_reg(1, 0, 0); + break; + case ast::binary_operation_type::logical_or: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.or_reg(1, 0, 0); + break; + case ast::binary_operation_type::logical_xor: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.xor_reg(1, 0, 0); + break; + case ast::binary_operation_type::equals: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.cmp_reg(1, 0); + set_m1(0); + builder.csel(0, 31, 0, 0b0000); + break; + case ast::binary_operation_type::not_equals: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.cmp_reg(1, 0); + set_m1(0); + builder.csel(0, 31, 0, 0b0001); + break; + case ast::binary_operation_type::less: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.cmp_reg(0, 1); + set_m1(0); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csel(0, 31, 0, 0b1000); + else + builder.csel(0, 31, 0, 0b1100); + break; + case ast::binary_operation_type::greater: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.cmp_reg(1, 0); + set_m1(0); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csel(0, 31, 0, 0b1000); + else + builder.csel(0, 31, 0, 0b1100); + break; + case ast::binary_operation_type::less_equals: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.cmp_reg(1, 0); + set_m1(0); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csel(0, 31, 0, 0b1001); + else + builder.csel(0, 31, 0, 0b1101); + break; + case ast::binary_operation_type::greater_equals: + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + builder.cmp_reg(0, 1); + set_m1(0); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csel(0, 31, 0, 0b1001); + else + builder.csel(0, 31, 0, 0b1101); + break; + default: + throw std::runtime_error("Not implemented"); + } + } + + void apply(ast::return_statement const & node) + { + apply(*node.value); + builder.ret(); + } + + void apply(ast::function_definition const & node) + { + // Don't handle internal functions + } + + void do_apply(ast::function_definition const & node) + { + // TODO: arguments + apply(*node.statements); + } + + private: + void push(std::uint8_t reg) + { + if (stack_free_bytes < 8) + { + builder.sub_imm(31, 31, 16); + stack_free_bytes += 16; + } + builder.str(reg, 31, (stack_free_bytes - 8) / 8); + stack_free_bytes -= 8; + } + + void pop(std::uint8_t reg) + { + builder.ldr(reg, 31, stack_free_bytes / 8); + stack_free_bytes += 8; + if (stack_free_bytes >= 16) + { + builder.add_imm(31, 31, 16); + stack_free_bytes -= 16; + } + } + + // Set register @reg to -1 (all bits = 1) + void set_m1(std::uint8_t reg) + { + builder.or_not_reg(31, 31, reg); + } + + // Sign- or zero-extend the register depending on the exact type + void extend(std::uint8_t reg, types::type_ptr const & type) + { + reg_extend_visitor{{}, builder, reg}.apply(*type); + } + }; + + struct compile_visitor + : ast::const_statement_visitor + { + using const_statement_visitor::apply; + + context & context; + instruction_builder builder{context.code}; + + template + void apply(Statement const &) + { + throw std::runtime_error("Not implemented"); + } + + void apply(ast::function_definition const & node) + { + context.code_symbol_table[node.name] = context.code.size(); + compile_function_visitor visitor{{}, {}, context}; + visitor.do_apply(node); + } + }; + + } + + compiled_module compile(ast::statement_list_ptr const & statements) + { + context context; + compile_visitor visitor{{}, context}; + visitor.apply(*statements); + + auto code = allocate(context.code.size()); + std::copy(context.code.data(), context.code.data() + context.code.size(), code.data.get()); + + return compiled_module { + .data = {}, + .code = { + .memory = std::move(code), + .symbol_table = std::move(context.code_symbol_table), + }, + .entry_point = 0, + .abi = abi::armv8, + }; + } + +} \ No newline at end of file diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp new file mode 100644 index 0000000..06d1fe3 --- /dev/null +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -0,0 +1,146 @@ +#include + +namespace pslang::jit::aarch64 +{ + + static constexpr std::uint32_t REG_MASK = 0x1fu; + + void instruction_builder::movz(std::uint8_t reg, std::uint16_t val, std::uint8_t shift) + { + do_push(0xd2800000u | (reg & REG_MASK) | (val << 5) | ((shift & 0x3u) << 21)); + } + + void instruction_builder::movk(std::uint8_t reg, std::uint16_t val, std::uint8_t shift) + { + do_push(0xf2800000u | (reg & REG_MASK) | (val << 5) | ((shift & 0x3u) << 21)); + } + + void instruction_builder::str(std::uint8_t reg_src, std::uint8_t reg_addr, std::uint16_t offset) + { + do_push(0xf9000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0xfffu) << 10)); + } + + void instruction_builder::str_pre(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0xf8000c00u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset) + { + do_push(0xf9400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0xfffu) << 10)); + } + + void instruction_builder::ldr_pre(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0xf8400c00u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0xf8400400u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::add_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value) + { + do_push(0x91000000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((value & 0xfffu) << 10)); + } + + void instruction_builder::sub_imm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint16_t value) + { + do_push(0xd1000000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((value & 0xfffu) << 10)); + } + + void instruction_builder::add_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0x8b000000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::sub_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0xcb000000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::and_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0x8a000000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::or_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0xaa000000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::xor_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0xca000000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::or_not_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0xaa200000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::mul_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0x9b007c00u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::sdiv_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0x9ac00c00u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::udiv_reg(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst) + { + do_push(0x9ac00800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::cmp_reg(std::uint8_t reg_src1, std::uint8_t reg_src2) + { + do_push(0xeb20601fu | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16)); + } + + void instruction_builder::cset(std::uint8_t reg_dst, std::uint8_t cond) + { + do_push(0x9a9f07e0u | (reg_dst & REG_MASK) | (((cond & 0xfu) ^ 1) << 12)); + } + + void instruction_builder::csel(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst, std::uint8_t cond) + { + do_push(0x9a800000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | ((cond & 0xfu) << 12)); + } + + void instruction_builder::sbfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count) + { + do_push(0x93400000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((bit_count - 1) & 0x3fu) << 10)); + } + + void instruction_builder::ubfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count) + { + do_push(0xd3400000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((bit_count - 1) & 0x3fu) << 10)); + } + + void instruction_builder::ret(std::uint8_t reg) + { + do_push(0xd65f0000u | ((reg & REG_MASK) << 5)); + } + + void instruction_builder::push(std::uint8_t reg) + { + str_pre(reg, 31, -8); + } + + void instruction_builder::pop(std::uint8_t reg) + { + ldr_post(reg, 31, 8); + } + + void instruction_builder::do_push(std::uint32_t opcode) + { + code.push_back((opcode >> 0) & 0xffu); + code.push_back((opcode >> 8) & 0xffu); + code.push_back((opcode >> 16) & 0xffu); + code.push_back((opcode >> 24) & 0xffu); + } + +} \ No newline at end of file diff --git a/libs/jit/source/executable.cpp b/libs/jit/source/executable.cpp index 911bb4d..86fc0f0 100644 --- a/libs/jit/source/executable.cpp +++ b/libs/jit/source/executable.cpp @@ -1,6 +1,8 @@ +#include #include #include +#include #ifdef __linux__ #include @@ -13,6 +15,17 @@ namespace pslang::jit { + blob allocate(std::size_t size) + { +#if defined(__linux__) || defined(__APPLE__) + auto ptr = (std::uint8_t *)mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); + auto shared_ptr = std::shared_ptr(ptr, [size](void * ptr){ munmap(ptr, size); }); + return blob(shared_ptr, size); +#else + throw std::runtime_error("Allocate not supported for this platform"); +#endif + } + compiled_module make_host_executable(compiled_module module) { #if defined(__linux__) || defined(__APPLE__) @@ -20,7 +33,8 @@ namespace pslang::jit throw std::runtime_error("Abi mismatch"); // Assume the module code memory was obtained via mmap - mprotect(module.code.memory.data.get(), module.code.memory.size, PROT_READ | PROT_EXEC); + if (mprotect(module.code.memory.data.get(), module.code.memory.size, PROT_READ | PROT_EXEC) != 0) + throw std::system_error(errno, std::generic_category()); return module; #else diff --git a/libs/jit/source/jit.cpp b/libs/jit/source/jit.cpp index 5992cac..322405d 100644 --- a/libs/jit/source/jit.cpp +++ b/libs/jit/source/jit.cpp @@ -1,16 +1,22 @@ #include +#include + +#include namespace pslang::jit { - compiled_module compile(ast::statement_list_ptr const & /* statements */, jit::abi abi) + compiled_module compile(ast::statement_list_ptr const & statements, jit::abi abi) { - return { - .data = {}, - .code = {}, - .entry_point = 0, - .abi = abi, - }; + switch (abi) + { + case abi::itanium: + throw std::runtime_error("Itanium ABI JIT not implemented"); + case abi::msvc: + throw std::runtime_error("MSVC ABI JIT not implemented"); + case abi::armv8: + return aarch64::compile(statements); + } } }