Support struct return values in aarch64 compiler

This commit is contained in:
Nikita Lisitsa 2026-03-27 20:25:16 +03:00
parent 7622a882b5
commit 041513f33e
2 changed files with 142 additions and 24 deletions

View file

@ -2,28 +2,78 @@ func print(c: u8):
foreign func putchar(c: i32) -> i32
putchar(c as i32)
func print32(n: u32):
if n >= 10u:
print32(n / 10u)
print('0' + ((n % 10u) as u8))
func print_i32(x: i32):
if x < 0:
print('-')
print_i32(-x)
return
if x >= 10:
print_i32(x / 10)
print('0' + (x % 10 as u8))
func alloc(size: u64) -> unit mut*:
foreign func malloc(size: u64) -> unit mut*
return malloc(size)
foreign func free(ptr: unit*)
let count = 30
let array = alloc(4 * count as u64) as u32 mut*
array[0] = 0u
array[1] = 1u
print32(array[0])
print('\n')
print32(array[1])
print('\n')
mut i = 2
while i < count:
array[i] = array[i - 1] + array[i - 2]
print32(array[i])
print('\n')
func print_f32(x: f32):
if x < 0.0:
print('-')
print_f32(-x)
return
foreign func floorf(x: f32) -> f32
let floor = floorf(x) as i32
print_i32(floor)
print('.')
mut y = x - (floor as f32)
mut i = 0
while i < 5:
y = y * 10.0
let yfloor = floorf(y) as i32
print('0' + (yfloor as u8))
y = y - (yfloor as f32)
i = i + 1
struct vec3:
x: f32
y: f32
z: f32
func print_vec3(v: vec3):
print('(')
print_f32(v.x)
print(',')
print_f32(v.y)
print(',')
print_f32(v.z)
print(')')
func dot(a: vec3, b: vec3) -> f32:
return a.x * b.x + a.y * b.y + a.z * b.z
func add(a: vec3, b: vec3) -> vec3:
return vec3(a.x + b.x, a.y + b.y, a.z + b.z)
func mult(a: vec3, b: f32) -> vec3:
return vec3(a.x * b, a.y * b, a.z * b)
func normalized(v: vec3) -> vec3:
foreign func sqrtf(x: f32) -> f32
return mult(v, 1.0 / sqrtf(dot(v, v)))
struct ray:
origin: vec3
direction: vec3
func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32:
return (value - dot(ray.origin, normal)) / dot(ray.direction, normal)
let t = intersect_plane(ray(vec3(0.0, 0.0, 5.0), vec3(0.0, 1.0, -1.0)), vec3(0.0, 0.0, 1.0), 0.0)
let a = vec3(5.0, 1.0, 4.0)
let n = normalized(vec3(1.0, 2.0, 3.0))
let b = add(a, mult(n, -dot(a, n)))
print_vec3(a)
print('\n')
print_vec3(n)
print('\n')
print_vec3(b)
print('\n')
print_f32(dot(b, n))
print('\n')

View file

@ -291,6 +291,7 @@ namespace pslang::jit::aarch64
std::vector<std::int32_t> argument_position;
std::unordered_map<ir::node_ref, std::int32_t> stack_position;
std::int32_t stack_size = 0;
bool return_value_is_large_struct = false;
void apply(ir::node_ref, ir::label const &, types::type_ptr const &)
{}
@ -750,6 +751,11 @@ namespace pslang::jit::aarch64
else
throw std::runtime_error("Unsupported function argument type");
}
if (return_value_is_large_struct)
{
builder.sub_imm(31, 31, 16);
builder.str(8, 31, 0);
}
if (!lcontext.use_frame_pointer)
{
builder.sub_imm(31, 31, 16);
@ -761,10 +767,40 @@ namespace pslang::jit::aarch64
builder.ldr(30, 31, 0);
builder.add_imm(31, 31, 16);
}
if (return_value_is_large_struct)
{
builder.ldr(8, 31, 0);
builder.add_imm(31, 31, 16);
}
// TODO: struct/array return value?
if (types::is_unit_type(*type))
// TODO: array return value?
auto size = ast::type_size(*type);
if (size == 0)
{}
else if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
auto base_offset = stack_size - stack_position.at(it);
if (auto hfa = get_hfa_data(lcontext, struct_type->node); hfa && hfa->count <= 4)
{
auto fp_mode = fp_mode_for(*hfa->element_type);
auto size = fp_size(fp_mode);
// HFA - returned in consecutive FP registers
for (std::size_t i = 0; i < hfa->count; ++i)
builder.str_fp(i, fp_mode, 31, (base_offset + i * size) / size);
}
else if (size <= 16)
{
// Small struct - returned in x0-x1 registers
builder.str(0, 31, base_offset / 8);
if (size > 8)
builder.str(1, 31, (base_offset + 8) / 8);
}
else
{
// Large struct - returned by pointer in x8 register
copy_memory(8, 0, 31, base_offset, size, 0);
}
}
else if (types::is_integer_like_type(*type))
store(it, 0);
else if (types::is_floating_point_type(*type))
@ -791,11 +827,39 @@ namespace pslang::jit::aarch64
void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &)
{
// TODO: struct/array return value?
// TODO: array return value?
if (node.value)
{
auto type = (*node.value)->inferred_type;
if (types::is_integer_like_type(*type))
auto size = ast::type_size(*type);
if (size == 0)
{}
else if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
auto base_offset = stack_size - stack_position.at(*node.value);
auto struct_node = struct_type->node;
if (auto hfa = get_hfa_data(lcontext, struct_node); hfa && hfa->count <= 4)
{
auto fp_mode = fp_mode_for(*hfa->element_type);
auto size = fp_size(fp_mode);
// HFA - returned in consecutive FP registers
for (std::size_t i = 0; i < hfa->count; ++i)
builder.ldr_fp(i, fp_mode, 31, (base_offset + i * size) / size);
}
else if (size <= 16)
{
// Small struct - returned in x0-x1 registers
builder.ldr(0, 31, base_offset / 8);
if (size > 8)
builder.ldr(1, 31, (base_offset + 8) / 8);
}
else
{
// Large struct - returned by pointer in x8 register
copy_memory(31, base_offset, 8, 0, size, 0);
}
}
else if (types::is_integer_like_type(*type))
load(*node.value, 0);
else if (types::is_floating_point_type(*type))
load_fp(*node.value, 0, fp_mode_for(*type));
@ -814,6 +878,10 @@ namespace pslang::jit::aarch64
void compile(ast::function_definition const * function_definition, ir::node_ref begin, ir::node_ref end)
{
if (auto struct_type = std::get_if<types::struct_type>(function_definition->inferred_result_type.get()))
if (!get_hfa_data(lcontext, struct_type->node) && struct_type->node->layout.size > 16)
return_value_is_large_struct = true;
stack_size = 0;
if (lcontext.use_frame_pointer)