#include #include #include #include #include #include #include #include #include #include #include #include namespace pslang::jit::aarch64 { namespace { // Homogeneous floating-point aggregate: up to 4 floating-point members // of the same type (after struct flattening) struct hfa_data { types::type_ptr type; std::size_t count; }; struct local_context { std::unordered_map f16_constants; std::unordered_map f32_constants; std::unordered_map f64_constants; std::unordered_map foreign_address; std::unordered_map functions; struct struct_data { std::optional hfa = {}; }; std::unordered_map structs; struct scope { std::unordered_set foreign_functions; std::unordered_map functions; std::unordered_map structs; }; std::vector scopes; struct resolve_info { ast::function_definition const * node; // Must be 'adr' instruction std::int32_t instruction_offset; }; std::vector resolve; bool is_foreign(std::string const & name) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (it->foreign_functions.contains(name)) return true; if (it->functions.contains(name)) return false; if (it->structs.contains(name)) return false; } return false; } ast::function_definition const * is_function(std::string const & name) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->functions.find(name); jt != it->functions.end()) return jt->second; if (it->foreign_functions.contains(name)) return nullptr; if (it->structs.contains(name)) return nullptr; } return nullptr; } ast::struct_definition const * is_struct(std::string const & name) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->structs.find(name); jt != it->structs.end()) return jt->second; if (it->foreign_functions.contains(name)) return nullptr; if (it->functions.contains(name)) return nullptr; } return nullptr; } }; std::uint8_t fp_mode_for(types::type const & type) { if (types::equal(type, types::primitive_type(types::f16_type{}))) return 1; if (types::equal(type, types::primitive_type(types::f32_type{}))) return 2; return 3; } bool is_short_circuiting(ast::binary_operation_type type) { switch (type) { case ast::binary_operation_type::logical_and: case ast::binary_operation_type::logical_or: return true; default: return false; } } // Add all f16, f32 and f64 constants as read-only data entries // Add extern pointers for all foreign functions as read-only data entries struct populate_constants_visitor : ast::const_expression_visitor , ast::const_statement_visitor { using const_expression_visitor::apply; using const_statement_visitor::apply; program_context & pcontext; local_context & lcontext; template requires(!std::is_floating_point_v) void apply(ast::primitive_literal_base const &) {} void apply(ast::f16_literal const & node) { if (!lcontext.f16_constants.contains(node.value.repr)) { lcontext.f16_constants[node.value.repr] = pcontext.code.size(); push_bytes(node.value.repr); } } void apply(ast::f32_literal const & node) { if (!lcontext.f32_constants.contains(node.value)) { lcontext.f32_constants[node.value] = pcontext.code.size(); push_bytes(node.value); } } void apply(ast::f64_literal const & node) { if (!lcontext.f64_constants.contains(node.value)) { lcontext.f64_constants[node.value] = pcontext.code.size(); push_bytes(node.value); } } void apply(ast::identifier const &) {} void apply(ast::unary_operation const & node) { apply(*node.arg1); } void apply(ast::binary_operation const & node) { apply(*node.arg1); apply(*node.arg2); } void apply(ast::cast_operation const & node) { apply(*node.expression); } void apply(ast::function_call const & node) { if (node.function) apply(*node.function); for (auto const & argument : node.arguments) apply(*argument); } void apply(ast::array const & node) { for (auto const & element : node.elements) apply(*element); } void apply(ast::array_access const & node) { apply(*node.array); apply(*node.index); } void apply(ast::field_access const & node) { apply(*node.object); } void apply(ast::expression_ptr const & node) { apply(*node); } void apply(ast::assignment const & node) { apply(*node.lhs); apply(*node.rhs); } void apply(ast::variable_declaration const & node) { apply(*node.initializer); } void apply(ast::if_chain const & node) { for (auto const & block : node.blocks) { if (block.condition) apply(*block.condition); apply(*block.statements); } } void apply(ast::while_block const & node) { apply(*node.condition); apply(*node.statements); } void apply(ast::function_definition const & node) { apply(*node.statements); } void apply(ast::foreign_function_declaration const & foreign_function_declaration) { if (!lcontext.foreign_address.contains(foreign_function_declaration.name)) { lcontext.foreign_address[foreign_function_declaration.name] = pcontext.code.size(); push_bytes(nullptr); } } void apply(ast::return_statement const & node) { if (node.value) apply(*node.value); } void apply(ast::struct_definition const &) {} private: template void push_bytes(T const & value) { auto begin = (std::uint8_t const *)(&value); auto end = begin + sizeof(value); pcontext.code.insert(pcontext.code.end(), begin, end); } }; std::optional get_hfa_data(ast::struct_definition const & node, local_context & lcontext) { if (auto it = lcontext.structs.find(&node); it != lcontext.structs.end()) return it->second.hfa; types::type_ptr type = nullptr; std::size_t count = 0; for (auto const & field : node.fields) { if (types::is_builtin_type(*field.inferred_type)) { if (!types::is_floating_point_type(*field.inferred_type)) return std::nullopt; if (type && !types::equal(*type, *field.inferred_type)) return std::nullopt; type = field.inferred_type; ++count; } else if (auto struct_type = std::get_if(field.inferred_type.get())) { // NB: recursion must be impossible due to prior checks in type checker if (auto subdata = get_hfa_data(*struct_type->node, lcontext)) { if (type && !types::equal(*type, *subdata->type)) return std::nullopt; type = subdata->type; count += subdata->count; } else return std::nullopt; } else return std::nullopt; } if (count <= 4) return hfa_data{type, count}; return std::nullopt; } // Iterate over a single scope (i.e. not visiting subscopes recursively) // and add all defined functions & foreign functions to the current scope struct populate_symbols_visitor : ast::const_statement_visitor { local_context & lcontext; using const_statement_visitor::apply; void apply(ast::expression_ptr const &) {} void apply(ast::assignment const &) {} void apply(ast::variable_declaration const &) {} void apply(ast::if_chain const &) {} void apply(ast::while_block const &) {} void apply(ast::function_definition const & node) { lcontext.scopes.back().functions[node.name] = &node; } void apply(ast::foreign_function_declaration const & node) { lcontext.scopes.back().foreign_functions.insert(node.name); } void apply(ast::return_statement const &) {} void apply(ast::struct_definition const & node) { lcontext.scopes.back().structs[node.name] = &node; if (!lcontext.structs.contains(&node)) { // NB: make sure not to add struct to lcontext.structs before computing hfa data auto hfa = get_hfa_data(node, lcontext); lcontext.structs[&node].hfa = hfa; } } }; struct reg_extend_visitor : types::const_visitor { using const_visitor::apply; instruction_builder & builder; std::uint8_t reg; void apply(types::bool_type const &) {} void apply(types::f32_type const &) {} void apply(types::f64_type const &) {} template void apply(types::primitive_type_base const &) { if constexpr (sizeof(T) == 8) { return; } if constexpr (std::is_signed_v) { builder.sbfm(reg, reg, sizeof(T) * 8); } if constexpr (std::is_unsigned_v) { builder.ubfm(reg, reg, sizeof(T) * 8); } } template void apply(T const &) { throw std::runtime_error(std::string("reg_extend_visitor is not implemented for ") + typeid(T).name()); } }; // Compile a single function and store the entry point offset // in local_context struct compile_function_visitor : ast::const_statement_visitor , ast::const_expression_visitor { using const_statement_visitor::apply; using const_expression_visitor::apply; program_context & pcontext; local_context & lcontext; instruction_builder builder{pcontext.code}; // Difference between initial stack pointer at function enter // and current virtual stack pointer value. The actual stack pointer // value is rounded down to a multiple of 16 std::uint32_t stack_offset = 0; struct variable_data { // Difference between initial stack pointer at function enter // and the variable address // Must be a multiple of 16 std::uint32_t frame_offset; }; struct scope { std::unordered_map variables = {}; // Difference between initial virtual stack pointer at scope enter // and current virtual stack pointer value std::uint32_t stack_offset = 0; }; std::vector scopes; template void apply(Node const &) { throw std::runtime_error(std::string("compile_function_visitor is not implemented for ") + typeid(Node).name()); } void apply(ast::bool_literal const & node) { if (node.value) set_m1(0); else builder.movz(0, 0); } template requires(std::is_integral_v && !std::is_same_v) void apply(ast::primitive_literal_base const & node) { for (std::size_t i = 0; i < sizeof(T); i += 2) { if (i == 0) { builder.movz(0, std::uint64_t(node.value)); } else { auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8)); if (val != 0) builder.movk(0, val, i / 2); } } if (sizeof(T) < 8) { if constexpr (std::is_signed_v) { if (node.value < 0) builder.sbfm(0, 0, sizeof(T) * 8); } } } void apply(ast::f16_literal const & node) { auto offset = lcontext.f16_constants.at(node.value.repr); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 0, (offset - current) / 4); builder.fcvt(0, 0b10, 0, 0b01); } void apply(ast::f32_literal const & node) { auto offset = lcontext.f32_constants.at(node.value); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 0, (offset - current) / 4); } void apply(ast::f64_literal const & node) { auto offset = lcontext.f64_constants.at(node.value); std::int32_t current = pcontext.code.size(); builder.ldr_fp_pc(0, 1, (offset - current) / 4); } void apply(ast::identifier const & node) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (node.variable_node && it->variables.contains(node.variable_node)) { auto jt = it->variables.find(node.variable_node); if (auto struct_type = std::get_if(node.inferred_type.get())) { std::size_t stack_size = ((struct_type->node->layout.size + 15) / 16) * 16; builder.sub_imm(31, 31, stack_size); stack_offset += stack_size; scopes.back().stack_offset += stack_size; std::size_t variable_offset = stack_offset - jt->second.frame_offset; for (std::size_t offset = 0; offset < stack_size; offset += 16) { builder.ldr(0, 31, (variable_offset + offset) / 8); builder.ldr(1, 31, (variable_offset + offset) / 8 + 1); builder.str(0, 31, offset / 8); builder.str(1, 31, offset / 8 + 1); } } else if (types::is_unit_type(*node.inferred_type)) {} else if (types::is_floating_point_type(*node.inferred_type)) builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / type_size(*node.inferred_type)); else builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8); return; } } if (lcontext.is_foreign(node.name)) { builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4); } else if (auto function_node = lcontext.is_function(node.name)) { lcontext.resolve.push_back({function_node, (std::int32_t)pcontext.code.size()}); builder.adr(0, 0); } else { throw std::runtime_error("unknown identifier \"" + node.name + "\""); } } void apply(ast::unary_operation const & node) { switch (node.type) { case ast::unary_operation_type::negation: apply(*node.arg1); if (types::is_integer_type(*node.inferred_type)) { builder.sub_reg(31, 0, 0); extend(0, node.inferred_type); } else if (types::is_floating_point_type(*node.inferred_type)) { builder.fneg(0, fp_mode_for(*node.inferred_type), 0); } break; case ast::unary_operation_type::logical_not: apply(*node.arg1); builder.or_not_reg(31, 0, 0); if (types::is_integer_type(*node.inferred_type)) extend(0, node.inferred_type); break; } } void apply(ast::binary_operation const & node) { auto arg1_type = ast::get_type(*node.arg1); bool const is_fp = types::is_floating_point_type(*arg1_type); std::uint8_t const fp_mode = fp_mode_for(*arg1_type); apply(*node.arg1); if (!is_short_circuiting(node.type)) { if (is_fp) { push_fp(0, fp_mode); apply(*node.arg2); pop_fp(1, fp_mode); } else { push(0); apply(*node.arg2); pop(1); } } switch (node.type) { case ast::binary_operation_type::addition: if (is_fp) builder.fadd(1, 0, fp_mode, 0); else { builder.add_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::subtraction: if (is_fp) builder.fsub(1, 0, fp_mode, 0); else { builder.sub_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::multiplication: if (is_fp) builder.fmul(1, 0, fp_mode, 0); else { builder.mul_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::division: if (is_fp) builder.fdiv(1, 0, fp_mode, 0); else { if (types::is_signed_integer_type(*node.inferred_type)) builder.sdiv_reg(1, 0, 0); else builder.udiv_reg(1, 0, 0); extend(0, node.inferred_type); } break; case ast::binary_operation_type::remainder: if (types::is_signed_integer_type(*node.inferred_type)) { builder.sdiv_reg(1, 0, 2); builder.mul_reg(0, 2, 0); builder.sub_reg(1, 0, 0); } else if (types::is_unsigned_integer_type(*node.inferred_type)) { builder.udiv_reg(1, 0, 2); builder.mul_reg(0, 2, 0); builder.sub_reg(1, 0, 0); } break; case ast::binary_operation_type::binary_and: builder.and_reg(1, 0, 0); break; case ast::binary_operation_type::logical_and: { std::int32_t start = pcontext.code.size(); builder.cbz(0, 0); push(0); apply(*node.arg2); pop(1); builder.and_reg(1, 0, 0); std::int32_t end = pcontext.code.size(); builder.cb_inject(pcontext.code.data() + start, (end - start) / 4); } break; case ast::binary_operation_type::binary_or: builder.or_reg(1, 0, 0); break; case ast::binary_operation_type::logical_or: { set_m1(1); extend(1, arg1_type); builder.xor_reg(0, 1, 1); std::int32_t start = pcontext.code.size(); builder.cbz(1, 0); push(0); apply(*node.arg2); pop(1); builder.or_reg(1, 0, 0); std::int32_t end = pcontext.code.size(); builder.cb_inject(pcontext.code.data() + start, (end - start) / 4); } break; case ast::binary_operation_type::logical_xor: builder.xor_reg(1, 0, 0); break; case ast::binary_operation_type::equals: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b0000); } else { builder.cmp_reg(1, 0); builder.csetm(0, 0b0000); } break; case ast::binary_operation_type::not_equals: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b0001); } else { builder.cmp_reg(1, 0); builder.csetm(0, 0b0001); } break; case ast::binary_operation_type::less: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b0100); } else { builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); } break; case ast::binary_operation_type::greater: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b0100); } else { builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1000); else builder.csetm(0, 0b1100); } break; case ast::binary_operation_type::less_equals: if (is_fp) { builder.fcmp(1, 0, fp_mode); builder.csetm(0, 0b1001); } else { builder.cmp_reg(1, 0); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); } break; case ast::binary_operation_type::greater_equals: if (is_fp) { builder.fcmp(0, 1, fp_mode); builder.csetm(0, 0b1001); } else { builder.cmp_reg(0, 1); if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1))) builder.csetm(0, 0b1001); else builder.csetm(0, 0b1101); } break; default: { std::ostringstream os; os << "binary operation " << node.type << " is not implemented"; throw std::runtime_error(os.str()); } } } void apply(ast::cast_operation const & node) { auto src_type = ast::get_type(*node.expression); auto dst_type = node.inferred_type; apply(*node.expression); if (types::equal(*src_type, *dst_type)) return; if (types::is_integer_type(*src_type)) { if (types::is_integer_type(*dst_type)) { extend(0, dst_type); } else if (types::is_floating_point_type(*dst_type)) { auto dst_mode = fp_mode_for(*dst_type); if (types::is_signed_integer_type(*src_type)) { builder.fmov(0, 0, 3, 1); builder.scvtf(0, 0, 3); if (dst_mode != 3) builder.fcvt(0, 3, 0, dst_mode); } else if (types::is_unsigned_integer_type(*src_type)) { builder.fmov(0, 0, 3, 1); builder.ucvtf(0, 0, 3); if (dst_mode != 3) builder.fcvt(0, 3, 0, dst_mode); } } } else if (types::is_floating_point_type(*src_type)) { auto src_mode = fp_mode_for(*src_type); if (types::is_integer_type(*dst_type)) { if (types::is_signed_integer_type(*dst_type)) { builder.fcvtns(0, 0, src_mode); extend(0, dst_type); } else if (types::is_unsigned_integer_type(*dst_type)) { builder.fcvtnu(0, 0, src_mode); extend(0, dst_type); } } else if (types::is_floating_point_type(*dst_type)) { auto dst_mode = fp_mode_for(*dst_type); builder.fcvt(0, src_mode, 0, dst_mode); } } } void apply(ast::function_call const & node) { if (node.function) { apply(*node.function); push(0); for (std::size_t i = node.arguments.size(); i --> 0;) { auto const & arg = node.arguments[i]; apply(*arg); auto type = ast::get_type(*arg); if (types::is_bool_type(*type) || types::is_integer_type(*type)) { push(0); } else if (types::is_floating_point_type(*type)) { push_fp(0, fp_mode_for(*type)); } } std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (auto const & arg : node.arguments) { auto type = ast::get_type(*arg); if (types::is_bool_type(*type) || types::is_integer_type(*type)) { pop(reg); ++reg; } else if (types::is_floating_point_type(*type)) { pop_fp(fp_reg, fp_mode_for(*type)); ++fp_reg; } } pop(reg); push(30); builder.b_reg(reg); pop(30); } else // if (node.type) { if (types::is_unit_type(*node.inferred_type)) { // Do nothing } else if (types::is_bool_type(*node.inferred_type) || types::is_integer_type(*node.inferred_type) || types::is_function_type(*node.inferred_type)) { builder.xor_reg(0, 0, 0); } else if (types::is_floating_point_type(*node.inferred_type)) { builder.xor_reg(0, 0, 0); builder.fmov(0, 0, fp_mode_for(*node.inferred_type), 1); } else if (auto struct_type = std::get_if(node.inferred_type.get())) { auto & struct_node = *struct_type->node; // Allocate stack space for the struct std::size_t stack_size = ((struct_node.layout.size + 15) / 16) * 16; auto offset = stack_offset; stack_offset += stack_size; scopes.back().stack_offset += stack_size; builder.sub_imm(31, 31, stack_size); // Evaluate each field of the struct (i.e. each constructor argument) // and copy it to the corresponding place in the struct for (std::size_t i = 0; i < node.arguments.size(); ++i) { auto type = ast::get_type(*node.arguments[i]); apply(*node.arguments[i]); if (std::get_if(type.get())) { // TODO: struct field throw std::runtime_error("Not implemented"); } else if (types::is_floating_point_type(*type)) { builder.stur_fp(0, fp_mode_for(*type), 31, struct_node.fields[i].layout.offset); } else { auto size = types::type_size(*type); if (size == 1) builder.sturb(0, 31, struct_node.fields[i].layout.offset); else if (size == 2) builder.sturh(0, 31, struct_node.fields[i].layout.offset); else if (size == 4) builder.sturw(0, 31, struct_node.fields[i].layout.offset); else if (size == 8) builder.stur(0, 31, struct_node.fields[i].layout.offset); } } } } } void apply(ast::field_access const & node) { auto object_type = get_type(*node.object); if (auto struct_type = std::get_if(object_type.get())) { auto & struct_node = *struct_type->node; std::optional field_id; for (std::size_t i = 0; i < struct_node.fields.size(); ++i) { if (struct_node.fields[i].name == node.field_name) { field_id = i; break; } } if (!field_id) throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + struct_node.name + "\""); apply(*node.object); auto stack_size = ((struct_node.layout.size + 15) / 16) * 16; auto const & field = struct_node.fields[*field_id]; if (types::is_unit_type(*field.inferred_type)) {} else if (types::is_floating_point_type(*field.inferred_type)) { builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset); builder.add_imm(31, 31, stack_size); stack_offset -= stack_size; scopes.back().stack_offset -= stack_size; } else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type)) { auto size = types::type_size(*field.inferred_type); if (size == 1) builder.ldurb(0, 31, field.layout.offset); else if (size == 2) builder.ldurh(0, 31, field.layout.offset); else if (size == 4) builder.ldurw(0, 31, field.layout.offset); else if (size == 8) builder.ldur(0, 31, field.layout.offset); builder.add_imm(31, 31, stack_size); stack_offset -= stack_size; scopes.back().stack_offset -= stack_size; } else if (auto struct_type = std::get_if(field.inferred_type.get())) { // TODO: copy the struct-typed field on stack, overriding // the struct itself, and update the stack offset throw std::runtime_error("Not implemented"); } return; } throw std::runtime_error("Unknown object in field access"); } void apply(ast::expression_ptr const & node) { auto stack_offset_before = stack_offset; apply(*node); // Restore stack offset in case the expression evaluated to a struct // (in which case the struct would be placed on the stack) auto stack_delta = stack_offset - stack_offset_before; if (stack_delta > 0) { builder.add_imm(31, 31, stack_delta); stack_offset -= stack_delta; scopes.back().stack_offset -= stack_delta; } } void apply(ast::assignment const & node) { auto frame_offset = lvalue_offset(node.lhs); apply(*node.rhs); auto type = ast::get_type(*node.rhs); if (types::is_unit_type(*type)) {} else if (types::is_floating_point_type(*type)) builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - frame_offset) / type_size(*type)); else if (types::is_bool_type(*type) || types::is_integer_type(*type) || types::is_function_type(*type)) { auto size = types::type_size(*type); if (size == 1) builder.sturb(0, 31, stack_offset - frame_offset); else if (size == 2) builder.sturh(0, 31, stack_offset - frame_offset); else if (size == 4) builder.sturw(0, 31, stack_offset - frame_offset); else if (size == 8) builder.stur(0, 31, stack_offset - frame_offset); } else if (auto struct_type = std::get_if(type.get())) { // TODO: whole-struct assignment throw std::runtime_error("Not implemented"); } } void apply(ast::variable_declaration const & node) { apply(*node.initializer); auto type = ast::get_type(*node.initializer); if (std::get_if(type.get())) { // Nothing to be done: the struct is already on the stack // Just record the stack offset as variable location } else if (types::is_floating_point_type(*type)) push_fp(0, fp_mode_for(*type)); else if (types::is_unit_type(*type)) { // Nothing to be done: unit type has zero size // Its stack position is recorded but de facto unused } else push(0); scopes.back().variables[&node] = {.frame_offset = stack_offset}; } void apply(ast::if_chain const & node) { std::vector branch_to_end; for (std::size_t i = 0; i < node.blocks.size(); ++i) { auto const & block = node.blocks[i]; std::optional branch_skip; if (block.condition) { apply(*block.condition); branch_skip = pcontext.code.size(); builder.cbz(0, 0); } scopes.emplace_back(); apply(*block.statements); scope_cleanup(); scopes.pop_back(); if (i + 1 < node.blocks.size()) { branch_to_end.push_back(pcontext.code.size()); builder.b(0); } if (branch_skip) { auto branch_offset = pcontext.code.size() - *branch_skip; builder.cb_inject(pcontext.code.data() + *branch_skip, branch_offset / 4); } } auto end = pcontext.code.size(); for (auto instruction : branch_to_end) { auto delta = end - instruction; builder.b_inject(pcontext.code.data() + instruction, delta / 4); } } void apply(ast::while_block const & node) { std::int32_t start = pcontext.code.size(); apply(*node.condition); std::int32_t skip = pcontext.code.size(); builder.cbz(0, 0); scopes.emplace_back(); apply(*node.statements); scope_cleanup(); scopes.pop_back(); std::int32_t loop = pcontext.code.size(); builder.b(0); std::int32_t end = pcontext.code.size(); builder.cb_inject(pcontext.code.data() + skip, (end - skip) / 4); builder.b_inject(pcontext.code.data() + loop, (start - loop) / 4); } void apply(ast::return_statement const & node) { // TODO: struct return value if (node.value) apply(*node.value); do_return(); } void apply(ast::function_definition const &) { // Must be handled prior to that in populate_symbols_visitor } void apply(ast::foreign_function_declaration const &) { // Must be handled prior to that in populate_symbols_visitor } void apply(ast::struct_definition const &) { // Must be handled prior to that in populate_symbols_visitor } void apply(ast::statement_list const & node) { lcontext.scopes.emplace_back(); populate_symbols_visitor{{}, lcontext}.apply(node); for (auto const & statement : node.statements) apply(*statement); lcontext.scopes.pop_back(); } void do_apply(ast::function_definition const & node) { lcontext.functions[&node] = pcontext.code.size(); // TODO: struct arguments scopes.emplace_back(); std::uint8_t reg = 0; std::uint8_t fp_reg = 0; for (auto const & argument : node.arguments) { auto type = ast::get_type(*argument.type); if (types::is_bool_type(*type)) { builder.tst(reg, reg); builder.csetm(reg, 0b0001); push(reg); ++reg; } else if (types::is_integer_type(*type)) { extend(reg, type); push(reg); ++reg; } else if (types::is_floating_point_type(*type)) { auto mode = fp_mode_for(*type); push_fp(fp_reg, mode); ++fp_reg; } scopes.back().variables[&argument] = {.frame_offset = stack_offset}; } apply(*node.statements); if (node.statements->statements.empty() || !std::get_if(node.statements->statements.back().get())) if (types::equal(*ast::get_type(*node.return_type), types::unit_type{})) do_return(); scopes.pop_back(); } void do_return() { if (stack_offset > 0) builder.add_imm(31, 31, stack_offset); builder.ret(); } private: void push(std::uint8_t reg) { builder.sub_imm(31, 31, 16); builder.str(reg, 31, 0); stack_offset += 16; scopes.back().stack_offset += 16; } void pop(std::uint8_t reg) { builder.ldr(reg, 31, 0); builder.add_imm(31, 31, 16); stack_offset -= 16; scopes.back().stack_offset -= 16; } void push_fp(std::uint8_t reg, std::uint8_t mode) { builder.sub_imm(31, 31, 16); builder.str_fp(0, mode, 31, 0); stack_offset += 16; scopes.back().stack_offset += 16; } void pop_fp(std::uint8_t reg, std::uint8_t mode) { builder.ldr_fp(reg, mode, 31, 0); builder.add_imm(31, 31, 16); stack_offset -= 16; scopes.back().stack_offset -= 16; } // Set register @reg to -1 (all bits = 1) void set_m1(std::uint8_t reg) { builder.or_not_reg(31, 31, reg); } // Sign- or zero-extend the register depending on the exact type void extend(std::uint8_t reg, types::type_ptr const & type) { reg_extend_visitor{{}, builder, reg}.apply(*type); } void scope_cleanup() { if (scopes.back().stack_offset > 0) { builder.add_imm(31, 31, scopes.back().stack_offset); stack_offset -= scopes.back().stack_offset; } } // Returns offset from function entry stack frame std::size_t lvalue_offset(ast::expression_ptr const & node) { if (auto identifier = std::get_if(node.get())) { if (identifier->variable_node) for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) if (auto jt = it->variables.find(identifier->variable_node); jt != it->variables.end()) return jt->second.frame_offset; throw std::runtime_error("Non-lvalue identifier: \"" + identifier->name + "\""); } else if (auto field_access = std::get_if(node.get())) { auto base_offset = lvalue_offset(field_access->object); auto type = ast::get_type(*field_access->object); if (auto struct_type = std::get_if(type.get())) { auto & struct_node = *struct_type->node; std::optional field_id; for (std::size_t i = 0; i < struct_node.fields.size(); ++i) if (struct_node.fields[i].name == field_access->field_name) { field_id = i; break; } if (!field_id) throw std::runtime_error("Invalid field \"" + field_access->field_name + "\""); return base_offset - struct_node.fields[*field_id].layout.offset; } else throw std::runtime_error("Invalid field access node"); } else if (auto array_access = std::get_if(node.get())) { throw std::runtime_error("Not implemented"); } throw std::runtime_error("Unknown lvalue node"); } }; // Main compilation visitor struct compile_visitor : ast::const_statement_visitor { program_context & pcontext; local_context & lcontext; instruction_builder & builder; using const_statement_visitor::apply; void apply(ast::expression_ptr const &) {} void apply(ast::assignment const &) {} void apply(ast::variable_declaration const &) {} void apply(ast::if_chain const & node) { for (auto const & block : node.blocks) apply(*block.statements); } void apply(ast::while_block const & node) { apply(*node.statements); } void apply(ast::function_definition const & node) { compile_function_visitor{{}, {}, pcontext, lcontext}.do_apply(node); apply(*node.statements); } void apply(ast::foreign_function_declaration const &) {} void apply(ast::return_statement const &) {} void apply(ast::struct_definition const &) {} void apply(ast::statement_list const & node) { lcontext.scopes.emplace_back(); populate_symbols_visitor{{}, lcontext}.apply(node); for (auto const & statement : node.statements) apply(*statement); // Don't pop_back entry point scope if (lcontext.scopes.size() > 1) lcontext.scopes.pop_back(); } }; } void compile(program_context & pcontext, ast::statement_list_ptr const & statements) { // Add a fake AST node for the entry point auto root = std::make_shared(ast::function_definition{ { {}, {}, std::make_shared(types::unit_type{}), {}, }, statements, {}, }); local_context lcontext; instruction_builder builder{pcontext.code}; populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements); compile_visitor{{}, pcontext, lcontext, builder}.apply(*root); for (auto const & resolve : lcontext.resolve) builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, lcontext.functions.at(resolve.node) - resolve.instruction_offset); for (auto const & foreign : lcontext.foreign_address) pcontext.foreign_resolve.push_back({foreign.first, foreign.second}); for (auto const & function : lcontext.scopes.front().functions) pcontext.symbols[function.first] = lcontext.functions.at(function.second); pcontext.entry_point = lcontext.functions.at(std::get_if(root.get())); } }