diff --git a/examples/jit_test.psl b/examples/jit_test.psl index a34bc68..0cead51 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -1,2 +1,2 @@ -func test() -> f32: - return - 3.1415 +func test() -> f16: + return - 3.1415h diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 064eb6a..0a6f3ee 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -124,14 +124,19 @@ namespace pslang::jit::aarch64 // Load a floating-point value from the address stored in @reg_addr plus a // 12-bit unsigned @offset multiplied by 4, and store it in floating-point - // register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit + // register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); // Store a floating-point value from @reg_src into the address stored in - // @reg_addr plus a 12-bit unsigned @offset multiplied by 4. @mode should - // be 0 for 32-bit values and 1 for 64-bit + // @reg_addr plus a 12-bit unsigned @offset multiplied by 4. + // @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); + // Convert a floating-point value from @reg_src stored in format @mode_src + // into a floating-point value in @reg_dst in format @mode_dst + // Modes are the same as for ldr_fp/str_fp + void fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst); + // Negate the floating-point register @reg_src and store the result in @reg_dst void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst); diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index a633ace..f6195f9 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -21,15 +21,18 @@ namespace pslang::jit::aarch64 std::vector code; std::unordered_map code_symbol_table; + std::unordered_map f16_constants; std::unordered_map f32_constants; std::unordered_map f64_constants; }; std::uint8_t fp_mode_for(types::type const & type) { + if (types::equal(type, types::primitive_type(types::f16_type{}))) + return 1; if (types::equal(type, types::primitive_type(types::f32_type{}))) - return 0; - return 1; + return 2; + return 3; } struct populate_constants_visitor @@ -46,6 +49,15 @@ namespace pslang::jit::aarch64 void apply(ast::primitive_literal_base const &) {} + void apply(ast::f16_literal const & node) + { + if (!context.f16_constants.contains(node.value.repr)) + { + context.f16_constants[node.value.repr] = context.code.size(); + push_bytes(node.value.repr); + } + } + void apply(ast::f32_literal const & node) { if (!context.f32_constants.contains(node.value)) @@ -290,6 +302,14 @@ namespace pslang::jit::aarch64 } } + void apply(ast::f16_literal const & node) + { + auto offset = context.f16_constants.at(node.value.repr); + std::int32_t current = context.code.size(); + builder.ldr_fp_pc(0, 0, (offset - current) / 4); + builder.fcvt(0, 0b10, 0, 0b01); + } + void apply(ast::f32_literal const & node) { auto offset = context.f32_constants.at(node.value); diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp index ddaec98..c4cb123 100644 --- a/libs/jit/source/arch/aarch64/instruction_builder.cpp +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -166,12 +166,17 @@ namespace pslang::jit::aarch64 void instruction_builder::ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset) { - do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30)); + do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30)); } void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset) { - do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30)); + do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30)); + } + + void instruction_builder::fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst) + { + do_push(0x1e224000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode_dst & 0x3u) ^ 0x2u) << 15) | (((mode_src & 0x3u) ^ 0x2u) << 22)); } void instruction_builder::fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst)