diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 0cead51..823cee9 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -1,2 +1,3 @@ -func test() -> f16: - return - 3.1415h +func test() -> f32: + let pi = 3.14159265358979323846 + return pi / 2.0 diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 0a6f3ee..7359439 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -140,6 +140,11 @@ namespace pslang::jit::aarch64 // Negate the floating-point register @reg_src and store the result in @reg_dst void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst); + void fadd(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); + void fsub(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); + void fmul(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); + void fdiv(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); + // Return from a subroutine, taking the return address from register @reg void ret(std::uint8_t reg = 30); diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index f6195f9..ec9936e 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -361,89 +361,86 @@ namespace pslang::jit::aarch64 void apply(ast::binary_operation const & node) { - // TODO: floating-point + bool const is_fp = types::is_floating_point_type(*node.inferred_type); + std::uint8_t const fp_mode = fp_mode_for(*node.inferred_type); + + if (is_fp) + { + apply(*node.arg1); + push_fp(0, fp_mode); + apply(*node.arg2); + pop_fp(1, fp_mode); + } + else + { + apply(*node.arg1); + push(0); + apply(*node.arg2); + pop(1); + } + switch (node.type) { case ast::binary_operation_type::addition: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); - builder.add_reg(1, 0, 0); - extend(0, node.inferred_type); + if (is_fp) + builder.fadd(1, 0, fp_mode, 0); + else + { + builder.add_reg(1, 0, 0); + extend(0, node.inferred_type); + } break; case ast::binary_operation_type::subtraction: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); - builder.sub_reg(1, 0, 0); - extend(0, node.inferred_type); + if (is_fp) + builder.fsub(1, 0, fp_mode, 0); + else + { + builder.sub_reg(1, 0, 0); + extend(0, node.inferred_type); + } break; case ast::binary_operation_type::multiplication: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); - builder.mul_reg(1, 0, 0); - extend(0, node.inferred_type); + if (is_fp) + builder.fmul(1, 0, fp_mode, 0); + else + { + builder.mul_reg(1, 0, 0); + extend(0, node.inferred_type); + } break; case ast::binary_operation_type::division: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); - if (types::is_signed_integer_type(*node.inferred_type)) - builder.sdiv_reg(1, 0, 0); + if (is_fp) + builder.fdiv(1, 0, fp_mode, 0); else - builder.udiv_reg(1, 0, 0); - extend(0, node.inferred_type); + { + if (types::is_signed_integer_type(*node.inferred_type)) + builder.sdiv_reg(1, 0, 0); + else + builder.udiv_reg(1, 0, 0); + extend(0, node.inferred_type); + } break; case ast::binary_operation_type::remainder: // TODO: implement via div & mul & sub throw std::runtime_error("Not implemented"); case ast::binary_operation_type::logical_and: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.and_reg(1, 0, 0); break; case ast::binary_operation_type::logical_or: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.or_reg(1, 0, 0); break; case ast::binary_operation_type::logical_xor: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.xor_reg(1, 0, 0); break; case ast::binary_operation_type::equals: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.cmp_reg(1, 0); builder.csetm(0, 0b0000); break; case ast::binary_operation_type::not_equals: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.cmp_reg(1, 0); builder.csetm(0, 0b0001); break; case ast::binary_operation_type::less: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); @@ -451,10 +448,6 @@ namespace pslang::jit::aarch64 builder.csetm(0, 0b1100); break; case ast::binary_operation_type::greater: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); @@ -462,10 +455,6 @@ namespace pslang::jit::aarch64 builder.csetm(0, 0b1100); break; case ast::binary_operation_type::less_equals: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); @@ -473,10 +462,6 @@ namespace pslang::jit::aarch64 builder.csetm(0, 0b1101); break; case ast::binary_operation_type::greater_equals: - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp index c4cb123..f1be01b 100644 --- a/libs/jit/source/arch/aarch64/instruction_builder.cpp +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -184,6 +184,26 @@ namespace pslang::jit::aarch64 do_push(0x1e214000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((mode & 0x1u) << 22)); } + void instruction_builder::fadd(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst) + { + do_push(0x1e202800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22)); + } + + void instruction_builder::fsub(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst) + { + do_push(0x1e203800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22)); + } + + void instruction_builder::fmul(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst) + { + do_push(0x1e200800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22)); + } + + void instruction_builder::fdiv(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst) + { + do_push(0x1e201800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22)); + } + void instruction_builder::ret(std::uint8_t reg) { do_push(0xd65f0000u | ((reg & REG_MASK) << 5));