Aarch64 compiler wip: replace csel with csetm for comparisons

This commit is contained in:
Nikita Lisitsa 2026-01-04 12:36:24 +03:00
parent 7ed008543c
commit db4a8ac264
4 changed files with 21 additions and 20 deletions

View file

@ -1,2 +1,2 @@
func test(x : i32) -> i32: func test(x : i32) -> bool:
return x * x return x <= 42

View file

@ -78,13 +78,16 @@ namespace pslang::jit::aarch64
// Compare the values of @reg_src1 and @reg_src2 and set the flags // Compare the values of @reg_src1 and @reg_src2 and set the flags
void cmp_reg(std::uint8_t reg_src1, std::uint8_t reg_src2); void cmp_reg(std::uint8_t reg_src1, std::uint8_t reg_src2);
// Set the value of @reg_dst to 1 is condition @cond is true // Set the value of @reg_dst to 1 if condition @cond is true, and to 0 otherwise
void cset(std::uint8_t reg_dst, std::uint8_t cond); void cset(std::uint8_t reg_dst, std::uint8_t cond);
// Set the value of @reg_dst to @reg_src1 if condition @cond is true, otherwise // Set the value of @reg_dst to @reg_src1 if condition @cond is true, otherwise
// to @reg_src2. // to @reg_src2.
void csel(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst, std::uint8_t cond); void csel(std::uint8_t reg_src1, std::uint8_t reg_src2, std::uint8_t reg_dst, std::uint8_t cond);
// Set the value of @reg_dst to all 1s if condition @cond is true, and to 0 otherwise
void csetm(std::uint8_t reg_dst, std::uint8_t cond);
// Take @bit_count lowest bits from @reg_src, copy them to @reg_dst, and sign-extend @reg_dst // Take @bit_count lowest bits from @reg_src, copy them to @reg_dst, and sign-extend @reg_dst
void sbfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count); void sbfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count);

View file

@ -240,8 +240,7 @@ namespace pslang::jit::aarch64
apply(*node.arg2); apply(*node.arg2);
pop(1); pop(1);
builder.cmp_reg(1, 0); builder.cmp_reg(1, 0);
set_m1(0); builder.csetm(0, 0b0000);
builder.csel(0, 31, 0, 0b0000);
break; break;
case ast::binary_operation_type::not_equals: case ast::binary_operation_type::not_equals:
apply(*node.arg1); apply(*node.arg1);
@ -249,8 +248,7 @@ namespace pslang::jit::aarch64
apply(*node.arg2); apply(*node.arg2);
pop(1); pop(1);
builder.cmp_reg(1, 0); builder.cmp_reg(1, 0);
set_m1(0); builder.csetm(0, 0b0001);
builder.csel(0, 31, 0, 0b0001);
break; break;
case ast::binary_operation_type::less: case ast::binary_operation_type::less:
apply(*node.arg1); apply(*node.arg1);
@ -258,23 +256,20 @@ namespace pslang::jit::aarch64
apply(*node.arg2); apply(*node.arg2);
pop(1); pop(1);
builder.cmp_reg(0, 1); builder.cmp_reg(0, 1);
set_m1(0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csel(0, 31, 0, 0b1000); builder.csetm(0, 0b1000);
else else
builder.csel(0, 31, 0, 0b1100); builder.csetm(0, 0b1100);
break; break;
case ast::binary_operation_type::greater: case ast::binary_operation_type::greater:
apply(*node.arg1); apply(*node.arg1);
push(0); push(0);
apply(*node.arg2); apply(*node.arg2);
pop(1); pop(1);
builder.cmp_reg(1, 0);
set_m1(0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csel(0, 31, 0, 0b1000); builder.csetm(0, 0b1000);
else else
builder.csel(0, 31, 0, 0b1100); builder.csetm(0, 0b1100);
break; break;
case ast::binary_operation_type::less_equals: case ast::binary_operation_type::less_equals:
apply(*node.arg1); apply(*node.arg1);
@ -282,11 +277,10 @@ namespace pslang::jit::aarch64
apply(*node.arg2); apply(*node.arg2);
pop(1); pop(1);
builder.cmp_reg(1, 0); builder.cmp_reg(1, 0);
set_m1(0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csel(0, 31, 0, 0b1001); builder.csetm(0, 0b1001);
else else
builder.csel(0, 31, 0, 0b1101); builder.csetm(0, 0b1101);
break; break;
case ast::binary_operation_type::greater_equals: case ast::binary_operation_type::greater_equals:
apply(*node.arg1); apply(*node.arg1);
@ -294,11 +288,10 @@ namespace pslang::jit::aarch64
apply(*node.arg2); apply(*node.arg2);
pop(1); pop(1);
builder.cmp_reg(0, 1); builder.cmp_reg(0, 1);
set_m1(0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csel(0, 31, 0, 0b1001); builder.csetm(0, 0b1001);
else else
builder.csel(0, 31, 0, 0b1101); builder.csetm(0, 0b1101);
break; break;
default: default:
throw std::runtime_error("Not implemented"); throw std::runtime_error("Not implemented");

View file

@ -110,6 +110,11 @@ namespace pslang::jit::aarch64
do_push(0x9a800000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | ((cond & 0xfu) << 12)); do_push(0x9a800000u | (reg_dst & REG_MASK) | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16) | ((cond & 0xfu) << 12));
} }
void instruction_builder::csetm(std::uint8_t reg_dst, std::uint8_t cond)
{
do_push(0xda9f03e0u | (reg_dst & REG_MASK) | (((cond & 0xfu) ^ 1) << 12));
}
void instruction_builder::sbfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count) void instruction_builder::sbfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count)
{ {
do_push(0x93400000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((bit_count - 1) & 0x3fu) << 10)); do_push(0x93400000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((bit_count - 1) & 0x3fu) << 10));