Aarch 64 jit compiler wip: floating-point support wip

This commit is contained in:
Nikita Lisitsa 2026-01-05 00:05:20 +03:00
parent 8b560b660b
commit 0d87d35c47
5 changed files with 238 additions and 14 deletions

View file

@ -14,6 +14,8 @@
#include <cstring>
#include <cmath>
#include <chrono>
std::string extract_nth_line(std::filesystem::path const & path, std::size_t n)
{
std::ifstream file(path);
@ -170,8 +172,8 @@ int main(int argc, char ** argv)
{
// TODO: remove, testing-only code; should execute entry point instead
auto offset = module.code.symbol_table.at("test");
auto fptr = (int(*)(int))(module.code.memory.data.get() + offset);
auto x = fptr(10);
auto fptr = (float(*)())(module.code.memory.data.get() + offset);
auto x = fptr();
std::cout << "Result: " << std::boolalpha << x << std::endl;
}
}

View file

@ -1,6 +1,2 @@
func factorial(x : i32) -> i32:
mut r = 1
while x > 0:
r = r * x
x = x - 1
return r
func test() -> f32:
return - 3.1415

View file

@ -117,6 +117,24 @@ namespace pslang::jit::aarch64
// starting at @opcode
void b_inject(std::uint8_t * opcode, std::int32_t offset);
// Load a floating-point value from current program counter plus a
// 19-bit signed @offset multiplied by 4, and store it in floating-point
// register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit
void ldr_fp_pc(std::uint8_t reg_dst, std::uint8_t mode, std::int32_t offset);
// Load a floating-point value from the address stored in @reg_addr plus a
// 12-bit unsigned @offset multiplied by 4, and store it in floating-point
// register @reg_dst. @mode should be 0 for 32-bit values and 1 for 64-bit
void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
// Store a floating-point value from @reg_src into the address stored in
// @reg_addr plus a 12-bit unsigned @offset multiplied by 4. @mode should
// be 0 for 32-bit values and 1 for 64-bit
void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset);
// Negate the floating-point register @reg_src and store the result in @reg_dst
void fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst);
// Return from a subroutine, taking the return address from register @reg
void ret(std::uint8_t reg = 30);

View file

