diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 823cee9..79df4e7 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -1,3 +1,4 @@ func test() -> f32: - let pi = 3.14159265358979323846 - return pi / 2.0 + if 0.0 >= 1.0: + return 3.1415 + return 2.71828 diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 7359439..45d8c99 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -145,6 +145,8 @@ namespace pslang::jit::aarch64 void fmul(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); void fdiv(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); + void fcmp(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode); + // Return from a subroutine, taking the return address from register @reg void ret(std::uint8_t reg = 30); diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index ec9936e..4e0ac14 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -147,7 +147,10 @@ namespace pslang::jit::aarch64 void apply(ast::if_chain const & node) { for (auto const & block : node.blocks) + { + apply(*block.condition); apply(*block.statements); + } } void apply(ast::while_block const & node) @@ -361,8 +364,9 @@ namespace pslang::jit::aarch64 void apply(ast::binary_operation const & node) { - bool const is_fp = types::is_floating_point_type(*node.inferred_type); - std::uint8_t const fp_mode = fp_mode_for(*node.inferred_type); + auto arg1_type = ast::get_type(*node.arg1); + bool const is_fp = types::is_floating_point_type(*arg1_type); + std::uint8_t const fp_mode = fp_mode_for(*arg1_type); if (is_fp) { @@ -433,40 +437,88 @@ namespace pslang::jit::aarch64 builder.xor_reg(1, 0, 0); break; case ast::binary_operation_type::equals: - builder.cmp_reg(1, 0); - builder.csetm(0, 0b0000); + if (is_fp) + { + builder.fcmp(1, 0, fp_mode); + builder.csetm(0, 0b0000); + } + else + { + builder.cmp_reg(1, 0); + builder.csetm(0, 0b0000); + } break; case ast::binary_operation_type::not_equals: - builder.cmp_reg(1, 0); - builder.csetm(0, 0b0001); + if (is_fp) + { + builder.fcmp(1, 0, fp_mode); + builder.csetm(0, 0b0001); + } + else + { + builder.cmp_reg(1, 0); + builder.csetm(0, 0b0001); + } break; case ast::binary_operation_type::less: - builder.cmp_reg(0, 1); - if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) - builder.csetm(0, 0b1000); + if (is_fp) + { + builder.fcmp(1, 0, fp_mode); + builder.csetm(0, 0b0100); + } else - builder.csetm(0, 0b1100); + { + builder.cmp_reg(0, 1); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csetm(0, 0b1000); + else + builder.csetm(0, 0b1100); + } break; case ast::binary_operation_type::greater: - builder.cmp_reg(1, 0); - if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) - builder.csetm(0, 0b1000); + if (is_fp) + { + builder.fcmp(0, 1, fp_mode); + builder.csetm(0, 0b0100); + } else - builder.csetm(0, 0b1100); + { + builder.cmp_reg(1, 0); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csetm(0, 0b1000); + else + builder.csetm(0, 0b1100); + } break; case ast::binary_operation_type::less_equals: - builder.cmp_reg(1, 0); - if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + if (is_fp) + { + builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b1001); + } else - builder.csetm(0, 0b1101); + { + builder.cmp_reg(1, 0); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csetm(0, 0b1001); + else + builder.csetm(0, 0b1101); + } break; case ast::binary_operation_type::greater_equals: - builder.cmp_reg(0, 1); - if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + if (is_fp) + { + builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b1001); + } else - builder.csetm(0, 0b1101); + { + builder.cmp_reg(0, 1); + if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) + builder.csetm(0, 0b1001); + else + builder.csetm(0, 0b1101); + } break; default: throw std::runtime_error("Not implemented"); diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp index f1be01b..9c200a3 100644 --- a/libs/jit/source/arch/aarch64/instruction_builder.cpp +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -204,6 +204,11 @@ namespace pslang::jit::aarch64 do_push(0x1e201800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22)); } + void instruction_builder::fcmp(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode) + { + do_push(0x1e202000u | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22)); + } + void instruction_builder::ret(std::uint8_t reg) { do_push(0xd65f0000u | ((reg & REG_MASK) << 5));