Aarch64 jit compiler wip: support f16 literals

This commit is contained in:
Nikita Lisitsa 2026-01-05 00:49:47 +03:00
parent 7ddc8ba25d
commit c253041068
4 changed files with 39 additions and 9 deletions

View file

@ -1,2 +1,2 @@
func test() -> f32:
return - 3.1415
func test() -> f16:
return - 3.1415h

View file

@ -124,14 +124,19 @@ namespace pslang::jit::aarch64
// Load a floating-point value from the address stored in @reg_addr plus a
// 12-bit unsigned @offset multiplied by 4, and store it in floating-point
// register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit
// register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
// Store a floating-point value from @reg_src into the address stored in
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4. @mode should
// be 0 for 32-bit values and 1 for 64-bit
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4.
// @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
// Convert a floating-point value from @reg_src stored in format @mode_src
// into a floating-point value in @reg_dst in format @mode_dst
// Modes are the same as for ldr_fp/str_fp
void fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst);
// Negate the floating-point register @reg_src and store the result in @reg_dst
void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst);

View file

@ -21,15 +21,18 @@ namespace pslang::jit::aarch64
std::vector<std::uint8_t> code;
std::unordered_map<std::string, std::size_t> code_symbol_table;
std::unordered_map<float, std::int32_t> f16_constants;
std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants;
};
std::uint8_t fp_mode_for(types::type const & type)
{
if (types::equal(type, types::primitive_type(types::f16_type{})))
return 1;
if (types::equal(type, types::primitive_type(types::f32_type{})))
return 0;
return 1;
return 2;
return 3;
}
struct populate_constants_visitor
@ -46,6 +49,15 @@ namespace pslang::jit::aarch64
void apply(ast::primitive_literal_base<T> const &)
{}
void apply(ast::f16_literal const & node)
{
if (!context.f16_constants.contains(node.value.repr))
{
context.f16_constants[node.value.repr] = context.code.size();
push_bytes(node.value.repr);
}
}
void apply(ast::f32_literal const & node)
{
if (!context.f32_constants.contains(node.value))
@ -290,6 +302,14 @@ namespace pslang::jit::aarch64
}
}
void apply(ast::f16_literal const & node)
{
auto offset = context.f16_constants.at(node.value.repr);
std::int32_t current = context.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
builder.fcvt(0, 0b10, 0, 0b01);
}
void apply(ast::f32_literal const & node)
{
auto offset = context.f32_constants.at(node.value);

View file

@ -166,12 +166,17 @@ namespace pslang::jit::aarch64
void instruction_builder::ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
{
do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
}
void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
{
do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
}
void instruction_builder::fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst)
{
do_push(0x1e224000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode_dst & 0x3u) ^ 0x2u) << 15) | (((mode_src & 0x3u) ^ 0x2u) << 22));
}
void instruction_builder::fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst)