diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 8406925..84332d7 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -14,6 +14,8 @@ #include #include +#include + std::string extract_nth_line(std::filesystem::path const & path, std::size_t n) { std::ifstream file(path); @@ -170,8 +172,8 @@ int main(int argc, char ** argv) { // TODO: remove, testing-only code; should execute entry point instead auto offset = module.code.symbol_table.at("test"); - auto fptr = (int(*)(int))(module.code.memory.data.get() + offset); - auto x = fptr(10); + auto fptr = (float(*)())(module.code.memory.data.get() + offset); + auto x = fptr(); std::cout << "Result: " << std::boolalpha << x << std::endl; } } diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 08a6fd1..a34bc68 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -1,6 +1,2 @@ -func factorial(x : i32) -> i32: - mut r = 1 - while x > 0: - r = r * x - x = x - 1 - return r +func test() -> f32: + return - 3.1415 diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 7bae70d..064eb6a 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -117,6 +117,24 @@ namespace pslang::jit::aarch64 // starting at @opcode void b_inject(std::uint8_t * opcode, std::int32_t offset); + // Load a floating-point value from current program counter plus a + // 19-bit signed @offset multiplied by 4, and store it in floating-point + // register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit + void ldr_fp_pc(std::uint8_t reg_dst, std::uint8_t mode, std::int32_t offset); + + // Load a floating-point value from the address stored in @reg_addr plus a + // 12-bit unsigned @offset multiplied by 4, and store it in floating-point + // register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit + void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); + + // Store a floating-point value from @reg_src into the address stored in + // @reg_addr plus a 12-bit unsigned @offset multiplied by 4. @mode should + // be 0 for 32-bit values and 1 for 64-bit + void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); + + // Negate the floating-point register @reg_src and store the result in @reg_dst + void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst); + // Return from a subroutine, taking the return address from register @reg void ret(std::uint8_t reg = 30); diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 04bb29d..a633ace 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -1,4 +1,3 @@ -#include "pslang/types/type_fwd.hpp" #include #include #include @@ -21,6 +20,155 @@ namespace pslang::jit::aarch64 { std::vector code; std::unordered_map code_symbol_table; + + std::unordered_map f32_constants; + std::unordered_map f64_constants; + }; + + std::uint8_t fp_mode_for(types::type const & type) + { + if (types::equal(type, types::primitive_type(types::f32_type{}))) + return 0; + return 1; + } + + struct populate_constants_visitor + : ast::const_expression_visitor + , ast::const_statement_visitor + { + using const_expression_visitor::apply; + using const_statement_visitor::apply; + + context & context; + + template + requires(!std::is_floating_point_v) + void apply(ast::primitive_literal_base const &) + {} + + void apply(ast::f32_literal const & node) + { + if (!context.f32_constants.contains(node.value)) + { + context.f32_constants[node.value] = context.code.size(); + push_bytes(node.value); + } + } + + void apply(ast::f64_literal const & node) + { + if (!context.f64_constants.contains(node.value)) + { + context.f64_constants[node.value] = context.code.size(); + push_bytes(node.value); + } + } + + void apply(ast::identifier const &) + {} + + void apply(ast::unary_operation const & node) + { + apply(*node.arg1); + } + + void apply(ast::binary_operation const & node) + { + apply(*node.arg1); + apply(*node.arg2); + } + + void apply(ast::cast_operation const & node) + { + apply(*node.expression); + } + + void apply(ast::function_call const & node) + { + apply(*node.function); + for (auto const & argument : node.arguments) + apply(*argument); + } + + void apply(ast::array const & node) + { + for (auto const & element : node.elements) + apply(*element); + } + + void apply(ast::array_access const & node) + { + apply(*node.array); + apply(*node.index); + } + + void apply(ast::field_access const & node) + { + apply(*node.object); + } + + void apply(ast::expression_ptr const & node) + { + apply(*node); + } + + void apply(ast::assignment const & node) + { + apply(*node.lhs); + apply(*node.rhs); + } + + void apply(ast::variable_declaration const & node) + { + apply(*node.initializer); + } + + void apply(ast::if_block const &) + {} + + void apply(ast::else_block const &) + {} + + void apply(ast::else_if_block const &) + {} + + void apply(ast::if_chain const & node) + { + for (auto const & block : node.blocks) + apply(*block.statements); + } + + void apply(ast::while_block const & node) + { + apply(*node.condition); + apply(*node.statements); + } + + void apply(ast::function_definition const & node) + { + apply(*node.statements); + } + + void apply(ast::return_statement const & node) + { + apply(*node.value); + } + + void apply(ast::field_definition const &) + {} + + void apply(ast::struct_definition const &) + {} + + private: + + template + void push_bytes(T const & value) + { + auto begin = (std::uint8_t const *)(&value); + auto end = begin + sizeof(value); + context.code.insert(context.code.end(), begin, end); + } }; struct reg_extend_visitor @@ -142,6 +290,20 @@ namespace pslang::jit::aarch64 } } + void apply(ast::f32_literal const & node) + { + auto offset = context.f32_constants.at(node.value); + std::int32_t current = context.code.size(); + builder.ldr_fp_pc(0, 0, (offset - current) / 4); + } + + void apply(ast::f64_literal const & node) + { + auto offset = context.f64_constants.at(node.value); + std::int32_t current = context.code.size(); + builder.ldr_fp_pc(0, 1, (offset - current) / 4); + } + void apply(ast::identifier const & node) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) @@ -156,13 +318,19 @@ namespace pslang::jit::aarch64 void apply(ast::unary_operation const & node) { - // TODO: floating-point switch (node.type) { case ast::unary_operation_type::negation: apply(*node.arg1); - builder.sub_reg(31, 0, 0); - extend(0, node.inferred_type); + if (types::is_integer_type(*node.inferred_type)) + { + builder.sub_reg(31, 0, 0); + extend(0, node.inferred_type); + } + else if (types::is_floating_point_type(*node.inferred_type)) + { + builder.fneg(0, fp_mode_for(*node.inferred_type), 0); + } break; case ast::unary_operation_type::logical_not: apply(*node.arg1); @@ -443,6 +611,24 @@ namespace pslang::jit::aarch64 scopes.back().stack_offset -= 16; } + // @mode = 0: 32-bit + // @mode = 1: 64-bit + void push_fp(std::uint8_t reg, std::uint8_t mode) + { + builder.sub_imm(31, 31, 16); + builder.str_fp(0, mode, 31, 0); + stack_offset += 16; + scopes.back().stack_offset += 16; + } + + void pop_fp(std::uint8_t reg, std::uint8_t mode) + { + builder.ldr_fp(reg, mode, 31, 0); + builder.add_imm(31, 31, 16); + stack_offset -= 16; + scopes.back().stack_offset -= 16; + } + // Set register @reg to -1 (all bits = 1) void set_m1(std::uint8_t reg) { @@ -489,8 +675,10 @@ namespace pslang::jit::aarch64 compiled_module compile(ast::statement_list_ptr const & statements) { context context; - compile_visitor visitor{{}, context}; - visitor.apply(*statements); + + populate_constants_visitor{{}, {}, context}.apply(*statements); + + compile_visitor{{}, context}.apply(*statements); auto code = allocate(context.code.size()); std::copy(context.code.data(), context.code.data() + context.code.size(), code.data.get()); diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp index e8302ba..ddaec98 100644 --- a/libs/jit/source/arch/aarch64/instruction_builder.cpp +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -159,6 +159,26 @@ namespace pslang::jit::aarch64 *dst |= (std::uint32_t(offset) & 0x3ffffffu); } + void instruction_builder::ldr_fp_pc(std::uint8_t reg_dst, std::uint8_t mode, std::int32_t offset) + { + do_push(0x1c000000u | ((mode & 0x1u) << 30) | (reg_dst & REG_MASK) | ((std::uint32_t(offset) & 0x7ffffu) << 5)); + } + + void instruction_builder::ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset) + { + do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30)); + } + + void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset) + { + do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30)); + } + + void instruction_builder::fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst) + { + do_push(0x1e214000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((mode & 0x1u) << 22)); + } + void instruction_builder::ret(std::uint8_t reg) { do_push(0xd65f0000u | ((reg & REG_MASK) << 5));