diff --git a/libs/jit/source/arch/aarch64/compiler_v2.cpp b/libs/jit/source/arch/aarch64/compiler_v2.cpp index 661e457..01b96e1 100644 --- a/libs/jit/source/arch/aarch64/compiler_v2.cpp +++ b/libs/jit/source/arch/aarch64/compiler_v2.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -222,6 +223,7 @@ namespace pslang::jit::aarch64 local_context & lcontext; instruction_builder & builder; + std::vector argument_position; std::unordered_map stack_position; std::int32_t stack_size = 0; @@ -581,9 +583,7 @@ namespace pslang::jit::aarch64 void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type) { - // TODO: compute argument layout before compiling the function? - // TODO: floating-point arguments? struct/array arguments? - store(it, node.index); + // Nothing to do: arguments already pushed on stack in function preamble } void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &) @@ -729,16 +729,27 @@ namespace pslang::jit::aarch64 builder.ret(); } - void compile(ir::node_ref begin, ir::node_ref end) + void compile(ast::function_definition const * function_definition, ir::node_ref begin, ir::node_ref end) { stack_size = 0; if (lcontext.use_frame_pointer) stack_size += 16; + for (auto const & argument : function_definition->arguments) + { + auto size = ast::type_size(*argument.inferred_type); + stack_size += ((size + 7) / 8) * 8; + argument_position.push_back(stack_size); + } + for (auto it = begin; it != end; ++it) { - if (ir::is_value_instruction(it->instruction)) + if (auto argument = std::get_if(&it->instruction)) + { + stack_position[it] = argument_position[argument->index]; + } + else if (ir::is_value_instruction(it->instruction)) { auto size = ast::type_size(*it->inferred_type); if (size == 0) continue; @@ -766,6 +777,21 @@ namespace pslang::jit::aarch64 builder.str(30, 31, (stack_size - 8) / 8); builder.add_imm(31, 29, stack_size - 16); } + std::uint8_t reg = 0; + std::uint8_t fp_reg = 0; + for (std::size_t i = 0; i < function_definition->arguments.size(); ++i) + { + auto const & argument = function_definition->arguments[i]; + auto size = ast::type_size(*argument.inferred_type); + auto fp_mode = fp_mode_for(*argument.inferred_type); + if (size == 0) continue; + if (types::is_integer_like_type(*argument.inferred_type)) + builder.str(reg++, 31, (stack_size - argument_position[i]) / 8); + else if (types::is_floating_point_type(*argument.inferred_type)) + builder.str_fp(fp_reg++, fp_mode, 31, (stack_size - argument_position[i]) / fp_size(fp_mode)); + else + throw std::runtime_error("Unknown argument type"); + } ++it; for (; it != end; ++it) @@ -869,7 +895,7 @@ namespace pslang::jit::aarch64 { pcontext.symbols[symbol.first] = pcontext.code.size(); compile_visitor visitor{pcontext, mcontext, lcontext, builder}; - visitor.compile(symbol.second.begin, symbol.second.end); + visitor.compile(symbol.first, symbol.second.begin, symbol.second.end); } pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point);