Fix floating-point function arguments in Aarch64 compiler v2

This commit is contained in:
Nikita Lisitsa 2026-03-27 16:17:37 +03:00
parent 0dc740d656
commit 62fc4c88de

View file

@ -4,6 +4,7 @@
#include <pslang/ir/compiler.hpp> #include <pslang/ir/compiler.hpp>
#include <pslang/ast/type.hpp> #include <pslang/ast/type.hpp>
#include <pslang/ast/struct.hpp> #include <pslang/ast/struct.hpp>
#include <pslang/ast/function.hpp>
#include <pslang/types/type_visitor.hpp> #include <pslang/types/type_visitor.hpp>
#include <sstream> #include <sstream>
@ -222,6 +223,7 @@ namespace pslang::jit::aarch64
local_context & lcontext; local_context & lcontext;
instruction_builder & builder; instruction_builder & builder;
std::vector<std::int32_t> argument_position;
std::unordered_map<ir::node_ref, std::int32_t> stack_position; std::unordered_map<ir::node_ref, std::int32_t> stack_position;
std::int32_t stack_size = 0; std::int32_t stack_size = 0;
@ -581,9 +583,7 @@ namespace pslang::jit::aarch64
void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type) void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type)
{ {
// TODO: compute argument layout before compiling the function? // Nothing to do: arguments already pushed on stack in function preamble
// TODO: floating-point arguments? struct/array arguments?
store(it, node.index);
} }
void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &) void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &)
@ -729,16 +729,27 @@ namespace pslang::jit::aarch64
builder.ret(); builder.ret();
} }
void compile(ir::node_ref begin, ir::node_ref end) void compile(ast::function_definition const * function_definition, ir::node_ref begin, ir::node_ref end)
{ {
stack_size = 0; stack_size = 0;
if (lcontext.use_frame_pointer) if (lcontext.use_frame_pointer)
stack_size += 16; stack_size += 16;
for (auto const & argument : function_definition->arguments)
{
auto size = ast::type_size(*argument.inferred_type);
stack_size += ((size + 7) / 8) * 8;
argument_position.push_back(stack_size);
}
for (auto it = begin; it != end; ++it) for (auto it = begin; it != end; ++it)
{ {
if (ir::is_value_instruction(it->instruction)) if (auto argument = std::get_if<ir::argument>(&it->instruction))
{
stack_position[it] = argument_position[argument->index];
}
else if (ir::is_value_instruction(it->instruction))
{ {
auto size = ast::type_size(*it->inferred_type); auto size = ast::type_size(*it->inferred_type);
if (size == 0) continue; if (size == 0) continue;
@ -766,6 +777,21 @@ namespace pslang::jit::aarch64
builder.str(30, 31, (stack_size - 8) / 8); builder.str(30, 31, (stack_size - 8) / 8);
builder.add_imm(31, 29, stack_size - 16); builder.add_imm(31, 29, stack_size - 16);
} }
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (std::size_t i = 0; i < function_definition->arguments.size(); ++i)
{
auto const & argument = function_definition->arguments[i];
auto size = ast::type_size(*argument.inferred_type);
auto fp_mode = fp_mode_for(*argument.inferred_type);
if (size == 0) continue;
if (types::is_integer_like_type(*argument.inferred_type))
builder.str(reg++, 31, (stack_size - argument_position[i]) / 8);
else if (types::is_floating_point_type(*argument.inferred_type))
builder.str_fp(fp_reg++, fp_mode, 31, (stack_size - argument_position[i]) / fp_size(fp_mode));
else
throw std::runtime_error("Unknown argument type");
}
++it; ++it;
for (; it != end; ++it) for (; it != end; ++it)
@ -869,7 +895,7 @@ namespace pslang::jit::aarch64
{ {
pcontext.symbols[symbol.first] = pcontext.code.size(); pcontext.symbols[symbol.first] = pcontext.code.size();
compile_visitor visitor{pcontext, mcontext, lcontext, builder}; compile_visitor visitor{pcontext, mcontext, lcontext, builder};
visitor.compile(symbol.second.begin, symbol.second.end); visitor.compile(symbol.first, symbol.second.begin, symbol.second.end);
} }
pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point); pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point);