Aarch64 jit compiler wip: support f16 literals
This commit is contained in:
parent
7ddc8ba25d
commit
c253041068
4 changed files with 39 additions and 9 deletions
|
|
@ -1,2 +1,2 @@
|
||||||
func test() -> f32:
|
func test() -> f16:
|
||||||
return - 3.1415
|
return - 3.1415h
|
||||||
|
|
|
||||||
|
|
@ -124,14 +124,19 @@ namespace pslang::jit::aarch64
|
||||||
|
|
||||||
// Load a floating-point value from the address stored in @reg_addr plus a
|
// Load a floating-point value from the address stored in @reg_addr plus a
|
||||||
// 12-bit unsigned @offset multiplied by 4, and store it in floating-point
|
// 12-bit unsigned @offset multiplied by 4, and store it in floating-point
|
||||||
// register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit
|
// register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
|
||||||
void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
|
void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
|
||||||
|
|
||||||
// Store a floating-point value from @reg_src into the address stored in
|
// Store a floating-point value from @reg_src into the address stored in
|
||||||
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4. @mode should
|
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4.
|
||||||
// be 0 for 32-bit values and 1 for 64-bit
|
// @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit
|
||||||
void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
|
void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
|
||||||
|
|
||||||
|
// Convert a floating-point value from @reg_src stored in format @mode_src
|
||||||
|
// into a floating-point value in @reg_dst in format @mode_dst
|
||||||
|
// Modes are the same as for ldr_fp/str_fp
|
||||||
|
void fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst);
|
||||||
|
|
||||||
// Negate the floating-point register @reg_src and store the result in @reg_dst
|
// Negate the floating-point register @reg_src and store the result in @reg_dst
|
||||||
void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst);
|
void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,15 +21,18 @@ namespace pslang::jit::aarch64
|
||||||
std::vector<std::uint8_t> code;
|
std::vector<std::uint8_t> code;
|
||||||
std::unordered_map<std::string, std::size_t> code_symbol_table;
|
std::unordered_map<std::string, std::size_t> code_symbol_table;
|
||||||
|
|
||||||
|
std::unordered_map<float, std::int32_t> f16_constants;
|
||||||
std::unordered_map<float, std::int32_t> f32_constants;
|
std::unordered_map<float, std::int32_t> f32_constants;
|
||||||
std::unordered_map<double, std::int32_t> f64_constants;
|
std::unordered_map<double, std::int32_t> f64_constants;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::uint8_t fp_mode_for(types::type const & type)
|
std::uint8_t fp_mode_for(types::type const & type)
|
||||||
{
|
{
|
||||||
|
if (types::equal(type, types::primitive_type(types::f16_type{})))
|
||||||
|
return 1;
|
||||||
if (types::equal(type, types::primitive_type(types::f32_type{})))
|
if (types::equal(type, types::primitive_type(types::f32_type{})))
|
||||||
return 0;
|
return 2;
|
||||||
return 1;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct populate_constants_visitor
|
struct populate_constants_visitor
|
||||||
|
|
@ -46,6 +49,15 @@ namespace pslang::jit::aarch64
|
||||||
void apply(ast::primitive_literal_base<T> const &)
|
void apply(ast::primitive_literal_base<T> const &)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
void apply(ast::f16_literal const & node)
|
||||||
|
{
|
||||||
|
if (!context.f16_constants.contains(node.value.repr))
|
||||||
|
{
|
||||||
|
context.f16_constants[node.value.repr] = context.code.size();
|
||||||
|
push_bytes(node.value.repr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void apply(ast::f32_literal const & node)
|
void apply(ast::f32_literal const & node)
|
||||||
{
|
{
|
||||||
if (!context.f32_constants.contains(node.value))
|
if (!context.f32_constants.contains(node.value))
|
||||||
|
|
@ -290,6 +302,14 @@ namespace pslang::jit::aarch64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(ast::f16_literal const & node)
|
||||||
|
{
|
||||||
|
auto offset = context.f16_constants.at(node.value.repr);
|
||||||
|
std::int32_t current = context.code.size();
|
||||||
|
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
|
||||||
|
builder.fcvt(0, 0b10, 0, 0b01);
|
||||||
|
}
|
||||||
|
|
||||||
void apply(ast::f32_literal const & node)
|
void apply(ast::f32_literal const & node)
|
||||||
{
|
{
|
||||||
auto offset = context.f32_constants.at(node.value);
|
auto offset = context.f32_constants.at(node.value);
|
||||||
|
|
|
||||||
|
|
@ -166,12 +166,17 @@ namespace pslang::jit::aarch64
|
||||||
|
|
||||||
void instruction_builder::ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
|
void instruction_builder::ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
|
||||||
{
|
{
|
||||||
do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
|
do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
|
||||||
}
|
}
|
||||||
|
|
||||||
void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
|
void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
|
||||||
{
|
{
|
||||||
do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
|
do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30));
|
||||||
|
}
|
||||||
|
|
||||||
|
void instruction_builder::fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst)
|
||||||
|
{
|
||||||
|
do_push(0x1e224000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode_dst & 0x3u) ^ 0x2u) << 15) | (((mode_src & 0x3u) ^ 0x2u) << 22));
|
||||||
}
|
}
|
||||||
|
|
||||||
void instruction_builder::fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst)
|
void instruction_builder::fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue