diff --git a/examples/ir_test.psl b/examples/ir_test.psl index 8297680..f813b2a 100644 --- a/examples/ir_test.psl +++ b/examples/ir_test.psl @@ -63,9 +63,21 @@ struct ray: func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32: return (value - dot(ray.origin, normal)) / dot(ray.direction, normal) +func test() -> i32[3]: + return [70, 60, 50] + +func print_i32_3(a: i32[3]): + print('[') + print_i32(a[0]) + print(',') + print_i32(a[1]) + print(',') + print_i32(a[2]) + print(']') + mut a = [1, 2, 3] +a = test() a[0] = 10 -let p = a as i32* -print_i32(p[1]) +print_i32_3(test()) print('\n') diff --git a/libs/jit/source/arch/aarch64/compiler_v2.cpp b/libs/jit/source/arch/aarch64/compiler_v2.cpp index 49348ee..41ab352 100644 --- a/libs/jit/source/arch/aarch64/compiler_v2.cpp +++ b/libs/jit/source/arch/aarch64/compiler_v2.cpp @@ -61,7 +61,7 @@ namespace pslang::jit::aarch64 return 1 << mode; } - std::optional get_hfa_data(local_context & lcontext, ast::struct_definition const * node); + std::optional get_hfa_data(local_context & lcontext, types::type_ptr const & type); std::optional compute_hfa_data(local_context & lcontext, ast::struct_definition const * node) { @@ -71,30 +71,15 @@ namespace pslang::jit::aarch64 for (std::size_t i = 0; i < node->fields.size(); ++i) { auto const & field = node->fields[i]; - if (types::is_builtin_type(*field.inferred_type)) + + // NB: recursion must be impossible due to prior checks in type checker + if (auto subdata = get_hfa_data(lcontext, field.inferred_type)) { - if (!types::is_floating_point_type(*field.inferred_type)) + if (type && !types::equal(*type, *subdata->element_type)) return std::nullopt; - if (type && !types::equal(*type, *field.inferred_type)) - return std::nullopt; - - type = field.inferred_type; - ++count; - } - else if (auto struct_type = std::get_if(field.inferred_type.get())) - { - // NB: recursion must be impossible due to prior checks in type checker - if (auto subdata = get_hfa_data(lcontext, struct_type->node)) - { - if (type && !types::equal(*type, *subdata->element_type)) - return std::nullopt; - - type = subdata->element_type; - count += subdata->count; - } - else - return std::nullopt; + type = subdata->element_type; + count += subdata->count; } else return std::nullopt; @@ -106,14 +91,28 @@ namespace pslang::jit::aarch64 return hfa_data{type, count}; } - std::optional get_hfa_data(local_context & lcontext, ast::struct_definition const * node) + std::optional get_hfa_data(local_context & lcontext, types::type_ptr const & type) { - if (auto it = lcontext.struct_hfa.find(node); it != lcontext.struct_hfa.end()) - return it->second; + if (auto struct_type = std::get_if(type.get())) + { + if (auto it = lcontext.struct_hfa.find(struct_type->node); it != lcontext.struct_hfa.end()) + return it->second; - auto result = compute_hfa_data(lcontext, node); - lcontext.struct_hfa[node] = result; - return result; + auto result = compute_hfa_data(lcontext, struct_type->node); + lcontext.struct_hfa[struct_type->node] = result; + return result; + } + else if (auto array_type = std::get_if(type.get())) + { + if (auto subdata = get_hfa_data(lcontext, array_type->element_type)) + return hfa_data{subdata->element_type, subdata->count * array_type->size}; + else + return std::nullopt; + } + else if (types::is_floating_point_type(*type)) + return hfa_data{type, 1}; + else + return std::nullopt; } struct populate_const_data_visitor @@ -737,16 +736,20 @@ namespace pslang::jit::aarch64 template void apply_call(ir::node_ref it, Node const & node, types::type_ptr const & type, DoCall && do_call) { - // TODO: array arguments? // TODO: handle the case when there weren't enough registers std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (auto const & argument : node.arguments) { - if (auto struct_type = std::get_if(argument->inferred_type.get())) + auto struct_type = std::get_if(argument->inferred_type.get()); + auto array_type = std::get_if(argument->inferred_type.get()); + if (struct_type || array_type) { - auto node = struct_type->node; - if (auto hfa = get_hfa_data(lcontext, node); hfa && hfa->count <= 4) + // NB: fixed-size arrays are handled in the same way + // as structs of N identical fields + auto size = ast::type_size(*argument->inferred_type); + + if (auto hfa = get_hfa_data(lcontext, argument->inferred_type); hfa && hfa->count <= 4) { // HFA - passed in consecutive FP registers std::int32_t base_offset = stack_size - stack_position.at(argument); @@ -755,16 +758,14 @@ namespace pslang::jit::aarch64 for (std::size_t i = 0; i < hfa->count; ++i) builder.ldr_fp(fp_reg++, fp_mode, 31, (base_offset + i * size) / size); } - else if (node->layout.size <= 16) + else if (size <= 16) { // Small struct - passed in up to 2 GP registers std::int32_t base_offset = stack_size - stack_position.at(argument); - std::int32_t size = node->layout.size; std::int32_t offset = 0; - while (size > 0) + while (offset < size) { builder.ldr(reg++, 31, (base_offset + offset) / 8); - size -= 8; offset += 8; } } @@ -806,12 +807,16 @@ namespace pslang::jit::aarch64 // TODO: array return value? auto size = ast::type_size(*type); + auto struct_type = std::get_if(type.get()); + auto array_type = std::get_if(type.get()); if (size == 0) {} - else if (auto struct_type = std::get_if(type.get())) + else if (struct_type || array_type) { + // NB: fixed-size arrays are handled in the same way + // as structs of N identical fields auto base_offset = stack_size - stack_position.at(it); - if (auto hfa = get_hfa_data(lcontext, struct_type->node); hfa && hfa->count <= 4) + if (auto hfa = get_hfa_data(lcontext, type); hfa && hfa->count <= 4) { auto fp_mode = fp_mode_for(*hfa->element_type); auto size = fp_size(fp_mode); @@ -863,13 +868,14 @@ namespace pslang::jit::aarch64 { auto type = (*node.value)->inferred_type; auto size = ast::type_size(*type); + auto struct_type = std::get_if(type.get()); + auto array_type = std::get_if(type.get()); if (size == 0) {} - else if (auto struct_type = std::get_if(type.get())) + else if (struct_type || array_type) { auto base_offset = stack_size - stack_position.at(*node.value); - auto struct_node = struct_type->node; - if (auto hfa = get_hfa_data(lcontext, struct_node); hfa && hfa->count <= 4) + if (auto hfa = get_hfa_data(lcontext, type); hfa && hfa->count <= 4) { auto fp_mode = fp_mode_for(*hfa->element_type); auto size = fp_size(fp_mode); @@ -909,8 +915,11 @@ namespace pslang::jit::aarch64 void compile(ast::function_definition const * function_definition, ir::node_ref begin, ir::node_ref end) { - if (auto struct_type = std::get_if(function_definition->inferred_result_type.get())) - if (!get_hfa_data(lcontext, struct_type->node) && struct_type->node->layout.size > 16) + auto result_type = function_definition->inferred_result_type; + auto struct_type = std::get_if(result_type.get()); + auto array_type = std::get_if(result_type.get()); + if (struct_type || array_type) + if (!get_hfa_data(lcontext, result_type) && ast::type_size(*result_type) > 16) return_value_is_large_struct = true; stack_size = 0; @@ -967,11 +976,12 @@ namespace pslang::jit::aarch64 { auto const & argument = function_definition->arguments[i]; auto size = ast::type_size(*argument.inferred_type); + auto struct_type = std::get_if(argument.inferred_type.get()); + auto array_type = std::get_if(argument.inferred_type.get()); if (size == 0) continue; - if (auto struct_type = std::get_if(argument.inferred_type.get())) + if (struct_type || array_type) { - auto node = struct_type->node; - if (auto hfa = get_hfa_data(lcontext, node); hfa && hfa->count <= 4) + if (auto hfa = get_hfa_data(lcontext, argument.inferred_type); hfa && hfa->count <= 4) { // HFA - passed in consecutive FP registers std::int32_t base_offset = stack_size - argument_position[i]; @@ -980,17 +990,15 @@ namespace pslang::jit::aarch64 for (std::size_t i = 0; i < hfa->count; ++i) builder.str_fp(fp_reg++, fp_mode, 31, (base_offset + i * size) / size); } - else if (node->layout.size <= 16) + else if (size <= 16) { // Small struct - passed in up to 2 GP registers std::int32_t base_offset = stack_size - argument_position[i]; - std::int32_t size = node->layout.size; std::int32_t offset = 0; - while (size > 0) + while (offset < size) { builder.str(reg++, 31, (base_offset + offset) / 8); offset += 8; - size -= 8; } } else