#include "pslang/ast/statement_fwd.hpp" #include #include #include #include #include #include #include #include #include #include namespace pslang::jit::aarch64 { namespace { struct context { std::vector code; std::unordered_map code_symbol_table; }; struct reg_extend_visitor : types::const_visitor { using const_visitor::apply; instruction_builder & builder; std::uint8_t reg; void apply(types::bool_type const &) {} void apply(types::f32_type const &) {} void apply(types::f64_type const &) {} template void apply(types::primitive_type_base const &) { if constexpr (sizeof(T) == 8) { return; } if constexpr (std::is_signed_v) { builder.sbfm(reg, reg, sizeof(T) * 8); } if constexpr (std::is_unsigned_v) { builder.ubfm(reg, reg, sizeof(T) * 8); } } template void apply(T const &) { throw std::runtime_error("Not implemented"); } }; struct compile_function_visitor : ast::const_statement_visitor , ast::const_expression_visitor { using const_statement_visitor::apply; using const_expression_visitor::apply; context & context; instruction_builder builder{context.code}; // Difference between initial stack pointer at function enter // and current virtual stack pointer value. The actual stack pointer // value is rounded down to a multiple of 16 std::uint32_t stack_offset = 0; struct variable_data { // Difference between initial stack pointer at scope enter // and the variable address // Must be a multiple of 8 std::uint32_t frame_offset; }; struct scope { std::unordered_map variables = {}; // Difference between initial virtual stack pointer at scope enter // and current virtual stack pointer value std::uint32_t stack_offset = 0; }; std::vector scopes; template void apply(Node const &) { throw std::runtime_error("Not implemented"); } void apply(ast::bool_literal const & node) { if (node.value) set_m1(0); else builder.movz(0, 0); } template requires(std::is_integral_v && !std::is_same_v) void apply(ast::primitive_literal_base const & node) { for (std::size_t i = 0; i < sizeof(T); i += 2) { if (i == 0) { builder.movz(0, std::uint64_t(node.value)); } else { auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 16)); if (val != 0) builder.movk(0, val, i / 2); } } if (sizeof(T) < 8) { if constexpr (std::is_signed_v) { if (node.value < 0) builder.sbfm(0, 0, sizeof(T) * 8); } } } void apply(ast::identifier const & node) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->variables.find(node.name); jt != it->variables.end()) { builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8); break; } } } void apply(ast::unary_operation const & node) { // TODO: floating-point switch (node.type) { case ast::unary_operation_type::negation: apply(*node.arg1); builder.sub_reg(31, 0, 0); extend(0, node.inferred_type); break; case ast::unary_operation_type::logical_not: apply(*node.arg1); builder.or_not_reg(31, 0, 0); break; } } void apply(ast::binary_operation const & node) { // TODO: floating-point switch (node.type) { case ast::binary_operation_type::addition: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.add_reg(1, 0, 0); extend(0, node.inferred_type); break; case ast::binary_operation_type::subtraction: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.sub_reg(1, 0, 0); extend(0, node.inferred_type); break; case ast::binary_operation_type::multiplication: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.mul_reg(1, 0, 0); extend(0, node.inferred_type); break; case ast::binary_operation_type::division: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); if (types::is_signed_integer_type(*node.inferred_type)) builder.sdiv_reg(1, 0, 0); else builder.udiv_reg(1, 0, 0); extend(0, node.inferred_type); break; case ast::binary_operation_type::remainder: // TODO: implement via div & mul & sub throw std::runtime_error("Not implemented"); case ast::binary_operation_type::logical_and: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.and_reg(1, 0, 0); break; case ast::binary_operation_type::logical_or: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.or_reg(1, 0, 0); break; case ast::binary_operation_type::logical_xor: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.xor_reg(1, 0, 0); break; case ast::binary_operation_type::equals: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.cmp_reg(1, 0); builder.csetm(0, 0b0000); break; case ast::binary_operation_type::not_equals: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.cmp_reg(1, 0); builder.csetm(0, 0b0001); break; case ast::binary_operation_type::less: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); break; case ast::binary_operation_type::greater: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); break; case ast::binary_operation_type::less_equals: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); break; case ast::binary_operation_type::greater_equals: apply(*node.arg1); push(0); apply(*node.arg2); pop(1); builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); break; default: throw std::runtime_error("Not implemented"); } } void apply(ast::assignment const & node) { auto identifier = std::get_if(node.lhs.get()); if (!identifier) throw std::runtime_error("Not implemented"); apply(*node.rhs); for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->variables.find(identifier->name); jt != it->variables.end()) { builder.str(0, 31, (stack_offset - jt->second.frame_offset) / 8); break; } } } void apply(ast::variable_declaration const & node) { apply(*node.initializer); push(0); scopes.back().variables[node.name] = {.frame_offset = stack_offset}; } void apply(ast::return_statement const & node) { apply(*node.value); if (stack_offset > 0) builder.add_imm(31, 31, stack_offset); builder.ret(); } void apply(ast::function_definition const & node) { // Don't handle internal functions } void apply(ast::statement_list const & node) { for (auto const & statement : node.statements) apply(*statement); builder.add_imm(31, 31, scopes.back().stack_offset); } void do_apply(ast::function_definition const & node) { // TODO: floating-point / struct arguments scopes.emplace_back(); for (std::size_t i = 0; i < node.arguments.size(); ++i) { extend(i, ast::get_type(*node.arguments[i].type)); push(i); scopes.back().variables[node.arguments[i].name] = {.frame_offset = stack_offset}; } apply(*node.statements); scopes.pop_back(); } private: void push(std::uint8_t reg) { builder.sub_imm(31, 31, 16); builder.str(reg, 31, 0); stack_offset += 16; scopes.back().stack_offset += 16; } void pop(std::uint8_t reg) { builder.ldr(reg, 31, 0); builder.add_imm(31, 31, 16); stack_offset -= 16; scopes.back().stack_offset -= 16; } // Set register @reg to -1 (all bits = 1) void set_m1(std::uint8_t reg) { builder.or_not_reg(31, 31, reg); } // Sign- or zero-extend the register depending on the exact type void extend(std::uint8_t reg, types::type_ptr const & type) { reg_extend_visitor{{}, builder, reg}.apply(*type); } }; struct compile_visitor : ast::const_statement_visitor { using const_statement_visitor::apply; context & context; instruction_builder builder{context.code}; template void apply(Statement const &) { throw std::runtime_error("Not implemented"); } void apply(ast::function_definition const & node) { context.code_symbol_table[node.name] = context.code.size(); compile_function_visitor visitor{{}, {}, context}; visitor.do_apply(node); } }; } compiled_module compile(ast::statement_list_ptr const & statements) { context context; compile_visitor visitor{{}, context}; visitor.apply(*statements); auto code = allocate(context.code.size()); std::copy(context.code.data(), context.code.data() + context.code.size(), code.data.get()); return compiled_module { .data = {}, .code = { .memory = std::move(code), .symbol_table = std::move(context.code_symbol_table), }, .entry_point = 0, .abi = abi::armv8, }; } }