Support struct function arguments + fixes
This commit is contained in:
parent
62fc4c88de
commit
7622a882b5
1 changed files with 164 additions and 45 deletions
|
|
@ -15,10 +15,20 @@ namespace pslang::jit::aarch64
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
|
// Homogeneous floating-point aggregate: up to 4 floating-point members
|
||||||
|
// of the same type (after struct flattening)
|
||||||
|
struct hfa_data
|
||||||
|
{
|
||||||
|
types::type_ptr element_type;
|
||||||
|
std::size_t count;
|
||||||
|
};
|
||||||
|
|
||||||
struct local_context
|
struct local_context
|
||||||
{
|
{
|
||||||
bool use_frame_pointer = true;
|
bool use_frame_pointer = true;
|
||||||
|
|
||||||
|
std::unordered_map<ast::struct_definition const *, std::optional<hfa_data>> struct_hfa;
|
||||||
|
|
||||||
std::unordered_map<std::string, std::int32_t> extern_symbols;
|
std::unordered_map<std::string, std::int32_t> extern_symbols;
|
||||||
std::unordered_map<ir::node_ref, std::int32_t> nodes;
|
std::unordered_map<ir::node_ref, std::int32_t> nodes;
|
||||||
|
|
||||||
|
|
@ -51,6 +61,61 @@ namespace pslang::jit::aarch64
|
||||||
return 1 << mode;
|
return 1 << mode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<hfa_data> get_hfa_data(local_context & lcontext, ast::struct_definition const * node);
|
||||||
|
|
||||||
|
std::optional<hfa_data> compute_hfa_data(local_context & lcontext, ast::struct_definition const * node)
|
||||||
|
{
|
||||||
|
types::type_ptr type = nullptr;
|
||||||
|
std::size_t count = 0;
|
||||||
|
|
||||||
|
for (std::size_t i = 0; i < node->fields.size(); ++i)
|
||||||
|
{
|
||||||
|
auto const & field = node->fields[i];
|
||||||
|
if (types::is_builtin_type(*field.inferred_type))
|
||||||
|
{
|
||||||
|
if (!types::is_floating_point_type(*field.inferred_type))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
if (type && !types::equal(*type, *field.inferred_type))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
type = field.inferred_type;
|
||||||
|
++count;
|
||||||
|
}
|
||||||
|
else if (auto struct_type = std::get_if<types::struct_type>(field.inferred_type.get()))
|
||||||
|
{
|
||||||
|
// NB: recursion must be impossible due to prior checks in type checker
|
||||||
|
if (auto subdata = get_hfa_data(lcontext, struct_type->node))
|
||||||
|
{
|
||||||
|
if (type && !types::equal(*type, *subdata->element_type))
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
type = subdata->element_type;
|
||||||
|
count += subdata->count;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count == 0)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
return hfa_data{type, count};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<hfa_data> get_hfa_data(local_context & lcontext, ast::struct_definition const * node)
|
||||||
|
{
|
||||||
|
if (auto it = lcontext.struct_hfa.find(node); it != lcontext.struct_hfa.end())
|
||||||
|
return it->second;
|
||||||
|
|
||||||
|
auto result = compute_hfa_data(lcontext, node);
|
||||||
|
lcontext.struct_hfa[node] = result;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct populate_const_data_visitor
|
struct populate_const_data_visitor
|
||||||
{
|
{
|
||||||
program_context & pcontext;
|
program_context & pcontext;
|
||||||
|
|
@ -295,7 +360,7 @@ namespace pslang::jit::aarch64
|
||||||
else if (types::is_floating_point_type(*type))
|
else if (types::is_floating_point_type(*type))
|
||||||
{
|
{
|
||||||
auto mode = fp_mode_for(*type);
|
auto mode = fp_mode_for(*type);
|
||||||
load_fp(it, 0, mode);
|
load_fp(node.arg1, 0, mode);
|
||||||
builder.fneg(0, mode, 0);
|
builder.fneg(0, mode, 0);
|
||||||
store_fp(it, 0, mode);
|
store_fp(it, 0, mode);
|
||||||
}
|
}
|
||||||
|
|
@ -319,6 +384,7 @@ namespace pslang::jit::aarch64
|
||||||
{
|
{
|
||||||
auto arg1_type = node.arg1->inferred_type;
|
auto arg1_type = node.arg1->inferred_type;
|
||||||
bool const is_fp = types::is_floating_point_type(*arg1_type);
|
bool const is_fp = types::is_floating_point_type(*arg1_type);
|
||||||
|
bool const result_is_fp = types::is_floating_point_type(*type);
|
||||||
std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
|
std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
|
||||||
|
|
||||||
if (is_fp)
|
if (is_fp)
|
||||||
|
|
@ -504,7 +570,7 @@ namespace pslang::jit::aarch64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_fp)
|
if (result_is_fp)
|
||||||
store_fp(it, 0, fp_mode);
|
store_fp(it, 0, fp_mode);
|
||||||
else
|
else
|
||||||
store(it, 0);
|
store(it, 0);
|
||||||
|
|
@ -636,14 +702,48 @@ namespace pslang::jit::aarch64
|
||||||
builder.cbnz(0, 0);
|
builder.cbnz(0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type)
|
template <typename Node, typename DoCall>
|
||||||
|
void apply_call(ir::node_ref it, Node const & node, types::type_ptr const & type, DoCall && do_call)
|
||||||
{
|
{
|
||||||
// TODO: struct/array arguments?
|
// TODO: array arguments?
|
||||||
|
// TODO: handle the case when there weren't enough registers
|
||||||
std::uint8_t reg = 0;
|
std::uint8_t reg = 0;
|
||||||
std::uint8_t fp_reg = 0;
|
std::uint8_t fp_reg = 0;
|
||||||
for (auto const & argument : node.arguments)
|
for (auto const & argument : node.arguments)
|
||||||
{
|
{
|
||||||
if (types::is_integer_like_type(*argument->inferred_type))
|
if (auto struct_type = std::get_if<types::struct_type>(argument->inferred_type.get()))
|
||||||
|
{
|
||||||
|
auto node = struct_type->node;
|
||||||
|
if (auto hfa = get_hfa_data(lcontext, node); hfa && hfa->count <= 4)
|
||||||
|
{
|
||||||
|
// HFA - passed in consecutive FP registers
|
||||||
|
std::int32_t base_offset = stack_size - stack_position.at(argument);
|
||||||
|
auto fp_mode = fp_mode_for(*hfa->element_type);
|
||||||
|
auto size = fp_size(fp_mode);
|
||||||
|
for (std::size_t i = 0; i < hfa->count; ++i)
|
||||||
|
builder.ldr_fp(fp_reg++, fp_mode, 31, (base_offset + i * size) / size);
|
||||||
|
}
|
||||||
|
else if (node->layout.size <= 16)
|
||||||
|
{
|
||||||
|
// Small struct - passed in up to 2 GP registers
|
||||||
|
std::int32_t base_offset = stack_size - stack_position.at(argument);
|
||||||
|
std::int32_t size = node->layout.size;
|
||||||
|
std::int32_t offset = 0;
|
||||||
|
while (size > 0)
|
||||||
|
{
|
||||||
|
builder.ldr(reg++, 31, (base_offset + offset) / 8);
|
||||||
|
size -= 8;
|
||||||
|
offset += 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Large struct - passed by pointer
|
||||||
|
std::int32_t base_offset = stack_size - stack_position.at(argument);
|
||||||
|
builder.add_imm(31, reg++, base_offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (types::is_integer_like_type(*argument->inferred_type))
|
||||||
load(argument, reg++);
|
load(argument, reg++);
|
||||||
else if (types::is_floating_point_type(*argument->inferred_type))
|
else if (types::is_floating_point_type(*argument->inferred_type))
|
||||||
load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type));
|
load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type));
|
||||||
|
|
@ -655,8 +755,7 @@ namespace pslang::jit::aarch64
|
||||||
builder.sub_imm(31, 31, 16);
|
builder.sub_imm(31, 31, 16);
|
||||||
builder.str(30, 31, 0);
|
builder.str(30, 31, 0);
|
||||||
}
|
}
|
||||||
lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target);
|
do_call();
|
||||||
builder.bl(0);
|
|
||||||
if (!lcontext.use_frame_pointer)
|
if (!lcontext.use_frame_pointer)
|
||||||
{
|
{
|
||||||
builder.ldr(30, 31, 0);
|
builder.ldr(30, 31, 0);
|
||||||
|
|
@ -674,36 +773,20 @@ namespace pslang::jit::aarch64
|
||||||
throw std::runtime_error("Unsupported return value type");
|
throw std::runtime_error("Unsupported return value type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type)
|
||||||
|
{
|
||||||
|
apply_call(it, node, type, [&]{
|
||||||
|
lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target);
|
||||||
|
builder.bl(0);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void apply(ir::node_ref it, ir::call_pointer const & node, types::type_ptr const & type)
|
void apply(ir::node_ref it, ir::call_pointer const & node, types::type_ptr const & type)
|
||||||
{
|
{
|
||||||
// TODO: struct/array arguments?
|
apply_call(it, node, type, [&]{
|
||||||
std::uint8_t reg = 0;
|
load(node.pointer, 9);
|
||||||
std::uint8_t fp_reg = 0;
|
builder.bl_reg(9);
|
||||||
for (auto const & argument : node.arguments)
|
});
|
||||||
{
|
|
||||||
if (types::is_integer_like_type(*argument->inferred_type))
|
|
||||||
load(argument, reg++);
|
|
||||||
else if (types::is_floating_point_type(*argument->inferred_type))
|
|
||||||
load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type));
|
|
||||||
else
|
|
||||||
throw std::runtime_error("Unsupported function argument type");
|
|
||||||
}
|
|
||||||
load(node.pointer, reg);
|
|
||||||
builder.sub_imm(31, 31, 16);
|
|
||||||
builder.str(30, 31, 0);
|
|
||||||
builder.bl_reg(reg);
|
|
||||||
builder.ldr(30, 31, 0);
|
|
||||||
builder.add_imm(31, 31, 16);
|
|
||||||
|
|
||||||
// TODO: struct/array return value?
|
|
||||||
if (types::is_unit_type(*type))
|
|
||||||
{}
|
|
||||||
else if (types::is_integer_like_type(*type))
|
|
||||||
store(it, 0);
|
|
||||||
else if (types::is_floating_point_type(*type))
|
|
||||||
store_fp(it, 0, fp_mode_for(*type));
|
|
||||||
else
|
|
||||||
throw std::runtime_error("Unsupported return value type");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &)
|
void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &)
|
||||||
|
|
@ -777,18 +860,54 @@ namespace pslang::jit::aarch64
|
||||||
builder.str(30, 31, (stack_size - 8) / 8);
|
builder.str(30, 31, (stack_size - 8) / 8);
|
||||||
builder.add_imm(31, 29, stack_size - 16);
|
builder.add_imm(31, 29, stack_size - 16);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: handle the case when there weren't enough registers
|
||||||
std::uint8_t reg = 0;
|
std::uint8_t reg = 0;
|
||||||
std::uint8_t fp_reg = 0;
|
std::uint8_t fp_reg = 0;
|
||||||
for (std::size_t i = 0; i < function_definition->arguments.size(); ++i)
|
for (std::size_t i = 0; i < function_definition->arguments.size(); ++i)
|
||||||
{
|
{
|
||||||
auto const & argument = function_definition->arguments[i];
|
auto const & argument = function_definition->arguments[i];
|
||||||
auto size = ast::type_size(*argument.inferred_type);
|
auto size = ast::type_size(*argument.inferred_type);
|
||||||
auto fp_mode = fp_mode_for(*argument.inferred_type);
|
|
||||||
if (size == 0) continue;
|
if (size == 0) continue;
|
||||||
if (types::is_integer_like_type(*argument.inferred_type))
|
if (auto struct_type = std::get_if<types::struct_type>(argument.inferred_type.get()))
|
||||||
|
{
|
||||||
|
auto node = struct_type->node;
|
||||||
|
if (auto hfa = get_hfa_data(lcontext, node); hfa && hfa->count <= 4)
|
||||||
|
{
|
||||||
|
// HFA - passed in consecutive FP registers
|
||||||
|
std::int32_t base_offset = stack_size - argument_position[i];
|
||||||
|
auto fp_mode = fp_mode_for(*hfa->element_type);
|
||||||
|
auto size = fp_size(fp_mode);
|
||||||
|
for (std::size_t i = 0; i < hfa->count; ++i)
|
||||||
|
builder.str_fp(fp_reg++, fp_mode, 31, (base_offset + i * size) / size);
|
||||||
|
}
|
||||||
|
else if (node->layout.size <= 16)
|
||||||
|
{
|
||||||
|
// Small struct - passed in up to 2 GP registers
|
||||||
|
std::int32_t base_offset = stack_size - argument_position[i];
|
||||||
|
std::int32_t size = node->layout.size;
|
||||||
|
std::int32_t offset = 0;
|
||||||
|
while (size > 0)
|
||||||
|
{
|
||||||
|
builder.str(reg++, 31, (base_offset + offset) / 8);
|
||||||
|
offset += 8;
|
||||||
|
size -= 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Large struct - passed by pointer
|
||||||
|
std::int32_t dst_offset = stack_size - argument_position[i];
|
||||||
|
copy_memory(reg++, 0, 31, dst_offset, size, 9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (types::is_integer_like_type(*argument.inferred_type))
|
||||||
builder.str(reg++, 31, (stack_size - argument_position[i]) / 8);
|
builder.str(reg++, 31, (stack_size - argument_position[i]) / 8);
|
||||||
else if (types::is_floating_point_type(*argument.inferred_type))
|
else if (types::is_floating_point_type(*argument.inferred_type))
|
||||||
|
{
|
||||||
|
auto fp_mode = fp_mode_for(*argument.inferred_type);
|
||||||
builder.str_fp(fp_reg++, fp_mode, 31, (stack_size - argument_position[i]) / fp_size(fp_mode));
|
builder.str_fp(fp_reg++, fp_mode, 31, (stack_size - argument_position[i]) / fp_size(fp_mode));
|
||||||
|
}
|
||||||
else
|
else
|
||||||
throw std::runtime_error("Unknown argument type");
|
throw std::runtime_error("Unknown argument type");
|
||||||
}
|
}
|
||||||
|
|
@ -804,27 +923,27 @@ namespace pslang::jit::aarch64
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void load(ir::node_ref ir, std::uint8_t reg)
|
void load(ir::node_ref it, std::uint8_t reg)
|
||||||
{
|
{
|
||||||
std::int32_t offset = stack_size - stack_position.at(ir);
|
std::int32_t offset = stack_size - stack_position.at(it);
|
||||||
builder.ldr(reg, 31, offset / 8);
|
builder.ldr(reg, 31, offset / 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
void load_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode)
|
void load_fp(ir::node_ref it, std::uint8_t reg, std::uint8_t mode)
|
||||||
{
|
{
|
||||||
std::int32_t offset = stack_size - stack_position.at(ir);
|
std::int32_t offset = stack_size - stack_position.at(it);
|
||||||
builder.ldr_fp(reg, mode, 31, offset / fp_size(mode));
|
builder.ldr_fp(reg, mode, 31, offset / fp_size(mode));
|
||||||
}
|
}
|
||||||
|
|
||||||
void store(ir::node_ref ir, std::uint8_t reg)
|
void store(ir::node_ref it, std::uint8_t reg)
|
||||||
{
|
{
|
||||||
std::int32_t offset = stack_size - stack_position.at(ir);
|
std::int32_t offset = stack_size - stack_position.at(it);
|
||||||
builder.str(reg, 31, offset / 8);
|
builder.str(reg, 31, offset / 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
void store_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode)
|
void store_fp(ir::node_ref it, std::uint8_t reg, std::uint8_t mode)
|
||||||
{
|
{
|
||||||
std::int32_t offset = stack_size - stack_position.at(ir);
|
std::int32_t offset = stack_size - stack_position.at(it);
|
||||||
builder.str_fp(reg, mode, 31, offset / fp_size(mode));
|
builder.str_fp(reg, mode, 31, offset / fp_size(mode));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue