Aarch64 jit compiler wip: support f16 literals
This commit is contained in:
parent
7ddc8ba25d
commit
c253041068
4 changed files with 39 additions and 9 deletions
|
|
@ -1,2 +1,2 @@
|
|||
func test() -> f32:
|
||||
return - 3.1415
|
||||
func test() -> f16:
|
||||
return - 3.1415h
|
||||
|
|
|
|||
|
|
@ -124,14 +124,19 @@ namespace pslang::jit::aarch64
|
|||
|
||||
// Load a floating-point value from the address stored in @reg_addr plus a
|
||||
// 12-bit unsigned @offset multiplied by 4, and store it in floating-point
|
||||
// register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit
|
||||
// register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
|
||||
void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
|
||||
|
||||
// Store a floating-point value from @reg_src into the address stored in
|
||||
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4. @mode should
|
||||
// be 0 for 32-bit values and 1 for 64-bit
|
||||
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4.
|
||||
// @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
|
||||
void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
|
||||
|
||||
// Convert a floating-point value from @reg_src stored in format @mode_src
|
||||
// into a floating-point value in @reg_dst in format @mode_dst
|
||||
// Modes are the same as for ldr_fp/str_fp
|
||||
void fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst);
|
||||
|
||||
// Negate the floating-point register @reg_src and store the result in @reg_dst
|
||||
void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst);
|
||||
|
||||
|
|
|
|||
|
|
@ -21,15 +21,18 @@ namespace pslang::jit::aarch64
|
|||
std::vector<std::uint8_t> code;
|
||||
std::unordered_map<std::string, std::size_t> code_symbol_table;
|
||||
|
||||
std::unordered_map<float, std::int32_t> f16_constants;
|
||||
std::unordered_map<float, std::int32_t> f32_constants;
|
||||
std::unordered_map<double, std::int32_t> f64_constants;
|
||||
};
|
||||
|
||||
std::uint8_t fp_mode_for(types::type const & type)
|
||||
{
|
||||
if (types::equal(type, types::primitive_type(types::f16_type{})))
|
||||
return 1;
|
||||
if (types::equal(type, types::primitive_type(types::f32_type{})))
|
||||
return 0;
|
||||
return 1;
|
||||
return 2;
|
||||
return 3;
|
||||
}
|
||||
|
||||
struct populate_constants_visitor
|
||||
|
|
@ -46,6 +49,15 @@ namespace pslang::jit::aarch64
|
|||
void apply(ast::primitive_literal_base<T> const &)
|
||||
{}
|
||||
|
||||
void apply(ast::f16_literal const & node)
|
||||
{
|
||||
if (!context.f16_constants.contains(node.value.repr))
|
||||
{
|
||||
context.f16_constants[node.value.repr] = context.code.size();
|
||||
push_bytes(node.value.repr);
|
||||
}
|
||||
}
|
||||
|
||||
void apply(ast::f32_literal const & node)
|
||||
{
|
||||
if (!context.f32_constants.contains(node.value))
|
||||
|
|
@ -290,6 +302,14 @@ namespace pslang::jit::aarch64
|
|||
}
|
||||
}
|
||||
|
||||
void apply(ast::f16_literal const & node)
|
||||
{
|
||||
auto offset = context.f16_constants.at(node.value.repr);
|
||||
std::int32_t current = context.code.size();
|
||||
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
|
||||
builder.fcvt(0, 0b10, 0, 0b01);
|
||||
}
|
||||
|
||||
void apply(ast::f32_literal const & node)
|
||||
{
|
||||
auto offset = context.f32_constants.at(node.value);
|
||||
|
|
|
|||
|
|
@ -166,12 +166,17 @@ namespace pslang::jit::aarch64
|
|||
|
||||
void instruction_builder::ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
|
||||
{
|
||||
do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
|
||||
do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
|
||||
}
|
||||
|
||||
void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
|
||||
{
|
||||
do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
|
||||
do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
|
||||
}
|
||||
|
||||
void instruction_builder::fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst)
|
||||
{
|
||||
do_push(0x1e224000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode_dst & 0x3u) ^ 0x2u) << 15) | (((mode_src & 0x3u) ^ 0x2u) << 22));
|
||||
}
|
||||
|
||||
void instruction_builder::fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue