pslang/libs/jit/source/arch/aarch64/compiler.cpp

509 lines
No EOL
12 KiB
C++

#include "pslang/types/type_fwd.hpp"
#include <pslang/jit/arch/aarch64/compiler.hpp>
#include <pslang/jit/arch/aarch64/instruction_builder.hpp>
#include <pslang/jit/executable.hpp>
#include <pslang/ast/expression_visitor.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/types/type_visitor.hpp>
#include <stdexcept>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace pslang::jit::aarch64
{
namespace
{
struct context
{
std::vector<std::uint8_t> code;
std::unordered_map<std::string, std::size_t> code_symbol_table;
};
struct reg_extend_visitor
: types::const_visitor<reg_extend_visitor>
{
using const_visitor::apply;
instruction_builder & builder;
std::uint8_t reg;
void apply(types::bool_type const &)
{}
void apply(types::f32_type const &)
{}
void apply(types::f64_type const &)
{}
template <typename T>
void apply(types::primitive_type_base<T> const &)
{
if constexpr (sizeof(T) == 8)
{
return;
}
if constexpr (std::is_signed_v<T>)
{
builder.sbfm(reg, reg, sizeof(T) * 8);
}
if constexpr (std::is_unsigned_v<T>)
{
builder.ubfm(reg, reg, sizeof(T) * 8);
}
}
template <typename T>
void apply(T const &)
{
throw std::runtime_error("Not implemented");
}
};
struct compile_function_visitor
: ast::const_statement_visitor<compile_function_visitor>
, ast::const_expression_visitor<compile_function_visitor>
{
using const_statement_visitor::apply;
using const_expression_visitor::apply;
context & context;
instruction_builder builder{context.code};
// Difference between initial stack pointer at function enter
// and current virtual stack pointer value. The actual stack pointer
// value is rounded down to a multiple of 16
std::uint32_t stack_offset = 0;
struct variable_data
{
// Difference between initial stack pointer at scope enter
// and the variable address
// Must be a multiple of 8
std::uint32_t frame_offset;
};
struct scope
{
std::unordered_map<std::string, variable_data> variables = {};
// Difference between initial virtual stack pointer at scope enter
// and current virtual stack pointer value
std::uint32_t stack_offset = 0;
};
std::vector<scope> scopes;
template <typename Node>
void apply(Node const &)
{
throw std::runtime_error("Not implemented");
}
void apply(ast::bool_literal const & node)
{
if (node.value)
set_m1(0);
else
builder.movz(0, 0);
}
template <typename T>
requires(std::is_integral_v<T> && !std::is_same_v<T, bool>)
void apply(ast::primitive_literal_base<T> const & node)
{
for (std::size_t i = 0; i < sizeof(T); i += 2)
{
if (i == 0)
{
builder.movz(0, std::uint64_t(node.value));
}
else
{
auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8));
if (val != 0)
builder.movk(0, val, i / 2);
}
}
if (sizeof(T) < 8)
{
if constexpr (std::is_signed_v<T>)
{
if (node.value < 0)
builder.sbfm(0, 0, sizeof(T) * 8);
}
}
}
void apply(ast::identifier const & node)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->variables.find(node.name); jt != it->variables.end())
{
builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8);
break;
}
}
}
void apply(ast::unary_operation const & node)
{
// TODO: floating-point
switch (node.type)
{
case ast::unary_operation_type::negation:
apply(*node.arg1);
builder.sub_reg(31, 0, 0);
extend(0, node.inferred_type);
break;
case ast::unary_operation_type::logical_not:
apply(*node.arg1);
builder.or_not_reg(31, 0, 0);
break;
}
}
void apply(ast::binary_operation const & node)
{
// TODO: floating-point
switch (node.type)
{
case ast::binary_operation_type::addition:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.add_reg(1, 0, 0);
extend(0, node.inferred_type);
break;
case ast::binary_operation_type::subtraction:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.sub_reg(1, 0, 0);
extend(0, node.inferred_type);
break;
case ast::binary_operation_type::multiplication:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.mul_reg(1, 0, 0);
extend(0, node.inferred_type);
break;
case ast::binary_operation_type::division:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
if (types::is_signed_integer_type(*node.inferred_type))
builder.sdiv_reg(1, 0, 0);
else
builder.udiv_reg(1, 0, 0);
extend(0, node.inferred_type);
break;
case ast::binary_operation_type::remainder:
// TODO: implement via div & mul & sub
throw std::runtime_error("Not implemented");
case ast::binary_operation_type::logical_and:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.and_reg(1, 0, 0);
break;
case ast::binary_operation_type::logical_or:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.or_reg(1, 0, 0);
break;
case ast::binary_operation_type::logical_xor:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.xor_reg(1, 0, 0);
break;
case ast::binary_operation_type::equals:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.cmp_reg(1, 0);
builder.csetm(0, 0b0000);
break;
case ast::binary_operation_type::not_equals:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.cmp_reg(1, 0);
builder.csetm(0, 0b0001);
break;
case ast::binary_operation_type::less:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.cmp_reg(0, 1);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
break;
case ast::binary_operation_type::greater:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.cmp_reg(1, 0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
break;
case ast::binary_operation_type::less_equals:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.cmp_reg(1, 0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
break;
case ast::binary_operation_type::greater_equals:
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
builder.cmp_reg(0, 1);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
break;
default:
throw std::runtime_error("Not implemented");
}
}
void apply(ast::assignment const & node)
{
auto identifier = std::get_if<ast::identifier>(node.lhs.get());
if (!identifier)
throw std::runtime_error("Not implemented");
apply(*node.rhs);
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->variables.find(identifier->name); jt != it->variables.end())
{
builder.str(0, 31, (stack_offset - jt->second.frame_offset) / 8);
break;
}
}
}
void apply(ast::variable_declaration const & node)
{
apply(*node.initializer);
push(0);
scopes.back().variables[node.name] = {.frame_offset = stack_offset};
}
void apply(ast::if_chain const & node)
{
std::vector<std::size_t> branch_to_end;
for (std::size_t i = 0; i < node.blocks.size(); ++i)
{
auto const & block = node.blocks[i];
std::optional<std::size_t> branch_skip;
if (block.condition)
{
apply(*block.condition);
branch_skip = context.code.size();
builder.cbz(0, 0);
}
scopes.emplace_back();
apply(*block.statements);
scope_cleanup();
scopes.pop_back();
if (i + 1 < node.blocks.size())
{
branch_to_end.push_back(context.code.size());
builder.b(0);
}
if (branch_skip)
{
auto branch_offset = context.code.size() - *branch_skip;
builder.cb_inject(context.code.data() + *branch_skip, branch_offset / 4);
}
}
auto end = context.code.size();
for (auto instruction : branch_to_end)
{
auto delta = end - instruction;
builder.b_inject(context.code.data() + instruction, delta / 4);
}
}
void apply(ast::while_block const & node)
{
std::int32_t start = context.code.size();
apply(*node.condition);
std::int32_t skip = context.code.size();
builder.cbz(0, 0);
scopes.emplace_back();
apply(*node.statements);
scope_cleanup();
scopes.pop_back();
std::int32_t loop = context.code.size();
builder.b(0);
std::int32_t end = context.code.size();
builder.cb_inject(context.code.data() + skip, (end - skip) / 4);
builder.b_inject(context.code.data() + loop, (start - loop) / 4);
}
void apply(ast::return_statement const & node)
{
apply(*node.value);
if (stack_offset > 0)
builder.add_imm(31, 31, stack_offset);
builder.ret();
}
void apply(ast::function_definition const & node)
{
// Don't handle internal functions
}
void apply(ast::statement_list const & node)
{
for (auto const & statement : node.statements)
apply(*statement);
}
void do_apply(ast::function_definition const & node)
{
// TODO: floating-point / struct arguments
scopes.emplace_back();
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
auto type = ast::get_type(*node.arguments[i].type);
if (types::is_bool_type(*type))
{
builder.tst(i, i);
builder.csetm(i, 0b0001);
}
else if (types::is_integer_type(*type))
extend(i, type);
push(i);
scopes.back().variables[node.arguments[i].name] = {.frame_offset = stack_offset};
}
apply(*node.statements);
scopes.pop_back();
}
private:
void push(std::uint8_t reg)
{
builder.sub_imm(31, 31, 16);
builder.str(reg, 31, 0);
stack_offset += 16;
scopes.back().stack_offset += 16;
}
void pop(std::uint8_t reg)
{
builder.ldr(reg, 31, 0);
builder.add_imm(31, 31, 16);
stack_offset -= 16;
scopes.back().stack_offset -= 16;
}
// Set register @reg to -1 (all bits = 1)
void set_m1(std::uint8_t reg)
{
builder.or_not_reg(31, 31, reg);
}
// Sign- or zero-extend the register depending on the exact type
void extend(std::uint8_t reg, types::type_ptr const & type)
{
reg_extend_visitor{{}, builder, reg}.apply(*type);
}
void scope_cleanup()
{
if (scopes.back().stack_offset > 0)
builder.add_imm(31, 31, scopes.back().stack_offset);
}
};
struct compile_visitor
: ast::const_statement_visitor<compile_visitor>
{
using const_statement_visitor::apply;
context & context;
instruction_builder builder{context.code};
template <typename Statement>
void apply(Statement const &)
{
throw std::runtime_error("Not implemented");
}
void apply(ast::function_definition const & node)
{
context.code_symbol_table[node.name] = context.code.size();
compile_function_visitor visitor{{}, {}, context};
visitor.do_apply(node);
}
};
}
compiled_module compile(ast::statement_list_ptr const & statements)
{
context context;
compile_visitor visitor{{}, context};
visitor.apply(*statements);
auto code = allocate(context.code.size());
std::copy(context.code.data(), context.code.data() + context.code.size(), code.data.get());
return compiled_module {
.data = {},
.code = {
.memory = std::move(code),
.symbol_table = std::move(context.code_symbol_table),
},
.entry_point = 0,
.abi = abi::armv8,
};
}
}