@ -1,4 +1,3 @@
#include "pslang/types/type_fwd.hpp"
#include <pslang/jit/arch/aarch64/compiler.hpp>
#include <pslang/jit/arch/aarch64/instruction_builder.hpp>
#include <pslang/jit/executable.hpp>
@ -21,6 +20,155 @@ namespace pslang::jit::aarch64
{
std::vector<std::uint8_t> code;
std::unordered_map<std::string, std::size_t> code_symbol_table;
std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants;
};
std::uint8_t fp_mode_for(types::type const & type)
{
if (types::equal(type, types::primitive_type(types::f32_type{})))
return 0;
return 1;
}
struct populate_constants_visitor
: ast::const_expression_visitor<populate_constants_visitor>
, ast::const_statement_visitor<populate_constants_visitor>
{
using const_expression_visitor::apply;
using const_statement_visitor::apply;
context & context;
template <typename T>
requires(!std::is_floating_point_v<T>)
void apply(ast::primitive_literal_base<T> const &)
{}
void apply(ast::f32_literal const & node)
{
if (!context.f32_constants.contains(node.value))
{
context.f32_constants[node.value] = context.code.size();
push_bytes(node.value);
}
}
void apply(ast::f64_literal const & node)
{
if (!context.f64_constants.contains(node.value))
{
context.f64_constants[node.value] = context.code.size();
push_bytes(node.value);
}
}
void apply(ast::identifier const &)
{}
void apply(ast::unary_operation const & node)
{
apply(*node.arg1);
}
void apply(ast::binary_operation const & node)
{
apply(*node.arg1);
apply(*node.arg2);
}
void apply(ast::cast_operation const & node)
{
apply(*node.expression);
}
void apply(ast::function_call const & node)
{
apply(*node.function);
for (auto const & argument : node.arguments)
apply(*argument);
}
void apply(ast::array const & node)
{
for (auto const & element : node.elements)
apply(*element);
}
void apply(ast::array_access const & node)
{
apply(*node.array);
apply(*node.index);
}
void apply(ast::field_access const & node)
{
apply(*node.object);
}
void apply(ast::expression_ptr const & node)
{
apply(*node);
}
void apply(ast::assignment const & node)
{
apply(*node.lhs);
apply(*node.rhs);
}
void apply(ast::variable_declaration const & node)
{
apply(*node.initializer);
}
void apply(ast::if_block const &)
{}
void apply(ast::else_block const &)
{}
void apply(ast::else_if_block const &)
{}
void apply(ast::if_chain const & node)
{
for (auto const & block : node.blocks)
apply(*block.statements);
}
void apply(ast::while_block const & node)
{
apply(*node.condition);
apply(*node.statements);
}
void apply(ast::function_definition const & node)
{
apply(*node.statements);
}
void apply(ast::return_statement const & node)
{
apply(*node.value);
}
void apply(ast::field_definition const &)
{}
void apply(ast::struct_definition const &)
{}
private:
template <typename T>
void push_bytes(T const & value)
{
auto begin = (std::uint8_t const *)(&value);
auto end = begin + sizeof(value);
context.code.insert(context.code.end(), begin, end);
}
};
struct reg_extend_visitor
@ -142,6 +290,20 @@ namespace pslang::jit::aarch64
}
}
void apply(ast::f32_literal const & node)
{
auto offset = context.f32_constants.at(node.value);
std::int32_t current = context.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
}
void apply(ast::f64_literal const & node)
{
auto offset = context.f64_constants.at(node.value);
std::int32_t current = context.code.size();
builder.ldr_fp_pc(0, 1, (offset - current) / 4);
}
void apply(ast::identifier const & node)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
@ -156,13 +318,19 @@ namespace pslang::jit::aarch64
void apply(ast::unary_operation const & node)
{
// TODO: floating-point
switch (node.type)
{
case ast::unary_operation_type::negation:
apply(*node.arg1);
builder.sub_reg(31, 0, 0);
extend(0, node.inferred_type);
if (types::is_integer_type(*node.inferred_type))
{
builder.sub_reg(31, 0, 0);
extend(0, node.inferred_type);
}
else if (types::is_floating_point_type(*node.inferred_type))
{
builder.fneg(0, fp_mode_for(*node.inferred_type), 0);
}
break;
case ast::unary_operation_type::logical_not:
apply(*node.arg1);
@ -443,6 +611,24 @@ namespace pslang::jit::aarch64
scopes.back().stack_offset -= 16;
}
// @mode = 0: 32-bit
// @mode = 1: 64-bit
void push_fp(std::uint8_t reg, std::uint8_t mode)
{
builder.sub_imm(31, 31, 16);
builder.str_fp(0, mode, 31, 0);
stack_offset += 16;
scopes.back().stack_offset += 16;
}
void pop_fp(std::uint8_t reg, std::uint8_t mode)
{
builder.ldr_fp(reg, mode, 31, 0);
builder.add_imm(31, 31, 16);
stack_offset -= 16;
scopes.back().stack_offset -= 16;
}
// Set register @reg to -1 (all bits = 1)
void set_m1(std::uint8_t reg)
{
@ -489,8 +675,10 @@ namespace pslang::jit::aarch64
compiled_module compile(ast::statement_list_ptr const & statements)
{
context context;
compile_visitor visitor{{}, context};
visitor.apply(*statements);
populate_constants_visitor{{}, {}, context}.apply(*statements);
compile_visitor{{}, context}.apply(*statements);
auto code = allocate(context.code.size());
std::copy(context.code.data(), context.code.data() + context.code.size(), code.data.get());

View file

@ -159,6 +159,26 @@ namespace pslang::jit::aarch64
*dst |= (std::uint32_t(offset) & 0x3ffffffu);
}
void instruction_builder::ldr_fp_pc(std::uint8_t reg_dst, std::uint8_t mode, std::int32_t offset)
{
do_push(0x1c000000u | ((mode & 0x1u) << 30) | (reg_dst & REG_MASK) | ((std::uint32_t(offset) & 0x7ffffu) << 5));
}
void instruction_builder::ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
{
do_push(0xbd400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
}
void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset)
{
do_push(0xbd000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x1u) << 30));
}
void instruction_builder::fneg(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_dst)
{
do_push(0x1e214000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | ((mode & 0x1u) << 22));
}
void instruction_builder::ret(std::uint8_t reg)
{
do_push(0xd65f0000u | ((reg & REG_MASK) << 5));