Aarch64 jit compiler wip: floating-point comparisons

This commit is contained in:
Nikita Lisitsa 2026-01-05 10:23:33 +03:00
parent 45004ce812
commit 428283dbf8
4 changed files with 82 additions and 22 deletions

View file

@ -1,3 +1,4 @@
func test() -> f32: func test() -> f32:
let pi = 3.14159265358979323846 if 0.0 >= 1.0:
return pi / 2.0 return 3.1415
return 2.71828

View file

@ -145,6 +145,8 @@ namespace pslang::jit::aarch64
void fmul(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); void fmul(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst);
void fdiv(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst); void fdiv(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode, std::uint8_t reg_dst);
void fcmp(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode);
// Return from a subroutine, taking the return address from register @reg // Return from a subroutine, taking the return address from register @reg
void ret(std::uint8_t reg = 30); void ret(std::uint8_t reg = 30);

View file

@ -147,7 +147,10 @@ namespace pslang::jit::aarch64
void apply(ast::if_chain const & node) void apply(ast::if_chain const & node)
{ {
for (auto const & block : node.blocks) for (auto const & block : node.blocks)
{
apply(*block.condition);
apply(*block.statements); apply(*block.statements);
}
} }
void apply(ast::while_block const & node) void apply(ast::while_block const & node)
@ -361,8 +364,9 @@ namespace pslang::jit::aarch64
void apply(ast::binary_operation const & node) void apply(ast::binary_operation const & node)
{ {
bool const is_fp = types::is_floating_point_type(*node.inferred_type); auto arg1_type = ast::get_type(*node.arg1);
std::uint8_t const fp_mode = fp_mode_for(*node.inferred_type); bool const is_fp = types::is_floating_point_type(*arg1_type);
std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
if (is_fp) if (is_fp)
{ {
@ -433,40 +437,88 @@ namespace pslang::jit::aarch64
builder.xor_reg(1, 0, 0); builder.xor_reg(1, 0, 0);
break; break;
case ast::binary_operation_type::equals: case ast::binary_operation_type::equals:
builder.cmp_reg(1, 0); if (is_fp)
builder.csetm(0, 0b0000); {
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0000);
}
else
{
builder.cmp_reg(1, 0);
builder.csetm(0, 0b0000);
}
break; break;
case ast::binary_operation_type::not_equals: case ast::binary_operation_type::not_equals:
builder.cmp_reg(1, 0); if (is_fp)
builder.csetm(0, 0b0001); {
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0001);
}
else
{
builder.cmp_reg(1, 0);
builder.csetm(0, 0b0001);
}
break; break;
case ast::binary_operation_type::less: case ast::binary_operation_type::less:
builder.cmp_reg(0, 1); if (is_fp)
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) {
builder.csetm(0, 0b1000); builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0100);
}
else else
builder.csetm(0, 0b1100); {
builder.cmp_reg(0, 1);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break; break;
case ast::binary_operation_type::greater: case ast::binary_operation_type::greater:
builder.cmp_reg(1, 0); if (is_fp)
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) {
builder.csetm(0, 0b1000); builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0100);
}
else else
builder.csetm(0, 0b1100); {
builder.cmp_reg(1, 0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break; break;
case ast::binary_operation_type::less_equals: case ast::binary_operation_type::less_equals:
builder.cmp_reg(1, 0); if (is_fp)
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) {
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b1001); builder.csetm(0, 0b1001);
}
else else
builder.csetm(0, 0b1101); {
builder.cmp_reg(1, 0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break; break;
case ast::binary_operation_type::greater_equals: case ast::binary_operation_type::greater_equals:
builder.cmp_reg(0, 1); if (is_fp)
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) {
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b1001); builder.csetm(0, 0b1001);
}
else else
builder.csetm(0, 0b1101); {
builder.cmp_reg(0, 1);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break; break;
default: default:
throw std::runtime_error("Not implemented"); throw std::runtime_error("Not implemented");

View file

@ -204,6 +204,11 @@ namespace pslang::jit::aarch64
do_push(0x1e201800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22)); do_push(0x1e201800u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22));
} }
void instruction_builder::fcmp(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t mode)
{
do_push(0x1e202000u | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | (((mode & 0x3u) ^ 0x2u) << 22));
}
void instruction_builder::ret(std::uint8_t reg) void instruction_builder::ret(std::uint8_t reg)
{ {
do_push(0xd65f0000u | ((reg & REG_MASK) << 5)); do_push(0xd65f0000u | ((reg & REG_MASK) << 5));