pslang/libs/jit/source/arch/aarch64/compiler.cpp

1447 lines
38 KiB
C++

#include <pslang/jit/arch/aarch64/compiler.hpp>
#include <pslang/jit/arch/aarch64/instruction_builder.hpp>
#include <pslang/jit/executable.hpp>
#include <pslang/ast/expression_visitor.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/types/type_visitor.hpp>
#include <stdexcept>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <sstream>
namespace pslang::jit::aarch64
{
namespace
{
// Homogeneous floating-point aggregate: up to 4 floating-point members
// of the same type (after struct flattening)
struct hfa_data
{
types::type_ptr type;
std::size_t count;
};
struct local_context
{
std::unordered_map<float, std::int32_t> f16_constants;
std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants;
std::unordered_map<std::string, std::int32_t> foreign_address;
std::unordered_map<ast::function_definition const *, std::int32_t> functions;
struct struct_data
{
std::optional<hfa_data> hfa = {};
};
std::unordered_map<ast::struct_definition const *, struct_data> structs;
struct scope
{
std::unordered_set<std::string> foreign_functions;
std::unordered_map<std::string, ast::function_definition const *> functions;
std::unordered_map<std::string, ast::struct_definition const *> structs;
};
std::vector<scope> scopes;
struct resolve_info
{
ast::function_definition const * node;
// Must be 'adr' instruction
std::int32_t instruction_offset;
};
std::vector<resolve_info> resolve;
bool is_foreign(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->foreign_functions.contains(name))
return true;
if (it->functions.contains(name))
return false;
if (it->structs.contains(name))
return false;
}
return false;
}
ast::function_definition const * is_function(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->functions.find(name); jt != it->functions.end())
return jt->second;
if (it->foreign_functions.contains(name))
return nullptr;
if (it->structs.contains(name))
return nullptr;
}
return nullptr;
}
ast::struct_definition const * is_struct(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->structs.find(name); jt != it->structs.end())
return jt->second;
if (it->foreign_functions.contains(name))
return nullptr;
if (it->functions.contains(name))
return nullptr;
}
return nullptr;
}
};
std::uint8_t fp_mode_for(types::type const & type)
{
if (types::equal(type, types::primitive_type(types::f16_type{})))
return 1;
if (types::equal(type, types::primitive_type(types::f32_type{})))
return 2;
return 3;
}
bool is_short_circuiting(ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::logical_and:
case ast::binary_operation_type::logical_or:
return true;
default:
return false;
}
}
// Add all f16, f32 and f64 constants as read-only data entries
// Add extern pointers for all foreign functions as read-only data entries
struct populate_constants_visitor
: ast::const_expression_visitor<populate_constants_visitor>
, ast::const_statement_visitor<populate_constants_visitor>
{
using const_expression_visitor::apply;
using const_statement_visitor::apply;
program_context & pcontext;
local_context & lcontext;
template <typename T>
requires(!std::is_floating_point_v<T>)
void apply(ast::primitive_literal_base<T> const &)
{}
void apply(ast::f16_literal const & node)
{
if (!lcontext.f16_constants.contains(node.value.repr))
{
lcontext.f16_constants[node.value.repr] = pcontext.code.size();
push_bytes(node.value.repr);
}
}
void apply(ast::f32_literal const & node)
{
if (!lcontext.f32_constants.contains(node.value))
{
lcontext.f32_constants[node.value] = pcontext.code.size();
push_bytes(node.value);
}
}
void apply(ast::f64_literal const & node)
{
if (!lcontext.f64_constants.contains(node.value))
{
lcontext.f64_constants[node.value] = pcontext.code.size();
push_bytes(node.value);
}
}
void apply(ast::identifier const &)
{}
void apply(ast::unary_operation const & node)
{
apply(*node.arg1);
}
void apply(ast::binary_operation const & node)
{
apply(*node.arg1);
apply(*node.arg2);
}
void apply(ast::cast_operation const & node)
{
apply(*node.expression);
}
void apply(ast::function_call const & node)
{
if (node.function)
apply(*node.function);
for (auto const & argument : node.arguments)
apply(*argument);
}
void apply(ast::array const & node)
{
for (auto const & element : node.elements)
apply(*element);
}
void apply(ast::array_access const & node)
{
apply(*node.array);
apply(*node.index);
}
void apply(ast::field_access const & node)
{
apply(*node.object);
}
void apply(ast::expression_ptr const & node)
{
apply(*node);
}
void apply(ast::assignment const & node)
{
apply(*node.lhs);
apply(*node.rhs);
}
void apply(ast::variable_declaration const & node)
{
apply(*node.initializer);
}
void apply(ast::if_block const &)
{}
void apply(ast::else_block const &)
{}
void apply(ast::else_if_block const &)
{}
void apply(ast::if_chain const & node)
{
for (auto const & block : node.blocks)
{
if (block.condition)
apply(*block.condition);
apply(*block.statements);
}
}
void apply(ast::while_block const & node)
{
apply(*node.condition);
apply(*node.statements);
}
void apply(ast::function_definition const & node)
{
apply(*node.statements);
}
void apply(ast::foreign_function_declaration const & foreign_function_declaration)
{
if (!lcontext.foreign_address.contains(foreign_function_declaration.name))
{
lcontext.foreign_address[foreign_function_declaration.name] = pcontext.code.size();
push_bytes<void *>(nullptr);
}
}
void apply(ast::return_statement const & node)
{
if (node.value)
apply(*node.value);
}
void apply(ast::field_definition const &)
{}
void apply(ast::struct_definition const &)
{}
private:
template <typename T>
void push_bytes(T const & value)
{
auto begin = (std::uint8_t const *)(&value);
auto end = begin + sizeof(value);
pcontext.code.insert(pcontext.code.end(), begin, end);
}
};
std::optional<hfa_data> get_hfa_data(ast::struct_definition const & node, local_context & lcontext)
{
if (auto it = lcontext.structs.find(&node); it != lcontext.structs.end())
return it->second.hfa;
types::type_ptr type = nullptr;
std::size_t count = 0;
for (auto const & field : node.fields)
{
if (types::is_builtin_type(*field.inferred_type))
{
if (!types::is_floating_point_type(*field.inferred_type))
return std::nullopt;
if (type && !types::equal(*type, *field.inferred_type))
return std::nullopt;
type = field.inferred_type;
++count;
}
else if (auto struct_type = std::get_if<types::struct_type>(field.inferred_type.get()))
{
// NB: recursion must be impossible due to prior checks in type checker
if (auto subdata = get_hfa_data(*struct_type->node, lcontext))
{
if (type && !types::equal(*type, *subdata->type))
return std::nullopt;
type = subdata->type;
count += subdata->count;
}
else
return std::nullopt;
}
else
return std::nullopt;
}
if (count <= 4)
return hfa_data{type, count};
return std::nullopt;
}
// Iterate over a single scope (i.e. not visiting subscopes recursively)
// and add all defined functions & foreign functions to the current scope
struct populate_symbols_visitor
: ast::const_statement_visitor<populate_symbols_visitor>
{
local_context & lcontext;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const &) {}
void apply(ast::while_block const &) {}
void apply(ast::function_definition const & node)
{
lcontext.scopes.back().functions[node.name] = &node;
}
void apply(ast::foreign_function_declaration const & node)
{
lcontext.scopes.back().foreign_functions.insert(node.name);
}
void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const & node)
{
lcontext.scopes.back().structs[node.name] = &node;
if (!lcontext.structs.contains(&node))
{
// NB: make sure not to add struct to lcontext.structs before computing hfa data
auto hfa = get_hfa_data(node, lcontext);
lcontext.structs[&node].hfa = hfa;
}
}
};
struct reg_extend_visitor
: types::const_visitor<reg_extend_visitor>
{
using const_visitor::apply;
instruction_builder & builder;
std::uint8_t reg;
void apply(types::bool_type const &)
{}
void apply(types::f32_type const &)
{}
void apply(types::f64_type const &)
{}
template <typename T>
void apply(types::primitive_type_base<T> const &)
{
if constexpr (sizeof(T) == 8)
{
return;
}
if constexpr (std::is_signed_v<T>)
{
builder.sbfm(reg, reg, sizeof(T) * 8);
}
if constexpr (std::is_unsigned_v<T>)
{
builder.ubfm(reg, reg, sizeof(T) * 8);
}
}
template <typename T>
void apply(T const &)
{
throw std::runtime_error(std::string("reg_extend_visitor is not implemented for ") + typeid(T).name());
}
};
// Compile a single function and store the entry point offset
// in local_context
struct compile_function_visitor
: ast::const_statement_visitor<compile_function_visitor>
, ast::const_expression_visitor<compile_function_visitor>
{
using const_statement_visitor::apply;
using const_expression_visitor::apply;
program_context & pcontext;
local_context & lcontext;
instruction_builder builder{pcontext.code};
// Difference between initial stack pointer at function enter
// and current virtual stack pointer value. The actual stack pointer
// value is rounded down to a multiple of 16
std::uint32_t stack_offset = 0;
struct variable_data
{
// Difference between initial stack pointer at function enter
// and the variable address
// Must be a multiple of 16
std::uint32_t frame_offset;
};
struct scope
{
std::unordered_map<std::string, variable_data> variables = {};
// Difference between initial virtual stack pointer at scope enter
// and current virtual stack pointer value
std::uint32_t stack_offset = 0;
};
std::vector<scope> scopes;
template <typename Node>
void apply(Node const &)
{
throw std::runtime_error(std::string("compile_function_visitor is not implemented for ") + typeid(Node).name());
}
void apply(ast::bool_literal const & node)
{
if (node.value)
set_m1(0);
else
builder.movz(0, 0);
}
template <typename T>
requires(std::is_integral_v<T> && !std::is_same_v<T, bool>)
void apply(ast::primitive_literal_base<T> const & node)
{
for (std::size_t i = 0; i < sizeof(T); i += 2)
{
if (i == 0)
{
builder.movz(0, std::uint64_t(node.value));
}
else
{
auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8));
if (val != 0)
builder.movk(0, val, i / 2);
}
}
if (sizeof(T) < 8)
{
if constexpr (std::is_signed_v<T>)
{
if (node.value < 0)
builder.sbfm(0, 0, sizeof(T) * 8);
}
}
}
void apply(ast::f16_literal const & node)
{
auto offset = lcontext.f16_constants.at(node.value.repr);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
builder.fcvt(0, 0b10, 0, 0b01);
}
void apply(ast::f32_literal const & node)
{
auto offset = lcontext.f32_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
}
void apply(ast::f64_literal const & node)
{
auto offset = lcontext.f64_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 1, (offset - current) / 4);
}
void apply(ast::identifier const & node)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->variables.find(node.name); jt != it->variables.end())
{
if (auto struct_type = std::get_if<types::struct_type>(node.inferred_type.get()))
{
std::size_t stack_size = ((struct_type->node->layout.size + 15) / 16) * 16;
builder.sub_imm(31, 31, stack_size);
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
std::size_t variable_offset = stack_offset - jt->second.frame_offset;
for (std::size_t offset = 0; offset < stack_size; offset += 16)
{
builder.ldr(0, 31, (variable_offset + offset) / 8);
builder.ldr(1, 31, (variable_offset + offset) / 8 + 1);
builder.str(0, 31, offset / 8);
builder.str(1, 31, offset / 8 + 1);
}
}
else if (types::is_unit_type(*node.inferred_type))
{}
else if (types::is_floating_point_type(*node.inferred_type))
builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / type_size(*node.inferred_type));
else
builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8);
return;
}
}
if (lcontext.is_foreign(node.name))
{
builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4);
}
else if (auto function_node = lcontext.is_function(node.name))
{
lcontext.resolve.push_back({function_node, (std::int32_t)pcontext.code.size()});
builder.adr(0, 0);
}
else
{
throw std::runtime_error("unknown identifier \"" + node.name + "\"");
}
}
void apply(ast::unary_operation const & node)
{
switch (node.type)
{
case ast::unary_operation_type::negation:
apply(*node.arg1);
if (types::is_integer_type(*node.inferred_type))
{
builder.sub_reg(31, 0, 0);
extend(0, node.inferred_type);
}
else if (types::is_floating_point_type(*node.inferred_type))
{
builder.fneg(0, fp_mode_for(*node.inferred_type), 0);
}
break;
case ast::unary_operation_type::logical_not:
apply(*node.arg1);
builder.or_not_reg(31, 0, 0);
if (types::is_integer_type(*node.inferred_type))
extend(0, node.inferred_type);
break;
}
}
void apply(ast::binary_operation const & node)
{
auto arg1_type = ast::get_type(*node.arg1);
bool const is_fp = types::is_floating_point_type(*arg1_type);
std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
apply(*node.arg1);
if (!is_short_circuiting(node.type))
{
if (is_fp)
{
push_fp(0, fp_mode);
apply(*node.arg2);
pop_fp(1, fp_mode);
}
else
{
push(0);
apply(*node.arg2);
pop(1);
}
}
switch (node.type)
{
case ast::binary_operation_type::addition:
if (is_fp)
builder.fadd(1, 0, fp_mode, 0);
else
{
builder.add_reg(1, 0, 0);
extend(0, node.inferred_type);
}
break;
case ast::binary_operation_type::subtraction:
if (is_fp)
builder.fsub(1, 0, fp_mode, 0);
else
{
builder.sub_reg(1, 0, 0);
extend(0, node.inferred_type);
}
break;
case ast::binary_operation_type::multiplication:
if (is_fp)
builder.fmul(1, 0, fp_mode, 0);
else
{
builder.mul_reg(1, 0, 0);
extend(0, node.inferred_type);
}
break;
case ast::binary_operation_type::division:
if (is_fp)
builder.fdiv(1, 0, fp_mode, 0);
else
{
if (types::is_signed_integer_type(*node.inferred_type))
builder.sdiv_reg(1, 0, 0);
else
builder.udiv_reg(1, 0, 0);
extend(0, node.inferred_type);
}
break;
case ast::binary_operation_type::remainder:
if (types::is_signed_integer_type(*node.inferred_type))
{
builder.sdiv_reg(1, 0, 2);
builder.mul_reg(0, 2, 0);
builder.sub_reg(1, 0, 0);
}
else if (types::is_unsigned_integer_type(*node.inferred_type))
{
builder.udiv_reg(1, 0, 2);
builder.mul_reg(0, 2, 0);
builder.sub_reg(1, 0, 0);
}
break;
case ast::binary_operation_type::binary_and:
builder.and_reg(1, 0, 0);
break;
case ast::binary_operation_type::logical_and:
{
std::int32_t start = pcontext.code.size();
builder.cbz(0, 0);
push(0);
apply(*node.arg2);
pop(1);
builder.and_reg(1, 0, 0);
std::int32_t end = pcontext.code.size();
builder.cb_inject(pcontext.code.data() + start, (end - start) / 4);
}
break;
case ast::binary_operation_type::binary_or:
builder.or_reg(1, 0, 0);
break;
case ast::binary_operation_type::logical_or:
{
set_m1(1);
extend(1, arg1_type);
builder.xor_reg(0, 1, 1);
std::int32_t start = pcontext.code.size();
builder.cbz(1, 0);
push(0);
apply(*node.arg2);
pop(1);
builder.or_reg(1, 0, 0);
std::int32_t end = pcontext.code.size();
builder.cb_inject(pcontext.code.data() + start, (end - start) / 4);
}
break;
case ast::binary_operation_type::logical_xor:
builder.xor_reg(1, 0, 0);
break;
case ast::binary_operation_type::equals:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0000);
}
else
{
builder.cmp_reg(1, 0);
builder.csetm(0, 0b0000);
}
break;
case ast::binary_operation_type::not_equals:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0001);
}
else
{
builder.cmp_reg(1, 0);
builder.csetm(0, 0b0001);
}
break;
case ast::binary_operation_type::less:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0100);
}
else
{
builder.cmp_reg(0, 1);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break;
case ast::binary_operation_type::greater:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0100);
}
else
{
builder.cmp_reg(1, 0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break;
case ast::binary_operation_type::less_equals:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b1001);
}
else
{
builder.cmp_reg(1, 0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break;
case ast::binary_operation_type::greater_equals:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b1001);
}
else
{
builder.cmp_reg(0, 1);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break;
default:
{
std::ostringstream os;
os << "binary operation " << node.type << " is not implemented";
throw std::runtime_error(os.str());
}
}
}
void apply(ast::cast_operation const & node)
{
auto src_type = ast::get_type(*node.expression);
auto dst_type = node.inferred_type;
apply(*node.expression);
if (types::equal(*src_type, *dst_type))
return;
if (types::is_integer_type(*src_type))
{
if (types::is_integer_type(*dst_type))
{
extend(0, dst_type);
}
else if (types::is_floating_point_type(*dst_type))
{
auto dst_mode = fp_mode_for(*dst_type);
if (types::is_signed_integer_type(*src_type))
{
builder.fmov(0, 0, 3, 1);
builder.scvtf(0, 0, 3);
if (dst_mode != 3)
builder.fcvt(0, 3, 0, dst_mode);
}
else if (types::is_unsigned_integer_type(*src_type))
{
builder.fmov(0, 0, 3, 1);
builder.ucvtf(0, 0, 3);
if (dst_mode != 3)
builder.fcvt(0, 3, 0, dst_mode);
}
}
}
else if (types::is_floating_point_type(*src_type))
{
auto src_mode = fp_mode_for(*src_type);
if (types::is_integer_type(*dst_type))
{
if (types::is_signed_integer_type(*dst_type))
{
builder.fcvtns(0, 0, src_mode);
extend(0, dst_type);
}
else if (types::is_unsigned_integer_type(*dst_type))
{
builder.fcvtnu(0, 0, src_mode);
extend(0, dst_type);
}
}
else if (types::is_floating_point_type(*dst_type))
{
auto dst_mode = fp_mode_for(*dst_type);
builder.fcvt(0, src_mode, 0, dst_mode);
}
}
}
void apply(ast::function_call const & node)
{
if (node.function)
{
apply(*node.function);
push(0);
for (std::size_t i = node.arguments.size(); i --> 0;)
{
auto const & arg = node.arguments[i];
apply(*arg);
auto type = ast::get_type(*arg);
if (types::is_bool_type(*type) || types::is_integer_type(*type))
{
push(0);
}
else if (types::is_floating_point_type(*type))
{
push_fp(0, fp_mode_for(*type));
}
}
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (auto const & arg : node.arguments)
{
auto type = ast::get_type(*arg);
if (types::is_bool_type(*type) || types::is_integer_type(*type))
{
pop(reg);
++reg;
}
else if (types::is_floating_point_type(*type))
{
pop_fp(fp_reg, fp_mode_for(*type));
++fp_reg;
}
}
pop(reg);
push(30);
builder.b_reg(reg);
pop(30);
}
else // if (node.type)
{
if (types::is_unit_type(*node.inferred_type))
{
// Do nothing
}
else if (types::is_bool_type(*node.inferred_type) || types::is_integer_type(*node.inferred_type) || types::is_function_type(*node.inferred_type))
{
builder.xor_reg(0, 0, 0);
}
else if (types::is_floating_point_type(*node.inferred_type))
{
builder.xor_reg(0, 0, 0);
builder.fmov(0, 0, fp_mode_for(*node.inferred_type), 1);
}
else if (auto struct_type = std::get_if<types::struct_type>(node.inferred_type.get()))
{
auto & struct_node = *struct_type->node;
// Allocate stack space for the struct
std::size_t stack_size = ((struct_node.layout.size + 15) / 16) * 16;
auto offset = stack_offset;
stack_offset += stack_size;
scopes.back().stack_offset += stack_size;
builder.sub_imm(31, 31, stack_size);
// Evaluate each field of the struct (i.e. each constructor argument)
// and copy it to the corresponding place in the struct
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
auto type = ast::get_type(*node.arguments[i]);
apply(*node.arguments[i]);
if (std::get_if<types::struct_type>(type.get()))
{
// TODO: struct field
throw std::runtime_error("Not implemented");
}
else if (types::is_floating_point_type(*type))
{
builder.stur_fp(0, fp_mode_for(*type), 31, struct_node.fields[i].layout.offset);
}
else
{
auto size = types::type_size(*type);
if (size == 1)
builder.sturb(0, 31, struct_node.fields[i].layout.offset);
else if (size == 2)
builder.sturh(0, 31, struct_node.fields[i].layout.offset);
else if (size == 4)
builder.sturw(0, 31, struct_node.fields[i].layout.offset);
else if (size == 8)
builder.stur(0, 31, struct_node.fields[i].layout.offset);
}
}
}
}
}
void apply(ast::field_access const & node)
{
auto object_type = get_type(*node.object);
if (auto struct_type = std::get_if<types::struct_type>(object_type.get()))
{
auto & struct_node = *struct_type->node;
std::optional<std::size_t> field_id;
for (std::size_t i = 0; i < struct_node.fields.size(); ++i)
{
if (struct_node.fields[i].name == node.field_name)
{
field_id = i;
break;
}
}
if (!field_id)
throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + struct_node.name + "\"");
apply(*node.object);
auto stack_size = ((struct_node.layout.size + 15) / 16) * 16;
auto const & field = struct_node.fields[*field_id];
if (types::is_unit_type(*field.inferred_type))
{}
else if (types::is_floating_point_type(*field.inferred_type))
{
builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type))
{
auto size = types::type_size(*field.inferred_type);
if (size == 1)
builder.ldurb(0, 31, field.layout.offset);
else if (size == 2)
builder.ldurh(0, 31, field.layout.offset);
else if (size == 4)
builder.ldurw(0, 31, field.layout.offset);
else if (size == 8)
builder.ldur(0, 31, field.layout.offset);
builder.add_imm(31, 31, stack_size);
stack_offset -= stack_size;
scopes.back().stack_offset -= stack_size;
}
else if (auto struct_type = std::get_if<types::struct_type>(field.inferred_type.get()))
{
// TODO: copy the struct-typed field on stack, overriding
// the struct itself, and update the stack offset
throw std::runtime_error("Not implemented");
}
return;
}
throw std::runtime_error("Unknown object in field access");
}
void apply(ast::expression_ptr const & node)
{
auto stack_offset_before = stack_offset;
apply(*node);
// Restore stack offset in case the expression evaluated to a struct
// (in which case the struct would be placed on the stack)
auto stack_delta = stack_offset - stack_offset_before;
if (stack_delta > 0)
{
builder.add_imm(31, 31, stack_delta);
stack_offset -= stack_delta;
scopes.back().stack_offset -= stack_delta;
}
}
void apply(ast::assignment const & node)
{
auto frame_offset = lvalue_offset(node.lhs);
apply(*node.rhs);
auto type = ast::get_type(*node.rhs);
if (types::is_unit_type(*type))
{}
else if (types::is_floating_point_type(*type))
builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - frame_offset) / type_size(*type));
else if (types::is_bool_type(*type) || types::is_integer_type(*type) || types::is_function_type(*type))
{
auto size = types::type_size(*type);
if (size == 1)
builder.sturb(0, 31, stack_offset - frame_offset);
else if (size == 2)
builder.sturh(0, 31, stack_offset - frame_offset);
else if (size == 4)
builder.sturw(0, 31, stack_offset - frame_offset);
else if (size == 8)
builder.stur(0, 31, stack_offset - frame_offset);
}
else if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
// TODO: whole-struct assignment
throw std::runtime_error("Not implemented");
}
}
void apply(ast::variable_declaration const & node)
{
apply(*node.initializer);
auto type = ast::get_type(*node.initializer);
if (std::get_if<types::struct_type>(type.get()))
{
// Nothing to be done: the struct is already on the stack
// Just record the stack offset as variable location
}
else if (types::is_floating_point_type(*type))
push_fp(0, fp_mode_for(*type));
else if (types::is_unit_type(*type))
{
// Nothing to be done: unit type has zero size
// Its stack position is recorded but de facto unused
}
else
push(0);
scopes.back().variables[node.name] = {.frame_offset = stack_offset};
}
void apply(ast::if_chain const & node)
{
std::vector<std::size_t> branch_to_end;
for (std::size_t i = 0; i < node.blocks.size(); ++i)
{
auto const & block = node.blocks[i];
std::optional<std::size_t> branch_skip;
if (block.condition)
{
apply(*block.condition);
branch_skip = pcontext.code.size();
builder.cbz(0, 0);
}
scopes.emplace_back();
apply(*block.statements);
scope_cleanup();
scopes.pop_back();
if (i + 1 < node.blocks.size())
{
branch_to_end.push_back(pcontext.code.size());
builder.b(0);
}
if (branch_skip)
{
auto branch_offset = pcontext.code.size() - *branch_skip;
builder.cb_inject(pcontext.code.data() + *branch_skip, branch_offset / 4);
}
}
auto end = pcontext.code.size();
for (auto instruction : branch_to_end)
{
auto delta = end - instruction;
builder.b_inject(pcontext.code.data() + instruction, delta / 4);
}
}
void apply(ast::while_block const & node)
{
std::int32_t start = pcontext.code.size();
apply(*node.condition);
std::int32_t skip = pcontext.code.size();
builder.cbz(0, 0);
scopes.emplace_back();
apply(*node.statements);
scope_cleanup();
scopes.pop_back();
std::int32_t loop = pcontext.code.size();
builder.b(0);
std::int32_t end = pcontext.code.size();
builder.cb_inject(pcontext.code.data() + skip, (end - skip) / 4);
builder.b_inject(pcontext.code.data() + loop, (start - loop) / 4);
}
void apply(ast::return_statement const & node)
{
// TODO: struct return value
if (node.value)
apply(*node.value);
do_return();
}
void apply(ast::function_definition const &)
{
// Must be handled prior to that in populate_symbols_visitor
}
void apply(ast::foreign_function_declaration const &)
{
// Must be handled prior to that in populate_symbols_visitor
}
void apply(ast::struct_definition const &)
{
// Must be handled prior to that in populate_symbols_visitor
}
void apply(ast::statement_list const & node)
{
lcontext.scopes.emplace_back();
populate_symbols_visitor{{}, lcontext}.apply(node);
for (auto const & statement : node.statements)
apply(*statement);
lcontext.scopes.pop_back();
}
void do_apply(ast::function_definition const & node)
{
lcontext.functions[&node] = pcontext.code.size();
// TODO: struct arguments
scopes.emplace_back();
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (auto const & argument : node.arguments)
{
auto type = ast::get_type(*argument.type);
if (types::is_bool_type(*type))
{
builder.tst(reg, reg);
builder.csetm(reg, 0b0001);
push(reg);
++reg;
}
else if (types::is_integer_type(*type))
{
extend(reg, type);
push(reg);
++reg;
}
else if (types::is_floating_point_type(*type))
{
auto mode = fp_mode_for(*type);
push_fp(fp_reg, mode);
++fp_reg;
}
scopes.back().variables[argument.name] = {.frame_offset = stack_offset};
}
apply(*node.statements);
if (node.statements->statements.empty() || !std::get_if<ast::return_statement>(node.statements->statements.back().get()))
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
do_return();
scopes.pop_back();
}
void do_return()
{
if (stack_offset > 0)
builder.add_imm(31, 31, stack_offset);
builder.ret();
}
private:
void push(std::uint8_t reg)
{
builder.sub_imm(31, 31, 16);
builder.str(reg, 31, 0);
stack_offset += 16;
scopes.back().stack_offset += 16;
}
void pop(std::uint8_t reg)
{
builder.ldr(reg, 31, 0);
builder.add_imm(31, 31, 16);
stack_offset -= 16;
scopes.back().stack_offset -= 16;
}
void push_fp(std::uint8_t reg, std::uint8_t mode)
{
builder.sub_imm(31, 31, 16);
builder.str_fp(0, mode, 31, 0);
stack_offset += 16;
scopes.back().stack_offset += 16;
}
void pop_fp(std::uint8_t reg, std::uint8_t mode)
{
builder.ldr_fp(reg, mode, 31, 0);
builder.add_imm(31, 31, 16);
stack_offset -= 16;
scopes.back().stack_offset -= 16;
}
// Set register @reg to -1 (all bits = 1)
void set_m1(std::uint8_t reg)
{
builder.or_not_reg(31, 31, reg);
}
// Sign- or zero-extend the register depending on the exact type
void extend(std::uint8_t reg, types::type_ptr const & type)
{
reg_extend_visitor{{}, builder, reg}.apply(*type);
}
void scope_cleanup()
{
if (scopes.back().stack_offset > 0)
{
builder.add_imm(31, 31, scopes.back().stack_offset);
stack_offset -= scopes.back().stack_offset;
}
}
// Returns offset from function entry stack frame
std::size_t lvalue_offset(ast::expression_ptr const & node)
{
if (auto identifier = std::get_if<ast::identifier>(node.get()))
{
auto const & scope = scopes[identifier->level];
if (auto it = scope.variables.find(identifier->name); it != scope.variables.end())
return it->second.frame_offset;
throw std::runtime_error("Non-lvalue identifier: \"" + identifier->name + "\"");
}
else if (auto field_access = std::get_if<ast::field_access>(node.get()))
{
auto base_offset = lvalue_offset(field_access->object);
auto type = ast::get_type(*field_access->object);
if (auto struct_type = std::get_if<types::struct_type>(type.get()))
{
auto & struct_node = *struct_type->node;
std::optional<std::size_t> field_id;
for (std::size_t i = 0; i < struct_node.fields.size(); ++i)
if (struct_node.fields[i].name == field_access->field_name)
{
field_id = i;
break;
}
if (!field_id)
throw std::runtime_error("Invalid field \"" + field_access->field_name + "\"");
return base_offset - struct_node.fields[*field_id].layout.offset;
}
else
throw std::runtime_error("Invalid field access node");
}
else if (auto array_access = std::get_if<ast::array_access>(node.get()))
{
throw std::runtime_error("Not implemented");
}
throw std::runtime_error("Unknown lvalue node");
}
};
// Main compilation visitor
struct compile_visitor
: ast::const_statement_visitor<compile_visitor>
{
program_context & pcontext;
local_context & lcontext;
instruction_builder & builder;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const & node)
{
for (auto const & block : node.blocks)
apply(*block.statements);
}
void apply(ast::while_block const & node)
{
apply(*node.statements);
}
void apply(ast::function_definition const & node)
{
compile_function_visitor{{}, {}, pcontext, lcontext}.do_apply(node);
apply(*node.statements);
}
void apply(ast::foreign_function_declaration const &) {}
void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {}
void apply(ast::statement_list const & node)
{
lcontext.scopes.emplace_back();
populate_symbols_visitor{{}, lcontext}.apply(node);
for (auto const & statement : node.statements)
apply(*statement);
// Don't pop_back entry point scope
if (lcontext.scopes.size() > 1)
lcontext.scopes.pop_back();
}
};
}
void compile(program_context & pcontext, ast::statement_list_ptr const & statements)
{
// Add a fake AST node for the entry point
auto root = std::make_shared<ast::statement>(ast::function_definition{
{
{},
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
statements,
{},
});
local_context lcontext;
instruction_builder builder{pcontext.code};
populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements);
compile_visitor{{}, pcontext, lcontext, builder}.apply(*root);
for (auto const & resolve : lcontext.resolve)
builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, lcontext.functions.at(resolve.node) - resolve.instruction_offset);
for (auto const & foreign : lcontext.foreign_address)
pcontext.foreign_resolve.push_back({foreign.first, foreign.second});
for (auto const & function : lcontext.scopes.front().functions)
pcontext.symbols[function.first] = lcontext.functions.at(function.second);
pcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
}
}