pslang/libs/jit/source/arch/aarch64/compiler_v2.cpp

829 lines
22 KiB
C++

#include <pslang/jit/arch/aarch64/compiler.hpp>
#include <pslang/jit/arch/aarch64/instruction_builder.hpp>
#include <pslang/ir/node.hpp>
#include <pslang/ir/compiler.hpp>
#include <pslang/ast/type.hpp>
#include <pslang/types/type_visitor.hpp>
#include <sstream>
namespace pslang::jit::aarch64
{
namespace
{
struct local_context
{
bool use_frame_pointer = true;
std::unordered_map<std::string, std::int32_t> extern_symbols;
std::unordered_map<ir::node_ref, std::int32_t> nodes;
std::unordered_map<float, std::int32_t> f16_constants;
std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants;
struct resolve_data
{
std::int32_t offset;
ir::node_ref target;
};
std::vector<resolve_data> branch_resolve;
std::vector<resolve_data> cbranch_resolve;
std::vector<resolve_data> adr_resolve;
};
std::uint8_t fp_mode_for(types::type const & type)
{
if (types::equal(type, types::primitive_type(types::f16_type{})))
return 1;
if (types::equal(type, types::primitive_type(types::f32_type{})))
return 2;
return 3;
}
std::int32_t fp_size(std::uint8_t mode)
{
return 1 << mode;
}
struct populate_const_data_visitor
{
program_context & pcontext;
local_context & lcontext;
template <typename Node>
void apply(Node const & node, types::type_ptr const &)
{}
void apply(ir::literal const & node, types::type_ptr const &)
{
if (auto f16_literal = std::get_if<ast::f16_literal>(&node.value))
{
lcontext.f16_constants[f16_literal->value.repr] = pcontext.code.size();
push_bytes(f16_literal->value.repr);
}
else if (auto f32_literal = std::get_if<ast::f32_literal>(&node.value))
{
lcontext.f32_constants[f32_literal->value] = pcontext.code.size();
push_bytes(f32_literal->value);
}
else if (auto f64_literal = std::get_if<ast::f64_literal>(&node.value))
{
lcontext.f32_constants[f64_literal->value] = pcontext.code.size();
push_bytes(f64_literal->value);
}
}
void apply(ir::extern_symbol const & node, types::type_ptr const &)
{
std::int32_t offset = pcontext.code.size();
lcontext.extern_symbols[node.name] = offset;
pcontext.foreign_resolve.push_back({node.name, offset});
push_bytes<void *>(nullptr);
}
private:
template <typename T>
void push_bytes(T const & value)
{
auto begin = (std::uint8_t const *)(&value);
auto end = begin + sizeof(value);
pcontext.code.insert(pcontext.code.end(), begin, end);
}
};
// Set register @reg to -1 (all bits = 1)
void set_m1(instruction_builder & builder, std::uint8_t reg)
{
builder.or_not_reg(31, 31, reg);
}
struct literal_visitor
{
program_context & pcontext;
local_context & lcontext;
instruction_builder & builder;
void operator()(ast::bool_literal const & node)
{
if (node.value)
set_m1(builder, 0);
else
builder.movz(0, 0);
}
template <typename T>
requires(std::is_integral_v<T> && !std::is_same_v<T, bool>)
void operator()(ast::primitive_literal_base<T> const & node)
{
for (std::size_t i = 0; i < sizeof(T); i += 2)
{
if (i == 0)
{
builder.movz(0, std::uint64_t(node.value));
}
else
{
auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8));
if (val != 0)
builder.movk(0, val, i / 2);
}
}
if (sizeof(T) < 8)
{
if constexpr (std::is_signed_v<T>)
{
if (node.value < 0)
builder.sbfm(0, 0, sizeof(T) * 8);
}
}
}
void operator()(ast::f16_literal const & node)
{
auto offset = lcontext.f16_constants.at(node.value.repr);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
builder.fcvt(0, 0b10, 0, 0b01);
}
void operator()(ast::f32_literal const & node)
{
auto offset = lcontext.f32_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
}
void operator()(ast::f64_literal const & node)
{
auto offset = lcontext.f64_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 1, (offset - current) / 4);
}
};
struct reg_extend_visitor
: types::const_visitor<reg_extend_visitor>
{
using const_visitor::apply;
instruction_builder & builder;
std::uint8_t reg;
void apply(types::bool_type const &)
{
builder.ubfm(reg, reg, 8);
}
void apply(types::f16_type const &)
{}
void apply(types::f32_type const &)
{}
void apply(types::f64_type const &)
{}
template <typename T>
void apply(types::primitive_type_base<T> const &)
{
if constexpr (sizeof(T) == 8)
{
return;
}
if constexpr (std::is_signed_v<T>)
{
builder.sbfm(reg, reg, sizeof(T) * 8);
}
if constexpr (std::is_unsigned_v<T>)
{
builder.ubfm(reg, reg, sizeof(T) * 8);
}
}
template <typename T>
void apply(T const &)
{
throw std::runtime_error(std::string("reg_extend_visitor is not implemented for ") + typeid(T).name());
}
};
struct compile_visitor
{
program_context & pcontext;
ir::module_context const & mcontext;
local_context & lcontext;
instruction_builder & builder;
std::unordered_map<ir::node_ref, std::int32_t> stack_position;
std::int32_t stack_size = 0;
void apply(ir::node_ref, ir::label const &, types::type_ptr const &)
{}
void apply(ir::node_ref it, ir::literal const & node, types::type_ptr const & type)
{
std::visit(literal_visitor{pcontext, lcontext, builder}, node.value);
if (types::is_integer_like_type(*type))
store(it, 0);
else if (types::is_floating_point_type(*type))
store_fp(it, 0, fp_mode_for(*type));
}
void apply(ir::node_ref it, ir::copy const & node, types::type_ptr const &)
{
// TODO: struct/array copy?
load(node.source, 0);
store(it, 0);
}
void apply(ir::node_ref it, ir::load const & node, types::type_ptr const & type)
{
// TODO: struct/array load?
load(node.ptr, 0);
auto size = ast::type_size(*type);
if (size == 1)
builder.ldurb(0, 0, 0);
else if (size == 2)
builder.ldurh(0, 0, 0);
else if (size == 4)
builder.ldurw(0, 0, 0);
else if (size == 8)
builder.ldur(0, 0, 0);
else
throw std::runtime_error(std::format("Unsupported load size: {}", size));
if (types::is_bool_type(*type) || types::is_integer_type(*type))
extend(0, type);
store(it, 0);
}
void apply(ir::node_ref, ir::store const & node, types::type_ptr const & type)
{
// TODO: struct/array store?
load(node.ptr, 0);
load(node.value, 1);
auto size = ast::type_size(*type);
if (size == 1)
builder.sturb(1, 0, 0);
else if (size == 2)
builder.sturh(1, 0, 0);
else if (size == 4)
builder.sturw(1, 0, 0);
else if (size == 8)
builder.stur(1, 0, 0);
else
throw std::runtime_error(std::format("Unsupported store size: {}", size));
}
void apply(ir::node_ref it, ir::unary_operation const & node, types::type_ptr const & type)
{
switch (node.type)
{
case ast::unary_operation_type::negation:
if (types::is_integer_type(*type))
{
load(node.arg1, 0);
builder.sub_reg(31, 0, 0);
extend(0, type);
store(it, 0);
}
else if (types::is_floating_point_type(*type))
{
auto mode = fp_mode_for(*type);
load_fp(it, 0, mode);
builder.fneg(0, mode, 0);
store_fp(it, 0, mode);
}
break;
case ast::unary_operation_type::logical_not:
load(node.arg1, 0);
builder.or_not_reg(31, 0, 0);
if (types::is_integer_type(*type))
extend(0, type);
store(it, 0);
break;
case ast::unary_operation_type::address_of:
case ast::unary_operation_type::mutable_address_of:
builder.add_imm(31, 0, stack_size - stack_position.at(node.arg1));
store(it, 0);
break;
case ast::unary_operation_type::dereference:
load(node.arg1, 0);
builder.ldr(0, 0 ,0);
store(it, 0);
break;
}
}
void apply(ir::node_ref it, ir::binary_operation const & node, types::type_ptr const & type)
{
auto arg1_type = node.arg1->inferred_type;
bool const is_fp = types::is_floating_point_type(*arg1_type);
std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
if (is_fp)
{
load_fp(node.arg1, 0, fp_mode);
load_fp(node.arg2, 1, fp_mode);
}
else
{
load(node.arg1, 0);
load(node.arg2, 1);
}
switch (node.type)
{
case ast::binary_operation_type::addition:
if (is_fp)
builder.fadd(0, 1, fp_mode, 0);
else
{
builder.add_reg(0, 1, 0);
if (!types::is_pointer_type(*type))
extend(0, type);
}
break;
case ast::binary_operation_type::subtraction:
if (is_fp)
builder.fsub(0, 1, fp_mode, 0);
else
{
builder.sub_reg(0, 1, 0);
if (!types::is_pointer_type(*type))
extend(0, type);
}
break;
case ast::binary_operation_type::multiplication:
if (is_fp)
builder.fmul(0, 1, fp_mode, 0);
else
{
builder.mul_reg(0, 1, 0);
extend(0, type);
}
break;
case ast::binary_operation_type::division:
if (is_fp)
builder.fdiv(0, 1, fp_mode, 0);
else
{
if (types::is_signed_integer_type(*type))
builder.sdiv_reg(0, 1, 0);
else
builder.udiv_reg(0, 1, 0);
extend(0, type);
}
break;
case ast::binary_operation_type::remainder:
if (types::is_signed_integer_type(*type))
{
builder.sdiv_reg(0, 1, 2);
builder.mul_reg(1, 2, 1);
builder.sub_reg(0, 1, 0);
}
else if (types::is_unsigned_integer_type(*type))
{
builder.udiv_reg(0, 1, 2);
builder.mul_reg(1, 2, 1);
builder.sub_reg(0, 1, 0);
}
break;
case ast::binary_operation_type::binary_and:
builder.and_reg(0, 1, 0);
break;
case ast::binary_operation_type::logical_and:
throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler");
case ast::binary_operation_type::binary_or:
builder.or_reg(0, 1, 0);
break;
case ast::binary_operation_type::logical_or:
throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler");
case ast::binary_operation_type::logical_xor:
builder.xor_reg(0, 1, 0);
break;
case ast::binary_operation_type::equals:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0000);
}
else
{
builder.cmp_reg(0, 1);
builder.csetm(0, 0b0000);
}
break;
case ast::binary_operation_type::not_equals:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0001);
}
else
{
builder.cmp_reg(0, 1);
builder.csetm(0, 0b0001);
}
break;
case ast::binary_operation_type::less:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0100);
}
else
{
builder.cmp_reg(1, 0);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break;
case ast::binary_operation_type::greater:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0100);
}
else
{
builder.cmp_reg(0, 1);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break;
case ast::binary_operation_type::less_equals:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b1001);
}
else
{
builder.cmp_reg(0, 1);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break;
case ast::binary_operation_type::greater_equals:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b1001);
}
else
{
builder.cmp_reg(1, 0);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break;
default:
{
std::ostringstream os;
os << "binary operation " << node.type << " is not implemented";
throw std::runtime_error(os.str());
}
}
if (is_fp)
store_fp(it, 0, fp_mode);
else
store(it, 0);
}
void apply(ir::node_ref it, ir::cast_operation const & node, types::type_ptr const &)
{
auto src_type = node.arg1->inferred_type;
auto dst_type = node.target_type;
if (types::equal(*src_type, *dst_type) || (types::is_pointer_type(*src_type) && types::is_pointer_type(*dst_type)))
{
load(node.arg1, 0);
store(it, 0);
return;
}
if (types::is_integer_type(*src_type))
{
load(node.arg1, 0);
if (types::is_integer_type(*dst_type))
{
extend(0, dst_type);
}
else if (types::is_floating_point_type(*dst_type))
{
auto dst_mode = fp_mode_for(*dst_type);
if (types::is_signed_integer_type(*src_type))
{
builder.fmov(0, 0, 3, 1);
builder.scvtf(0, 0, 3);
if (dst_mode != 3)
builder.fcvt(0, 3, 0, dst_mode);
}
else if (types::is_unsigned_integer_type(*src_type))
{
builder.fmov(0, 0, 3, 1);
builder.ucvtf(0, 0, 3);
if (dst_mode != 3)
builder.fcvt(0, 3, 0, dst_mode);
}
}
}
else if (types::is_floating_point_type(*src_type))
{
auto src_mode = fp_mode_for(*src_type);
if (types::is_integer_type(*dst_type))
{
if (types::is_signed_integer_type(*dst_type))
{
builder.fcvtns(0, 0, src_mode);
extend(0, dst_type);
}
else if (types::is_unsigned_integer_type(*dst_type))
{
builder.fcvtnu(0, 0, src_mode);
extend(0, dst_type);
}
}
else if (types::is_floating_point_type(*dst_type))
{
auto dst_mode = fp_mode_for(*dst_type);
builder.fcvt(0, src_mode, 0, dst_mode);
}
}
if (types::is_integer_type(*dst_type))
{
store(it, 0);
}
else if (types::is_floating_point_type(*dst_type))
{
store_fp(it, 0, fp_mode_for(*dst_type));
}
}
void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type)
{
// TODO: compute argument layout before compiling the function?
// TODO: floating-point arguments? struct/array arguments?
store(it, node.index);
}
void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &)
{
lcontext.adr_resolve.emplace_back(pcontext.code.size(), node.target);
builder.adr(0, 0);
store(it, 0);
}
void apply(ir::node_ref it, ir::extern_symbol const & node, types::type_ptr const &)
{
builder.ldr_pc(0, (lcontext.extern_symbols[node.name] - (std::int32_t)pcontext.code.size()) / 4);
store(it, 0);
}
void apply(ir::node_ref, ir::assignment const & node, types::type_ptr const &)
{
// TODO: struct/array assignment?
load(node.rhs, 0);
store(node.lhs, 0);
}
void apply(ir::node_ref, ir::jump const & node, types::type_ptr const &)
{
lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.b(0);
}
void apply(ir::node_ref, ir::jump_if_zero const & node, types::type_ptr const &)
{
load(node.condition, 0);
lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.cbz(0, 0);
}
void apply(ir::node_ref, ir::jump_if_nonzero const & node, types::type_ptr const &)
{
load(node.condition, 0);
lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.cbnz(0, 0);
}
void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type)
{
// TODO: struct/array arguments?
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (auto const & argument : node.arguments)
{
if (types::is_integer_like_type(*argument->inferred_type))
load(argument, reg++);
else if (types::is_floating_point_type(*argument->inferred_type))
load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type));
else
throw std::runtime_error("Unsupported function argument type");
}
if (!lcontext.use_frame_pointer)
{
builder.sub_imm(31, 31, 16);
builder.str(30, 31, 0);
}
lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.bl(0);
if (!lcontext.use_frame_pointer)
{
builder.ldr(30, 31, 0);
builder.add_imm(31, 31, 16);
}
// TODO: struct/array return value?
if (types::is_unit_type(*type))
{}
else if (types::is_integer_like_type(*type))
store(it, 0);
else if (types::is_floating_point_type(*type))
store_fp(it, 0, fp_mode_for(*type));
else
throw std::runtime_error("Unsupported return value type");
}
void apply(ir::node_ref it, ir::call_pointer const & node, types::type_ptr const & type)
{
// TODO: struct/array arguments?
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (auto const & argument : node.arguments)
{
if (types::is_integer_like_type(*argument->inferred_type))
load(argument, reg++);
else if (types::is_floating_point_type(*argument->inferred_type))
load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type));
else
throw std::runtime_error("Unsupported function argument type");
}
load(node.pointer, reg);
builder.sub_imm(31, 31, 16);
builder.str(30, 31, 0);
builder.bl_reg(reg);
builder.ldr(30, 31, 0);
builder.add_imm(31, 31, 16);
// TODO: struct/array return value?
if (types::is_unit_type(*type))
{}
else if (types::is_integer_like_type(*type))
store(it, 0);
else if (types::is_floating_point_type(*type))
store_fp(it, 0, fp_mode_for(*type));
else
throw std::runtime_error("Unsupported return value type");
}
void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &)
{
// TODO: struct/array return value?
if (node.value)
{
auto type = (*node.value)->inferred_type;
if (types::is_integer_like_type(*type))
load(*node.value, 0);
else if (types::is_floating_point_type(*type))
load_fp(*node.value, 0, fp_mode_for(*type));
else
throw std::runtime_error("Unsupported return value type");
}
if (lcontext.use_frame_pointer)
{
builder.ldr(29, 31, (stack_size - 16) / 8);
builder.ldr(30, 31, (stack_size - 8) / 8);
}
if (stack_size > 0)
builder.add_imm(31, 31, stack_size);
builder.ret();
}
void compile(ir::node_ref begin, ir::node_ref end)
{
stack_size = 0;
if (lcontext.use_frame_pointer)
stack_size += 16;
for (auto it = begin; it != end; ++it)
{
if (ir::is_value_instruction(it->instruction))
{
stack_size += 8;
stack_position[it] = stack_size;
}
}
stack_size = ((stack_size + 15) / 16) * 16;
if (!std::holds_alternative<ir::label>(begin->instruction))
throw std::runtime_error("First IR node of a function must be a label");
auto it = begin;
lcontext.nodes[it] = pcontext.code.size();
if (stack_size > 0)
builder.sub_imm(31, 31, stack_size);
if (lcontext.use_frame_pointer)
{
builder.str(29, 31, (stack_size - 16) / 8);
builder.str(30, 31, (stack_size - 8) / 8);
builder.add_imm(31, 29, stack_size - 16);
}
++it;
for (; it != end; ++it)
{
lcontext.nodes[it] = pcontext.code.size();
std::visit([&](auto const & instruction){ apply(it, instruction, it->inferred_type); }, it->instruction);
}
}
private:
void load(ir::node_ref ir, std::uint8_t reg)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.ldr(reg, 31, offset / 8);
}
void load_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.ldr_fp(reg, mode, 31, offset / fp_size(mode));
}
void store(ir::node_ref ir, std::uint8_t reg)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.str(reg, 31, offset / 8);
}
void store_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.str_fp(reg, mode, 31, offset / fp_size(mode));
}
// Sign- or zero-extend the register depending on the exact type
void extend(std::uint8_t reg, types::type_ptr const & type)
{
reg_extend_visitor{{}, builder, reg}.apply(*type);
}
};
}
void compile(program_context & pcontext, ir::module_context const & mcontext)
{
local_context lcontext;
instruction_builder builder{pcontext.code};
{
populate_const_data_visitor visitor{pcontext, lcontext};
for (auto it = mcontext.nodes->begin(); it != mcontext.nodes->end(); ++it)
std::visit([&](auto const & instruction){ visitor.apply(instruction, it->inferred_type); }, it->instruction);
}
for (auto const & symbol : mcontext.symbols)
{
pcontext.symbols[symbol.first] = pcontext.code.size();
compile_visitor visitor{pcontext, mcontext, lcontext, builder};
visitor.compile(symbol.second.begin, symbol.second.end);
}
pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point);
for (auto const & resolve : lcontext.branch_resolve)
builder.b_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4);
for (auto const & resolve : lcontext.cbranch_resolve)
builder.cb_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4);
for (auto const & resolve : lcontext.adr_resolve)
builder.adr_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4);
}
}