Fix floating-point function arguments in Aarch64 compiler v2

This commit is contained in:
Nikita Lisitsa 2026-03-27 16:17:37 +03:00
parent 0dc740d656
commit 62fc4c88de

View file

@ -4,6 +4,7 @@
#include <pslang/ir/compiler.hpp>
#include <pslang/ast/type.hpp>
#include <pslang/ast/struct.hpp>
#include <pslang/ast/function.hpp>
#include <pslang/types/type_visitor.hpp>
#include <sstream>
@ -222,6 +223,7 @@ namespace pslang::jit::aarch64
local_context & lcontext;
instruction_builder & builder;
std::vector<std::int32_t> argument_position;
std::unordered_map<ir::node_ref, std::int32_t> stack_position;
std::int32_t stack_size = 0;
@ -581,9 +583,7 @@ namespace pslang::jit::aarch64
void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type)
{
// TODO: compute argument layout before compiling the function?
// TODO: floating-point arguments? struct/array arguments?
store(it, node.index);
// Nothing to do: arguments already pushed on stack in function preamble
}
void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &)
@ -729,16 +729,27 @@ namespace pslang::jit::aarch64
builder.ret();
}
void compile(ir::node_ref begin, ir::node_ref end)
void compile(ast::function_definition const * function_definition, ir::node_ref begin, ir::node_ref end)
{
stack_size = 0;
if (lcontext.use_frame_pointer)
stack_size += 16;
for (auto const & argument : function_definition->arguments)
{
auto size = ast::type_size(*argument.inferred_type);
stack_size += ((size + 7) / 8) * 8;
argument_position.push_back(stack_size);
}
for (auto it = begin; it != end; ++it)
{
if (ir::is_value_instruction(it->instruction))
if (auto argument = std::get_if<ir::argument>(&it->instruction))
{
stack_position[it] = argument_position[argument->index];
}
else if (ir::is_value_instruction(it->instruction))
{
auto size = ast::type_size(*it->inferred_type);
if (size == 0) continue;
@ -766,6 +777,21 @@ namespace pslang::jit::aarch64
builder.str(30, 31, (stack_size - 8) / 8);
builder.add_imm(31, 29, stack_size - 16);
}
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (std::size_t i = 0; i < function_definition->arguments.size(); ++i)
{
auto const & argument = function_definition->arguments[i];
auto size = ast::type_size(*argument.inferred_type);
auto fp_mode = fp_mode_for(*argument.inferred_type);
if (size == 0) continue;
if (types::is_integer_like_type(*argument.inferred_type))
builder.str(reg++, 31, (stack_size - argument_position[i]) / 8);
else if (types::is_floating_point_type(*argument.inferred_type))
builder.str_fp(fp_reg++, fp_mode, 31, (stack_size - argument_position[i]) / fp_size(fp_mode));
else
throw std::runtime_error("Unknown argument type");
}
++it;
for (; it != end; ++it)
@ -869,7 +895,7 @@ namespace pslang::jit::aarch64
{
pcontext.symbols[symbol.first] = pcontext.code.size();
compile_visitor visitor{pcontext, mcontext, lcontext, builder};
visitor.compile(symbol.second.begin, symbol.second.end);
visitor.compile(symbol.first, symbol.second.begin, symbol.second.end);
}
pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point);