From 041513f33e0924118e5de2455d046cffecc9e8a0 Mon Sep 17 00:00:00 2001 From: lisyarus Date: Fri, 27 Mar 2026 20:25:16 +0300 Subject: [PATCH] Support struct return values in aarch64 compiler --- examples/ir_test.psl | 90 +++++++++++++++----- libs/jit/source/arch/aarch64/compiler_v2.cpp | 76 ++++++++++++++++- 2 files changed, 142 insertions(+), 24 deletions(-) diff --git a/examples/ir_test.psl b/examples/ir_test.psl index e74e70c..bf2efc7 100644 --- a/examples/ir_test.psl +++ b/examples/ir_test.psl @@ -2,28 +2,78 @@ func print(c: u8): foreign func putchar(c: i32) -> i32 putchar(c as i32) -func print32(n: u32): - if n >= 10u: - print32(n / 10u) - print('0' + ((n % 10u) as u8)) +func print_i32(x: i32): + if x < 0: + print('-') + print_i32(-x) + return + if x >= 10: + print_i32(x / 10) + print('0' + (x % 10 as u8)) -func alloc(size: u64) -> unit mut*: - foreign func malloc(size: u64) -> unit mut* - return malloc(size) +func print_f32(x: f32): + if x < 0.0: + print('-') + print_f32(-x) + return + foreign func floorf(x: f32) -> f32 + let floor = floorf(x) as i32 + print_i32(floor) + print('.') + mut y = x - (floor as f32) + mut i = 0 + while i < 5: + y = y * 10.0 + let yfloor = floorf(y) as i32 + print('0' + (yfloor as u8)) + y = y - (yfloor as f32) + i = i + 1 -foreign func free(ptr: unit*) +struct vec3: + x: f32 + y: f32 + z: f32 -let count = 30 -let array = alloc(4 * count as u64) as u32 mut* -array[0] = 0u -array[1] = 1u -print32(array[0]) +func print_vec3(v: vec3): + print('(') + print_f32(v.x) + print(',') + print_f32(v.y) + print(',') + print_f32(v.z) + print(')') + +func dot(a: vec3, b: vec3) -> f32: + return a.x * b.x + a.y * b.y + a.z * b.z + +func add(a: vec3, b: vec3) -> vec3: + return vec3(a.x + b.x, a.y + b.y, a.z + b.z) + +func mult(a: vec3, b: f32) -> vec3: + return vec3(a.x * b, a.y * b, a.z * b) + +func normalized(v: vec3) -> vec3: + foreign func sqrtf(x: f32) -> f32 + return mult(v, 1.0 / sqrtf(dot(v, v))) + +struct ray: + origin: vec3 + direction: vec3 + +func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32: + return (value - dot(ray.origin, normal)) / dot(ray.direction, normal) + +let t = intersect_plane(ray(vec3(0.0, 0.0, 5.0), vec3(0.0, 1.0, -1.0)), vec3(0.0, 0.0, 1.0), 0.0) + +let a = vec3(5.0, 1.0, 4.0) +let n = normalized(vec3(1.0, 2.0, 3.0)) +let b = add(a, mult(n, -dot(a, n))) + +print_vec3(a) print('\n') -print32(array[1]) +print_vec3(n) +print('\n') +print_vec3(b) +print('\n') +print_f32(dot(b, n)) print('\n') -mut i = 2 -while i < count: - array[i] = array[i - 1] + array[i - 2] - print32(array[i]) - print('\n') - i = i + 1 diff --git a/libs/jit/source/arch/aarch64/compiler_v2.cpp b/libs/jit/source/arch/aarch64/compiler_v2.cpp index c241868..f05f319 100644 --- a/libs/jit/source/arch/aarch64/compiler_v2.cpp +++ b/libs/jit/source/arch/aarch64/compiler_v2.cpp @@ -291,6 +291,7 @@ namespace pslang::jit::aarch64 std::vector argument_position; std::unordered_map stack_position; std::int32_t stack_size = 0; + bool return_value_is_large_struct = false; void apply(ir::node_ref, ir::label const &, types::type_ptr const &) {} @@ -750,6 +751,11 @@ namespace pslang::jit::aarch64 else throw std::runtime_error("Unsupported function argument type"); } + if (return_value_is_large_struct) + { + builder.sub_imm(31, 31, 16); + builder.str(8, 31, 0); + } if (!lcontext.use_frame_pointer) { builder.sub_imm(31, 31, 16); @@ -761,10 +767,40 @@ namespace pslang::jit::aarch64 builder.ldr(30, 31, 0); builder.add_imm(31, 31, 16); } + if (return_value_is_large_struct) + { + builder.ldr(8, 31, 0); + builder.add_imm(31, 31, 16); + } - // TODO: struct/array return value? - if (types::is_unit_type(*type)) + // TODO: array return value? + auto size = ast::type_size(*type); + if (size == 0) {} + else if (auto struct_type = std::get_if(type.get())) + { + auto base_offset = stack_size - stack_position.at(it); + if (auto hfa = get_hfa_data(lcontext, struct_type->node); hfa && hfa->count <= 4) + { + auto fp_mode = fp_mode_for(*hfa->element_type); + auto size = fp_size(fp_mode); + // HFA - returned in consecutive FP registers + for (std::size_t i = 0; i < hfa->count; ++i) + builder.str_fp(i, fp_mode, 31, (base_offset + i * size) / size); + } + else if (size <= 16) + { + // Small struct - returned in x0-x1 registers + builder.str(0, 31, base_offset / 8); + if (size > 8) + builder.str(1, 31, (base_offset + 8) / 8); + } + else + { + // Large struct - returned by pointer in x8 register + copy_memory(8, 0, 31, base_offset, size, 0); + } + } else if (types::is_integer_like_type(*type)) store(it, 0); else if (types::is_floating_point_type(*type)) @@ -791,11 +827,39 @@ namespace pslang::jit::aarch64 void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &) { - // TODO: struct/array return value? + // TODO: array return value? if (node.value) { auto type = (*node.value)->inferred_type; - if (types::is_integer_like_type(*type)) + auto size = ast::type_size(*type); + if (size == 0) + {} + else if (auto struct_type = std::get_if(type.get())) + { + auto base_offset = stack_size - stack_position.at(*node.value); + auto struct_node = struct_type->node; + if (auto hfa = get_hfa_data(lcontext, struct_node); hfa && hfa->count <= 4) + { + auto fp_mode = fp_mode_for(*hfa->element_type); + auto size = fp_size(fp_mode); + // HFA - returned in consecutive FP registers + for (std::size_t i = 0; i < hfa->count; ++i) + builder.ldr_fp(i, fp_mode, 31, (base_offset + i * size) / size); + } + else if (size <= 16) + { + // Small struct - returned in x0-x1 registers + builder.ldr(0, 31, base_offset / 8); + if (size > 8) + builder.ldr(1, 31, (base_offset + 8) / 8); + } + else + { + // Large struct - returned by pointer in x8 register + copy_memory(31, base_offset, 8, 0, size, 0); + } + } + else if (types::is_integer_like_type(*type)) load(*node.value, 0); else if (types::is_floating_point_type(*type)) load_fp(*node.value, 0, fp_mode_for(*type)); @@ -814,6 +878,10 @@ namespace pslang::jit::aarch64 void compile(ast::function_definition const * function_definition, ir::node_ref begin, ir::node_ref end) { + if (auto struct_type = std::get_if(function_definition->inferred_result_type.get())) + if (!get_hfa_data(lcontext, struct_type->node) && struct_type->node->layout.size > 16) + return_value_is_large_struct = true; + stack_size = 0; if (lcontext.use_frame_pointer)