diff --git a/libs/jit/source/arch/aarch64/compiler_v2.cpp b/libs/jit/source/arch/aarch64/compiler_v2.cpp index 01b96e1..c241868 100644 --- a/libs/jit/source/arch/aarch64/compiler_v2.cpp +++ b/libs/jit/source/arch/aarch64/compiler_v2.cpp @@ -15,10 +15,20 @@ namespace pslang::jit::aarch64 namespace { + // Homogeneous floating-point aggregate: up to 4 floating-point members + // of the same type (after struct flattening) + struct hfa_data + { + types::type_ptr element_type; + std::size_t count; + }; + struct local_context { bool use_frame_pointer = true; + std::unordered_map> struct_hfa; + std::unordered_map extern_symbols; std::unordered_map nodes; @@ -51,6 +61,61 @@ namespace pslang::jit::aarch64 return 1 << mode; } + std::optional get_hfa_data(local_context & lcontext, ast::struct_definition const * node); + + std::optional compute_hfa_data(local_context & lcontext, ast::struct_definition const * node) + { + types::type_ptr type = nullptr; + std::size_t count = 0; + + for (std::size_t i = 0; i < node->fields.size(); ++i) + { + auto const & field = node->fields[i]; + if (types::is_builtin_type(*field.inferred_type)) + { + if (!types::is_floating_point_type(*field.inferred_type)) + return std::nullopt; + + if (type && !types::equal(*type, *field.inferred_type)) + return std::nullopt; + + type = field.inferred_type; + ++count; + } + else if (auto struct_type = std::get_if(field.inferred_type.get())) + { + // NB: recursion must be impossible due to prior checks in type checker + if (auto subdata = get_hfa_data(lcontext, struct_type->node)) + { + if (type && !types::equal(*type, *subdata->element_type)) + return std::nullopt; + + type = subdata->element_type; + count += subdata->count; + } + else + return std::nullopt; + } + else + return std::nullopt; + } + + if (count == 0) + return std::nullopt; + + return hfa_data{type, count}; + } + + std::optional get_hfa_data(local_context & lcontext, ast::struct_definition const * node) + { + if (auto it = lcontext.struct_hfa.find(node); it != lcontext.struct_hfa.end()) + return it->second; + + auto result = compute_hfa_data(lcontext, node); + lcontext.struct_hfa[node] = result; + return result; + } + struct populate_const_data_visitor { program_context & pcontext; @@ -295,7 +360,7 @@ namespace pslang::jit::aarch64 else if (types::is_floating_point_type(*type)) { auto mode = fp_mode_for(*type); - load_fp(it, 0, mode); + load_fp(node.arg1, 0, mode); builder.fneg(0, mode, 0); store_fp(it, 0, mode); } @@ -319,6 +384,7 @@ namespace pslang::jit::aarch64 { auto arg1_type = node.arg1->inferred_type; bool const is_fp = types::is_floating_point_type(*arg1_type); + bool const result_is_fp = types::is_floating_point_type(*type); std::uint8_t const fp_mode = fp_mode_for(*arg1_type); if (is_fp) @@ -504,7 +570,7 @@ namespace pslang::jit::aarch64 } } - if (is_fp) + if (result_is_fp) store_fp(it, 0, fp_mode); else store(it, 0); @@ -636,14 +702,48 @@ namespace pslang::jit::aarch64 builder.cbnz(0, 0); } - void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type) + template + void apply_call(ir::node_ref it, Node const & node, types::type_ptr const & type, DoCall && do_call) { - // TODO: struct/array arguments? + // TODO: array arguments? + // TODO: handle the case when there weren't enough registers std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (auto const & argument : node.arguments) { - if (types::is_integer_like_type(*argument->inferred_type)) + if (auto struct_type = std::get_if(argument->inferred_type.get())) + { + auto node = struct_type->node; + if (auto hfa = get_hfa_data(lcontext, node); hfa && hfa->count <= 4) + { + // HFA - passed in consecutive FP registers + std::int32_t base_offset = stack_size - stack_position.at(argument); + auto fp_mode = fp_mode_for(*hfa->element_type); + auto size = fp_size(fp_mode); + for (std::size_t i = 0; i < hfa->count; ++i) + builder.ldr_fp(fp_reg++, fp_mode, 31, (base_offset + i * size) / size); + } + else if (node->layout.size <= 16) + { + // Small struct - passed in up to 2 GP registers + std::int32_t base_offset = stack_size - stack_position.at(argument); + std::int32_t size = node->layout.size; + std::int32_t offset = 0; + while (size > 0) + { + builder.ldr(reg++, 31, (base_offset + offset) / 8); + size -= 8; + offset += 8; + } + } + else + { + // Large struct - passed by pointer + std::int32_t base_offset = stack_size - stack_position.at(argument); + builder.add_imm(31, reg++, base_offset); + } + } + else if (types::is_integer_like_type(*argument->inferred_type)) load(argument, reg++); else if (types::is_floating_point_type(*argument->inferred_type)) load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type)); @@ -655,8 +755,7 @@ namespace pslang::jit::aarch64 builder.sub_imm(31, 31, 16); builder.str(30, 31, 0); } - lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target); - builder.bl(0); + do_call(); if (!lcontext.use_frame_pointer) { builder.ldr(30, 31, 0); @@ -674,36 +773,20 @@ namespace pslang::jit::aarch64 throw std::runtime_error("Unsupported return value type"); } + void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type) + { + apply_call(it, node, type, [&]{ + lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target); + builder.bl(0); + }); + } + void apply(ir::node_ref it, ir::call_pointer const & node, types::type_ptr const & type) { - // TODO: struct/array arguments? - std::uint8_t reg = 0; - std::uint8_t fp_reg = 0; - for (auto const & argument : node.arguments) - { - if (types::is_integer_like_type(*argument->inferred_type)) - load(argument, reg++); - else if (types::is_floating_point_type(*argument->inferred_type)) - load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type)); - else - throw std::runtime_error("Unsupported function argument type"); - } - load(node.pointer, reg); - builder.sub_imm(31, 31, 16); - builder.str(30, 31, 0); - builder.bl_reg(reg); - builder.ldr(30, 31, 0); - builder.add_imm(31, 31, 16); - - // TODO: struct/array return value? - if (types::is_unit_type(*type)) - {} - else if (types::is_integer_like_type(*type)) - store(it, 0); - else if (types::is_floating_point_type(*type)) - store_fp(it, 0, fp_mode_for(*type)); - else - throw std::runtime_error("Unsupported return value type"); + apply_call(it, node, type, [&]{ + load(node.pointer, 9); + builder.bl_reg(9); + }); } void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &) @@ -777,18 +860,54 @@ namespace pslang::jit::aarch64 builder.str(30, 31, (stack_size - 8) / 8); builder.add_imm(31, 29, stack_size - 16); } + + // TODO: handle the case when there weren't enough registers std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (std::size_t i = 0; i < function_definition->arguments.size(); ++i) { auto const & argument = function_definition->arguments[i]; auto size = ast::type_size(*argument.inferred_type); - auto fp_mode = fp_mode_for(*argument.inferred_type); if (size == 0) continue; - if (types::is_integer_like_type(*argument.inferred_type)) + if (auto struct_type = std::get_if(argument.inferred_type.get())) + { + auto node = struct_type->node; + if (auto hfa = get_hfa_data(lcontext, node); hfa && hfa->count <= 4) + { + // HFA - passed in consecutive FP registers + std::int32_t base_offset = stack_size - argument_position[i]; + auto fp_mode = fp_mode_for(*hfa->element_type); + auto size = fp_size(fp_mode); + for (std::size_t i = 0; i < hfa->count; ++i) + builder.str_fp(fp_reg++, fp_mode, 31, (base_offset + i * size) / size); + } + else if (node->layout.size <= 16) + { + // Small struct - passed in up to 2 GP registers + std::int32_t base_offset = stack_size - argument_position[i]; + std::int32_t size = node->layout.size; + std::int32_t offset = 0; + while (size > 0) + { + builder.str(reg++, 31, (base_offset + offset) / 8); + offset += 8; + size -= 8; + } + } + else + { + // Large struct - passed by pointer + std::int32_t dst_offset = stack_size - argument_position[i]; + copy_memory(reg++, 0, 31, dst_offset, size, 9); + } + } + else if (types::is_integer_like_type(*argument.inferred_type)) builder.str(reg++, 31, (stack_size - argument_position[i]) / 8); else if (types::is_floating_point_type(*argument.inferred_type)) + { + auto fp_mode = fp_mode_for(*argument.inferred_type); builder.str_fp(fp_reg++, fp_mode, 31, (stack_size - argument_position[i]) / fp_size(fp_mode)); + } else throw std::runtime_error("Unknown argument type"); } @@ -804,27 +923,27 @@ namespace pslang::jit::aarch64 } private: - void load(ir::node_ref ir, std::uint8_t reg) + void load(ir::node_ref it, std::uint8_t reg) { - std::int32_t offset = stack_size - stack_position.at(ir); + std::int32_t offset = stack_size - stack_position.at(it); builder.ldr(reg, 31, offset / 8); } - void load_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode) + void load_fp(ir::node_ref it, std::uint8_t reg, std::uint8_t mode) { - std::int32_t offset = stack_size - stack_position.at(ir); + std::int32_t offset = stack_size - stack_position.at(it); builder.ldr_fp(reg, mode, 31, offset / fp_size(mode)); } - void store(ir::node_ref ir, std::uint8_t reg) + void store(ir::node_ref it, std::uint8_t reg) { - std::int32_t offset = stack_size - stack_position.at(ir); + std::int32_t offset = stack_size - stack_position.at(it); builder.str(reg, 31, offset / 8); } - void store_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode) + void store_fp(ir::node_ref it, std::uint8_t reg, std::uint8_t mode) { - std::int32_t offset = stack_size - stack_position.at(ir); + std::int32_t offset = stack_size - stack_position.at(it); builder.str_fp(reg, mode, 31, offset / fp_size(mode)); }