#include #include #include #include #include #include #include #include #include #include namespace pslang::jit::aarch64 { namespace { struct local_context { std::unordered_map f16_constants; std::unordered_map f32_constants; std::unordered_map f64_constants; }; std::uint8_t fp_mode_for(types::type const & type) { if (types::equal(type, types::primitive_type(types::f16_type{}))) return 1; if (types::equal(type, types::primitive_type(types::f32_type{}))) return 2; return 3; } struct populate_constants_visitor : ast::const_expression_visitor , ast::const_statement_visitor { using const_expression_visitor::apply; using const_statement_visitor::apply; program_context & pcontext; local_context & lcontext; template requires(!std::is_floating_point_v) void apply(ast::primitive_literal_base const &) {} void apply(ast::f16_literal const & node) { if (!lcontext.f16_constants.contains(node.value.repr)) { lcontext.f16_constants[node.value.repr] = pcontext.code.size(); push_bytes(node.value.repr); } } void apply(ast::f32_literal const & node) { if (!lcontext.f32_constants.contains(node.value)) { lcontext.f32_constants[node.value] = pcontext.code.size(); push_bytes(node.value); } } void apply(ast::f64_literal const & node) { if (!lcontext.f64_constants.contains(node.value)) { lcontext.f64_constants[node.value] = pcontext.code.size(); push_bytes(node.value); } } void apply(ast::identifier const &) {} void apply(ast::unary_operation const & node) { apply(*node.arg1); } void apply(ast::binary_operation const & node) { apply(*node.arg1); apply(*node.arg2); } void apply(ast::cast_operation const & node) { apply(*node.expression); } void apply(ast::function_call const & node) { apply(*node.function); for (auto const & argument : node.arguments) apply(*argument); } void apply(ast::array const & node) { for (auto const & element : node.elements) apply(*element); } void apply(ast::array_access const & node) { apply(*node.array); apply(*node.index); } void apply(ast::field_access const & node) { apply(*node.object); } void apply(ast::expression_ptr const & node) { apply(*node); } void apply(ast::assignment const & node) { apply(*node.lhs); apply(*node.rhs); } void apply(ast::variable_declaration const & node) { apply(*node.initializer); } void apply(ast::if_block const &) {} void apply(ast::else_block const &) {} void apply(ast::else_if_block const &) {} void apply(ast::if_chain const & node) { for (auto const & block : node.blocks) { apply(*block.condition); apply(*block.statements); } } void apply(ast::while_block const & node) { apply(*node.condition); apply(*node.statements); } void apply(ast::function_definition const & node) { apply(*node.statements); } void apply(ast::return_statement const & node) { apply(*node.value); } void apply(ast::field_definition const &) {} void apply(ast::struct_definition const &) {} private: template void push_bytes(T const & value) { auto begin = (std::uint8_t const *)(&value); auto end = begin + sizeof(value); pcontext.code.insert(pcontext.code.end(), begin, end); } }; struct reg_extend_visitor : types::const_visitor { using const_visitor::apply; instruction_builder & builder; std::uint8_t reg; void apply(types::bool_type const &) {} void apply(types::f32_type const &) {} void apply(types::f64_type const &) {} template void apply(types::primitive_type_base const &) { if constexpr (sizeof(T) == 8) { return; } if constexpr (std::is_signed_v) { builder.sbfm(reg, reg, sizeof(T) * 8); } if constexpr (std::is_unsigned_v) { builder.ubfm(reg, reg, sizeof(T) * 8); } } template void apply(T const &) { throw std::runtime_error("Not implemented"); } }; struct compile_function_visitor : ast::const_statement_visitor , ast::const_expression_visitor { using const_statement_visitor::apply; using const_expression_visitor::apply; program_context & pcontext; local_context & lcontext; instruction_builder builder{pcontext.code}; // Difference between initial stack pointer at function enter // and current virtual stack pointer value. The actual stack pointer // value is rounded down to a multiple of 16 std::uint32_t stack_offset = 0; struct variable_data { // Difference between initial stack pointer at scope enter // and the variable address // Must be a multiple of 8 std::uint32_t frame_offset; }; struct scope { std::unordered_map variables = {}; // Difference between initial virtual stack pointer at scope enter // and current virtual stack pointer value std::uint32_t stack_offset = 0; }; std::vector scopes; template void apply(Node const &) { throw std::runtime_error("Not implemented"); } void apply(ast::bool_literal const & node) { if (node.value) set_m1(0); else builder.movz(0, 0); } template requires(std::is_integral_v && !std::is_same_v) void apply(ast::primitive_literal_base const & node) { for (std::size_t i = 0; i < sizeof(T); i += 2) { if (i == 0) { builder.movz(0, std::uint64_t(node.value)); } else { auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8)); if (val != 0) builder.movk(0, val, i / 2); } } if (sizeof(T) < 8) { if constexpr (std::is_signed_v) { if (node.value < 0) builder.sbfm(0, 0, sizeof(T) * 8); } } } void apply(ast::f16_literal const & node) { auto offset = lcontext.f16_constants.at(node.value.repr); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 0, (offset - current) / 4); builder.fcvt(0, 0b10, 0, 0b01); } void apply(ast::f32_literal const & node) { auto offset = lcontext.f32_constants.at(node.value); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 0, (offset - current) / 4); } void apply(ast::f64_literal const & node) { auto offset = lcontext.f64_constants.at(node.value); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 1, (offset - current) / 4); } void apply(ast::identifier const & node) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->variables.find(node.name); jt != it->variables.end()) { if (types::is_floating_point_type(*node.inferred_type)) builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / builtin_type_size(*node.inferred_type)); else builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8); break; } } } void apply(ast::unary_operation const & node) { switch (node.type) { case ast::unary_operation_type::negation: apply(*node.arg1); if (types::is_integer_type(*node.inferred_type)) { builder.sub_reg(31, 0, 0); extend(0, node.inferred_type); } else if (types::is_floating_point_type(*node.inferred_type)) { builder.fneg(0, fp_mode_for(*node.inferred_type), 0); } break; case ast::unary_operation_type::logical_not: apply(*node.arg1); builder.or_not_reg(31, 0, 0); break; } } void apply(ast::binary_operation const & node) { auto arg1_type = ast::get_type(*node.arg1); bool const is_fp = types::is_floating_point_type(*arg1_type); std::uint8_t const fp_mode = fp_mode_for(*arg1_type); if (is_fp) { apply(*node.arg1); push_fp(0, fp_mode); apply(*node.arg2); pop_fp(1, fp_mode); } else { apply(*node.arg1); push(0); apply(*node.arg2); pop(1); } switch (node.type) { case ast::binary_operation_type::addition: if (is_fp) builder.fadd(1, 0, fp_mode, 0); else { builder.add_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::subtraction: if (is_fp) builder.fsub(1, 0, fp_mode, 0); else { builder.sub_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::multiplication: if (is_fp) builder.fmul(1, 0, fp_mode, 0); else { builder.mul_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::division: if (is_fp) builder.fdiv(1, 0, fp_mode, 0); else { if (types::is_signed_integer_type(*node.inferred_type)) builder.sdiv_reg(1, 0, 0); else builder.udiv_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::remainder: // TODO: implement via div & mul & sub throw std::runtime_error("Not implemented"); case ast::binary_operation_type::logical_and: builder.and_reg(1, 0, 0); break; case ast::binary_operation_type::logical_or: builder.or_reg(1, 0, 0); break; case ast::binary_operation_type::logical_xor: builder.xor_reg(1, 0, 0); break; case ast::binary_operation_type::equals: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b0000); } else { builder.cmp_reg(1, 0); builder.csetm(0, 0b0000); } break; case ast::binary_operation_type::not_equals: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b0001); } else { builder.cmp_reg(1, 0); builder.csetm(0, 0b0001); } break; case ast::binary_operation_type::less: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b0100); } else { builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); } break; case ast::binary_operation_type::greater: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b0100); } else { builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); } break; case ast::binary_operation_type::less_equals: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b1001); } else { builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); } break; case ast::binary_operation_type::greater_equals: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b1001); } else { builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); } break; default: throw std::runtime_error("Not implemented"); } } void apply(ast::cast_operation const & node) { auto src_type = ast::get_type(*node.expression); auto dst_type = node.inferred_type; apply(*node.expression); if (types::equal(*src_type, *dst_type)) return; if (types::is_integer_type(*src_type)) { if (types::is_integer_type(*dst_type)) { extend(0, dst_type); } else if (types::is_floating_point_type(*dst_type)) { auto dst_mode = fp_mode_for(*dst_type); if (types::is_signed_integer_type(*src_type)) { builder.fmov(0, 0, 3, 1); builder.scvtf(0, 0, 3); if (dst_mode != 3) builder.fcvt(0, 3, 0, dst_mode); } else if (types::is_unsigned_integer_type(*src_type)) { builder.fmov(0, 0, 3, 1); builder.ucvtf(0, 0, 3); if (dst_mode != 3) builder.fcvt(0, 3, 0, dst_mode); } } } else if (types::is_floating_point_type(*src_type)) { auto src_mode = fp_mode_for(*src_type); if (types::is_integer_type(*dst_type)) { if (types::is_signed_integer_type(*dst_type)) { builder.fcvtns(0, 0, src_mode); extend(0, dst_type); } else if (types::is_unsigned_integer_type(*dst_type)) { builder.fcvtnu(0, 0, src_mode); extend(0, dst_type); } } else if (types::is_floating_point_type(*dst_type)) { auto dst_mode = fp_mode_for(*dst_type); builder.fcvt(0, src_mode, 0, dst_mode); } } } void apply(ast::assignment const & node) { auto identifier = std::get_if(node.lhs.get()); if (!identifier) throw std::runtime_error("Not implemented"); apply(*node.rhs); for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->variables.find(identifier->name); jt != it->variables.end()) { auto type = ast::get_type(*node.rhs); if (types::is_floating_point_type(*type)) builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - jt->second.frame_offset) / builtin_type_size(*type)); else builder.str(0, 31, (stack_offset - jt->second.frame_offset) / 8); break; } } } void apply(ast::variable_declaration const & node) { apply(*node.initializer); auto type = ast::get_type(*node.initializer); if (types::is_floating_point_type(*type)) push_fp(0, fp_mode_for(*type)); else push(0); scopes.back().variables[node.name] = {.frame_offset = stack_offset}; } void apply(ast::if_chain const & node) { std::vector branch_to_end; for (std::size_t i = 0; i < node.blocks.size(); ++i) { auto const & block = node.blocks[i]; std::optional branch_skip; if (block.condition) { apply(*block.condition); branch_skip = pcontext.code.size(); builder.cbz(0, 0); } scopes.emplace_back(); apply(*block.statements); scope_cleanup(); scopes.pop_back(); if (i + 1 < node.blocks.size()) { branch_to_end.push_back(pcontext.code.size()); builder.b(0); } if (branch_skip) { auto branch_offset = pcontext.code.size() - *branch_skip; builder.cb_inject(pcontext.code.data() + *branch_skip, branch_offset / 4); } } auto end = pcontext.code.size(); for (auto instruction : branch_to_end) { auto delta = end - instruction; builder.b_inject(pcontext.code.data() + instruction, delta / 4); } } void apply(ast::while_block const & node) { std::int32_t start = pcontext.code.size(); apply(*node.condition); std::int32_t skip = pcontext.code.size(); builder.cbz(0, 0); scopes.emplace_back(); apply(*node.statements); scope_cleanup(); scopes.pop_back(); std::int32_t loop = pcontext.code.size(); builder.b(0); std::int32_t end = pcontext.code.size(); builder.cb_inject(pcontext.code.data() + skip, (end - skip) / 4); builder.b_inject(pcontext.code.data() + loop, (start - loop) / 4); } void apply(ast::return_statement const & node) { apply(*node.value); if (stack_offset > 0) builder.add_imm(31, 31, stack_offset); builder.ret(); } void apply(ast::function_definition const & node) { // Don't handle internal functions } void apply(ast::statement_list const & node) { for (auto const & statement : node.statements) apply(*statement); } void do_apply(ast::function_definition const & node) { // TODO: floating-point / struct arguments scopes.emplace_back(); std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (auto const & argument : node.arguments) { auto type = ast::get_type(*argument.type); if (types::is_bool_type(*type)) { builder.tst(reg, reg); builder.csetm(reg, 0b0001); push(reg); ++reg; } else if (types::is_integer_type(*type)) { extend(reg, type); push(reg); ++reg; } else if (types::is_floating_point_type(*type)) { auto mode = fp_mode_for(*type); push_fp(fp_reg, mode); ++fp_reg; } scopes.back().variables[argument.name] = {.frame_offset = stack_offset}; } apply(*node.statements); scopes.pop_back(); } private: void push(std::uint8_t reg) { builder.sub_imm(31, 31, 16); builder.str(reg, 31, 0); stack_offset += 16; scopes.back().stack_offset += 16; } void pop(std::uint8_t reg) { builder.ldr(reg, 31, 0); builder.add_imm(31, 31, 16); stack_offset -= 16; scopes.back().stack_offset -= 16; } void push_fp(std::uint8_t reg, std::uint8_t mode) { builder.sub_imm(31, 31, 16); builder.str_fp(0, mode, 31, 0); stack_offset += 16; scopes.back().stack_offset += 16; } void pop_fp(std::uint8_t reg, std::uint8_t mode) { builder.ldr_fp(reg, mode, 31, 0); builder.add_imm(31, 31, 16); stack_offset -= 16; scopes.back().stack_offset -= 16; } // Set register @reg to -1 (all bits = 1) void set_m1(std::uint8_t reg) { builder.or_not_reg(31, 31, reg); } // Sign- or zero-extend the register depending on the exact type void extend(std::uint8_t reg, types::type_ptr const & type) { reg_extend_visitor{{}, builder, reg}.apply(*type); } void scope_cleanup() { if (scopes.back().stack_offset > 0) builder.add_imm(31, 31, scopes.back().stack_offset); } }; struct compile_visitor : ast::const_statement_visitor { using const_statement_visitor::apply; program_context & pcontext; local_context & lcontext; instruction_builder builder{pcontext.code}; template void apply(Statement const &) { throw std::runtime_error("Not implemented"); } void apply(ast::function_definition const & node) { pcontext.symbols[node.name] = pcontext.code.size(); compile_function_visitor visitor{{}, {}, pcontext, lcontext}; visitor.do_apply(node); } }; } void compile(program_context & pcontext, ast::statement_list_ptr const & statements) { local_context lcontext; populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements); compile_visitor{{}, pcontext, lcontext}.apply(*statements); } }