#include #include #include #include #include #include #include #include #include namespace pslang::jit::aarch64 { namespace { struct local_context { bool use_frame_pointer = true; std::unordered_map extern_symbols; std::unordered_map nodes; std::unordered_map f16_constants; std::unordered_map f32_constants; std::unordered_map f64_constants; struct resolve_data { std::int32_t offset; ir::node_ref target; }; std::vector branch_resolve; std::vector cbranch_resolve; std::vector adr_resolve; }; std::uint8_t fp_mode_for(types::type const & type) { if (types::equal(type, types::primitive_type(types::f16_type{}))) return 1; if (types::equal(type, types::primitive_type(types::f32_type{}))) return 2; return 3; } std::int32_t fp_size(std::uint8_t mode) { return 1 << mode; } struct populate_const_data_visitor { program_context & pcontext; local_context & lcontext; template void apply(Node const & node, types::type_ptr const &) {} void apply(ir::literal const & node, types::type_ptr const &) { if (auto f16_literal = std::get_if(&node.value)) { lcontext.f16_constants[f16_literal->value.repr] = pcontext.code.size(); push_bytes(f16_literal->value.repr); } else if (auto f32_literal = std::get_if(&node.value)) { lcontext.f32_constants[f32_literal->value] = pcontext.code.size(); push_bytes(f32_literal->value); } else if (auto f64_literal = std::get_if(&node.value)) { lcontext.f32_constants[f64_literal->value] = pcontext.code.size(); push_bytes(f64_literal->value); } } void apply(ir::extern_symbol const & node, types::type_ptr const &) { std::int32_t offset = pcontext.code.size(); lcontext.extern_symbols[node.name] = offset; pcontext.foreign_resolve.push_back({node.name, offset}); push_bytes(nullptr); } private: template void push_bytes(T const & value) { auto begin = (std::uint8_t const *)(&value); auto end = begin + sizeof(value); pcontext.code.insert(pcontext.code.end(), begin, end); } }; // Set register @reg to -1 (all bits = 1) void set_m1(instruction_builder & builder, std::uint8_t reg) { builder.or_not_reg(31, 31, reg); } struct literal_visitor { program_context & pcontext; local_context & lcontext; instruction_builder & builder; void operator()(ast::bool_literal const & node) { if (node.value) set_m1(builder, 0); else builder.movz(0, 0); } template requires(std::is_integral_v && !std::is_same_v) void operator()(ast::primitive_literal_base const & node) { for (std::size_t i = 0; i < sizeof(T); i += 2) { if (i == 0) { builder.movz(0, std::uint64_t(node.value)); } else { auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8)); if (val != 0) builder.movk(0, val, i / 2); } } if (sizeof(T) < 8) { if constexpr (std::is_signed_v) { if (node.value < 0) builder.sbfm(0, 0, sizeof(T) * 8); } } } void operator()(ast::f16_literal const & node) { auto offset = lcontext.f16_constants.at(node.value.repr); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 0, (offset - current) / 4); builder.fcvt(0, 0b10, 0, 0b01); } void operator()(ast::f32_literal const & node) { auto offset = lcontext.f32_constants.at(node.value); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 0, (offset - current) / 4); } void operator()(ast::f64_literal const & node) { auto offset = lcontext.f64_constants.at(node.value); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 1, (offset - current) / 4); } }; struct reg_extend_visitor : types::const_visitor { using const_visitor::apply; instruction_builder & builder; std::uint8_t reg; void apply(types::bool_type const &) { builder.ubfm(reg, reg, 8); } void apply(types::f16_type const &) {} void apply(types::f32_type const &) {} void apply(types::f64_type const &) {} template void apply(types::primitive_type_base const &) { if constexpr (sizeof(T) == 8) { return; } if constexpr (std::is_signed_v) { builder.sbfm(reg, reg, sizeof(T) * 8); } if constexpr (std::is_unsigned_v) { builder.ubfm(reg, reg, sizeof(T) * 8); } } template void apply(T const &) { throw std::runtime_error(std::string("reg_extend_visitor is not implemented for ") + typeid(T).name()); } }; struct compile_visitor { program_context & pcontext; ir::module_context const & mcontext; local_context & lcontext; instruction_builder & builder; std::vector argument_position; std::unordered_map stack_position; std::int32_t stack_size = 0; void apply(ir::node_ref, ir::label const &, types::type_ptr const &) {} void apply(ir::node_ref it, ir::literal const & node, types::type_ptr const & type) { std::visit(literal_visitor{pcontext, lcontext, builder}, node.value); if (types::is_integer_like_type(*type)) store(it, 0); else if (types::is_floating_point_type(*type)) store_fp(it, 0, fp_mode_for(*type)); } void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &) { // Nothing to do: alloc just allocates a node of a struct type, // but we already allocated stack space for it } void apply(ir::node_ref it, ir::copy const & node, types::type_ptr const & type) { // TODO: array/array element copy? auto size = ast::type_size(*type); auto dst_offset = stack_size - stack_position.at(it); auto src_type = node.source->inferred_type; auto src_offset = stack_size - stack_position.at(node.source); for (auto field_id : node.path) { auto const & field = std::get(*src_type).node->fields[field_id]; src_type = field.inferred_type; src_offset += field.layout.offset; } copy_memory(31, src_offset, 31, dst_offset, size, 0); } void apply(ir::node_ref it, ir::load const & node, types::type_ptr const & type) { // TODO: array/array element load? load(node.ptr, 0); auto size = ast::type_size(*type); auto dst_offset = stack_size - stack_position.at(it); copy_memory(0, 0, 31, dst_offset, size, 1); } void apply(ir::node_ref, ir::store const & node, types::type_ptr const & type) { // TODO: array/array element store? load(node.ptr, 0); auto size = ast::type_size(*type); std::int32_t src_offset = stack_size - stack_position.at(node.value); copy_memory(31, src_offset, 0, 0, size, 1); } void apply(ir::node_ref it, ir::unary_operation const & node, types::type_ptr const & type) { switch (node.type) { case ast::unary_operation_type::negation: if (types::is_integer_type(*type)) { load(node.arg1, 0); builder.sub_reg(31, 0, 0); store(it, 0); } else if (types::is_floating_point_type(*type)) { auto mode = fp_mode_for(*type); load_fp(it, 0, mode); builder.fneg(0, mode, 0); store_fp(it, 0, mode); } break; case ast::unary_operation_type::logical_not: load(node.arg1, 0); builder.or_not_reg(31, 0, 0); store(it, 0); break; case ast::unary_operation_type::address_of: case ast::unary_operation_type::mutable_address_of: builder.add_imm(31, 0, stack_size - stack_position.at(node.arg1)); store(it, 0); break; case ast::unary_operation_type::dereference: throw std::runtime_error("Dereference operator mush not be present in compiled IR"); } } void apply(ir::node_ref it, ir::binary_operation const & node, types::type_ptr const & type) { auto arg1_type = node.arg1->inferred_type; bool const is_fp = types::is_floating_point_type(*arg1_type); std::uint8_t const fp_mode = fp_mode_for(*arg1_type); if (is_fp) { load_fp(node.arg1, 0, fp_mode); load_fp(node.arg2, 1, fp_mode); } else { load(node.arg1, 0); load(node.arg2, 1); } switch (node.type) { case ast::binary_operation_type::addition: if (is_fp) builder.fadd(0, 1, fp_mode, 0); else { builder.add_reg(0, 1, 0); } break; case ast::binary_operation_type::subtraction: if (is_fp) builder.fsub(0, 1, fp_mode, 0); else { builder.sub_reg(0, 1, 0); } break; case ast::binary_operation_type::multiplication: if (is_fp) builder.fmul(0, 1, fp_mode, 0); else { builder.mul_reg(0, 1, 0); } break; case ast::binary_operation_type::division: if (is_fp) builder.fdiv(0, 1, fp_mode, 0); else { extend(0, type); extend(1, type); if (types::is_signed_integer_type(*type)) builder.sdiv_reg(0, 1, 0); else builder.udiv_reg(0, 1, 0); } break; case ast::binary_operation_type::remainder: extend(0, type); extend(1, type); if (types::is_signed_integer_type(*type)) { builder.sdiv_reg(0, 1, 2); builder.mul_reg(1, 2, 1); builder.sub_reg(0, 1, 0); } else if (types::is_unsigned_integer_type(*type)) { builder.udiv_reg(0, 1, 2); builder.mul_reg(1, 2, 1); builder.sub_reg(0, 1, 0); } break; case ast::binary_operation_type::binary_and: builder.and_reg(0, 1, 0); break; case ast::binary_operation_type::logical_and: throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler"); case ast::binary_operation_type::binary_or: builder.or_reg(0, 1, 0); break; case ast::binary_operation_type::logical_or: throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler"); case ast::binary_operation_type::logical_xor: builder.xor_reg(0, 1, 0); break; case ast::binary_operation_type::equals: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b0000); } else { extend(0, node.arg1->inferred_type); extend(1, node.arg2->inferred_type); builder.cmp_reg(0, 1); builder.csetm(0, 0b0000); } break; case ast::binary_operation_type::not_equals: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b0001); } else { extend(0, node.arg1->inferred_type); extend(1, node.arg2->inferred_type); builder.cmp_reg(0, 1); builder.csetm(0, 0b0001); } break; case ast::binary_operation_type::less: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b0100); } else { extend(0, node.arg1->inferred_type); extend(1, node.arg2->inferred_type); builder.cmp_reg(1, 0); if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); } break; case ast::binary_operation_type::greater: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b0100); } else { extend(0, node.arg1->inferred_type); extend(1, node.arg2->inferred_type); builder.cmp_reg(0, 1); if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); } break; case ast::binary_operation_type::less_equals: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b1001); } else { extend(0, node.arg1->inferred_type); extend(1, node.arg2->inferred_type); builder.cmp_reg(0, 1); if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); } break; case ast::binary_operation_type::greater_equals: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b1001); } else { extend(0, node.arg1->inferred_type); extend(1, node.arg2->inferred_type); builder.cmp_reg(1, 0); if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); } break; default: { std::ostringstream os; os << "binary operation " << node.type << " is not implemented"; throw std::runtime_error(os.str()); } } if (is_fp) store_fp(it, 0, fp_mode); else store(it, 0); } void apply(ir::node_ref it, ir::cast_operation const & node, types::type_ptr const &) { auto src_type = node.arg1->inferred_type; auto dst_type = node.target_type; if (types::equal(*src_type, *dst_type) || (types::is_pointer_type(*src_type) && types::is_pointer_type(*dst_type))) { load(node.arg1, 0); store(it, 0); return; } if (types::is_integer_type(*src_type)) { load(node.arg1, 0); if (types::is_integer_type(*dst_type)) { extend(0, dst_type); } else if (types::is_floating_point_type(*dst_type)) { auto dst_mode = fp_mode_for(*dst_type); if (types::is_signed_integer_type(*src_type)) { builder.fmov(0, 0, 3, 1); builder.scvtf(0, 0, 3); if (dst_mode != 3) builder.fcvt(0, 3, 0, dst_mode); } else if (types::is_unsigned_integer_type(*src_type)) { builder.fmov(0, 0, 3, 1); builder.ucvtf(0, 0, 3); if (dst_mode != 3) builder.fcvt(0, 3, 0, dst_mode); } } } else if (types::is_floating_point_type(*src_type)) { auto src_mode = fp_mode_for(*src_type); if (types::is_integer_type(*dst_type)) { if (types::is_signed_integer_type(*dst_type)) { builder.fcvtns(0, 0, src_mode); extend(0, dst_type); } else if (types::is_unsigned_integer_type(*dst_type)) { builder.fcvtnu(0, 0, src_mode); extend(0, dst_type); } } else if (types::is_floating_point_type(*dst_type)) { auto dst_mode = fp_mode_for(*dst_type); builder.fcvt(0, src_mode, 0, dst_mode); } } if (types::is_integer_type(*dst_type)) { store(it, 0); } else if (types::is_floating_point_type(*dst_type)) { store_fp(it, 0, fp_mode_for(*dst_type)); } } void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type) { // Nothing to do: arguments already pushed on stack in function preamble } void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &) { lcontext.adr_resolve.emplace_back(pcontext.code.size(), node.target); builder.adr(0, 0); store(it, 0); } void apply(ir::node_ref it, ir::extern_symbol const & node, types::type_ptr const &) { builder.ldr_pc(0, (lcontext.extern_symbols[node.name] - (std::int32_t)pcontext.code.size()) / 4); store(it, 0); } void apply(ir::node_ref, ir::assignment const & node, types::type_ptr const & type) { // TODO: array/array element assignment? std::size_t src_offset = stack_size - stack_position.at(node.rhs); auto dst_type = node.lhs->inferred_type; std::size_t dst_offset = stack_size - stack_position.at(node.lhs); for (auto field_id : node.path) { auto struct_node = std::get(*dst_type).node; dst_type = struct_node->fields[field_id].inferred_type; dst_offset = dst_offset + struct_node->fields[field_id].layout.offset; } copy_memory(31, src_offset, 31, dst_offset, ast::type_size(*dst_type), 0); } void apply(ir::node_ref, ir::jump const & node, types::type_ptr const &) { lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target); builder.b(0); } void apply(ir::node_ref, ir::jump_if_zero const & node, types::type_ptr const &) { load(node.condition, 0); lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target); builder.cbz(0, 0); } void apply(ir::node_ref, ir::jump_if_nonzero const & node, types::type_ptr const &) { load(node.condition, 0); lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target); builder.cbnz(0, 0); } void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type) { // TODO: struct/array arguments? std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (auto const & argument : node.arguments) { if (types::is_integer_like_type(*argument->inferred_type)) load(argument, reg++); else if (types::is_floating_point_type(*argument->inferred_type)) load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type)); else throw std::runtime_error("Unsupported function argument type"); } if (!lcontext.use_frame_pointer) { builder.sub_imm(31, 31, 16); builder.str(30, 31, 0); } lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target); builder.bl(0); if (!lcontext.use_frame_pointer) { builder.ldr(30, 31, 0); builder.add_imm(31, 31, 16); } // TODO: struct/array return value? if (types::is_unit_type(*type)) {} else if (types::is_integer_like_type(*type)) store(it, 0); else if (types::is_floating_point_type(*type)) store_fp(it, 0, fp_mode_for(*type)); else throw std::runtime_error("Unsupported return value type"); } void apply(ir::node_ref it, ir::call_pointer const & node, types::type_ptr const & type) { // TODO: struct/array arguments? std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (auto const & argument : node.arguments) { if (types::is_integer_like_type(*argument->inferred_type)) load(argument, reg++); else if (types::is_floating_point_type(*argument->inferred_type)) load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type)); else throw std::runtime_error("Unsupported function argument type"); } load(node.pointer, reg); builder.sub_imm(31, 31, 16); builder.str(30, 31, 0); builder.bl_reg(reg); builder.ldr(30, 31, 0); builder.add_imm(31, 31, 16); // TODO: struct/array return value? if (types::is_unit_type(*type)) {} else if (types::is_integer_like_type(*type)) store(it, 0); else if (types::is_floating_point_type(*type)) store_fp(it, 0, fp_mode_for(*type)); else throw std::runtime_error("Unsupported return value type"); } void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &) { // TODO: struct/array return value? if (node.value) { auto type = (*node.value)->inferred_type; if (types::is_integer_like_type(*type)) load(*node.value, 0); else if (types::is_floating_point_type(*type)) load_fp(*node.value, 0, fp_mode_for(*type)); else throw std::runtime_error("Unsupported return value type"); } if (lcontext.use_frame_pointer) { builder.ldr(29, 31, (stack_size - 16) / 8); builder.ldr(30, 31, (stack_size - 8) / 8); } if (stack_size > 0) builder.add_imm(31, 31, stack_size); builder.ret(); } void compile(ast::function_definition const * function_definition, ir::node_ref begin, ir::node_ref end) { stack_size = 0; if (lcontext.use_frame_pointer) stack_size += 16; for (auto const & argument : function_definition->arguments) { auto size = ast::type_size(*argument.inferred_type); stack_size += ((size + 7) / 8) * 8; argument_position.push_back(stack_size); } for (auto it = begin; it != end; ++it) { if (auto argument = std::get_if(&it->instruction)) { stack_position[it] = argument_position[argument->index]; } else if (ir::is_value_instruction(it->instruction)) { auto size = ast::type_size(*it->inferred_type); if (size == 0) continue; // TODO: inefficient for small types, maybe only round up to type alignment? // Need to make sure all read/write arm64 instructions used can handle offsets that // are not a multiple of 8 stack_size += ((size + 7) / 8) * 8; stack_position[it] = stack_size; } } stack_size = ((stack_size + 15) / 16) * 16; if (!std::holds_alternative(begin->instruction)) throw std::runtime_error("First IR node of a function must be a label"); auto it = begin; lcontext.nodes[it] = pcontext.code.size(); if (stack_size > 0) builder.sub_imm(31, 31, stack_size); if (lcontext.use_frame_pointer) { builder.str(29, 31, (stack_size - 16) / 8); builder.str(30, 31, (stack_size - 8) / 8); builder.add_imm(31, 29, stack_size - 16); } std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (std::size_t i = 0; i < function_definition->arguments.size(); ++i) { auto const & argument = function_definition->arguments[i]; auto size = ast::type_size(*argument.inferred_type); auto fp_mode = fp_mode_for(*argument.inferred_type); if (size == 0) continue; if (types::is_integer_like_type(*argument.inferred_type)) builder.str(reg++, 31, (stack_size - argument_position[i]) / 8); else if (types::is_floating_point_type(*argument.inferred_type)) builder.str_fp(fp_reg++, fp_mode, 31, (stack_size - argument_position[i]) / fp_size(fp_mode)); else throw std::runtime_error("Unknown argument type"); } ++it; for (; it != end; ++it) { // Uncomment to debug per-node instruction generation: builder.nop(); lcontext.nodes[it] = pcontext.code.size(); std::visit([&](auto const & instruction){ apply(it, instruction, it->inferred_type); }, it->instruction); } } private: void load(ir::node_ref ir, std::uint8_t reg) { std::int32_t offset = stack_size - stack_position.at(ir); builder.ldr(reg, 31, offset / 8); } void load_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode) { std::int32_t offset = stack_size - stack_position.at(ir); builder.ldr_fp(reg, mode, 31, offset / fp_size(mode)); } void store(ir::node_ref ir, std::uint8_t reg) { std::int32_t offset = stack_size - stack_position.at(ir); builder.str(reg, 31, offset / 8); } void store_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode) { std::int32_t offset = stack_size - stack_position.at(ir); builder.str_fp(reg, mode, 31, offset / fp_size(mode)); } // Sign- or zero-extend the register depending on the exact type void extend(std::uint8_t reg, types::type_ptr const & type) { reg_extend_visitor{{}, builder, reg}.apply(*type); } void copy_memory(std::uint8_t reg_src_addr, std::size_t src_offset, std::uint8_t reg_dst_addr, std::size_t dst_offset, std::size_t size, std::uint8_t tmp_reg) { std::int32_t offset = 0; while (size > 0) { auto check_step = [&](std::size_t step) { return size >= step && (((src_offset + offset) % step) == 0) && (((dst_offset + offset) % step) == 0); }; if (check_step(8)) { builder.ldr(tmp_reg, reg_src_addr, (src_offset + offset) / 8); builder.str(tmp_reg, reg_dst_addr, (dst_offset + offset) / 8); size -= 8; offset += 8; } else if (check_step(4)) { builder.ldrw(tmp_reg, reg_src_addr, (src_offset + offset) / 4); builder.strw(tmp_reg, reg_dst_addr, (dst_offset + offset) / 4); size -= 4; offset += 4; } else if (check_step(2)) { builder.ldrh(tmp_reg, reg_src_addr, (src_offset + offset) / 2); builder.strh(tmp_reg, reg_dst_addr, (dst_offset + offset) / 2); size -= 2; offset += 2; } else { builder.ldrb(tmp_reg, reg_src_addr, src_offset + offset); builder.strb(tmp_reg, reg_dst_addr, dst_offset + offset); size -= 1; offset += 1; } } } }; } void compile(program_context & pcontext, ir::module_context const & mcontext) { local_context lcontext; instruction_builder builder{pcontext.code}; { populate_const_data_visitor visitor{pcontext, lcontext}; for (auto it = mcontext.nodes->begin(); it != mcontext.nodes->end(); ++it) std::visit([&](auto const & instruction){ visitor.apply(instruction, it->inferred_type); }, it->instruction); } for (auto const & symbol : mcontext.symbols) { pcontext.symbols[symbol.first] = pcontext.code.size(); compile_visitor visitor{pcontext, mcontext, lcontext, builder}; visitor.compile(symbol.first, symbol.second.begin, symbol.second.end); } pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point); for (auto const & resolve : lcontext.branch_resolve) builder.b_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4); for (auto const & resolve : lcontext.cbranch_resolve) builder.cb_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4); for (auto const & resolve : lcontext.adr_resolve) builder.adr_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4); } }