#include #include #include #include #include #include namespace pslang::ir { namespace { struct local_context { std::unordered_map> functions; std::unordered_map variables; struct scope { std::string label_prefix; }; std::vector scopes; struct resolve_address_data { node_ref address; ast::function_definition const * target; }; struct resolve_call_data { node_ref call; ast::function_definition const * target; }; std::vector resolve_address; std::vector resolve_call; }; struct zero_literal_visitor : types::const_visitor { module_context & mcontext; using const_visitor::apply; template void apply(types::primitive_type_base) { mcontext.nodes->emplace_back(literal{ast::literal{ast::primitive_literal_base{.value = {}}}}, std::make_shared(types::primitive_type{types::primitive_type_base{}})); } template void apply(T const &) { throw std::runtime_error("Invalid type for zero_literal_visitor"); } }; // Compile a single function and store the entry point node_ref // in local_context struct compile_function_visitor : ast::const_statement_visitor , ast::const_expression_visitor { using const_statement_visitor::apply; using const_expression_visitor::apply; module_context & mcontext; local_context & lcontext; template node_ref apply(Node const & node) { throw std::runtime_error(std::string("IR compile visitor not implemented for ") + typeid(Node).name()); } // Expressions template node_ref apply(ast::primitive_literal_base const & node) { mcontext.nodes->emplace_back(literal{node}, ast::get_type(node)); return last(); } node_ref apply(ast::identifier const & node) { if (node.constant_node) { // Temporary hack: replace reading constants with their ininitialization code return apply(*node.constant_node->initializer); } else if (node.variable_node) { return lcontext.variables.at(node.variable_node); } else if (node.function_node) { mcontext.nodes->emplace_back(instruction_address{}, node.inferred_type); lcontext.resolve_address.emplace_back(last(), node.function_node); return last(); } else if (node.foreign_function_node) { mcontext.nodes->emplace_back(extern_symbol{node.name}, node.inferred_type); return last(); } else throw std::runtime_error("Unknown identifier \"" + node.name + "\""); } node_ref apply(ast::unary_operation const & node) { auto arg1 = apply(*node.arg1); if (node.type == ast::unary_operation_type::dereference) { mcontext.nodes->emplace_back(load{arg1}, node.inferred_type); return last(); } mcontext.nodes->emplace_back(unary_operation{node.type, arg1}, node.inferred_type); return last(); } node_ref apply(ast::binary_operation const & node) { auto arg1 = apply(*node.arg1); // Short-circuit operators if (node.type == ast::binary_operation_type::logical_and) { mcontext.nodes->emplace_back(copy{arg1}, ast::get_type(*node.arg1)); auto result = last(); mcontext.nodes->emplace_back(jump_if_zero{result}); auto jump = last(); auto arg2 = apply(*node.arg2); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_and, result, arg2}, node.inferred_type); mcontext.nodes->emplace_back(assignment{result, last()}); mcontext.nodes->emplace_back(label{}); std::get(jump->instruction).target = last(); return result; } if (node.type == ast::binary_operation_type::logical_or) { mcontext.nodes->emplace_back(copy{arg1}, ast::get_type(*node.arg1)); auto result = last(); mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::logical_not, result}, node.inferred_type); mcontext.nodes->emplace_back(jump_if_zero{last()}); auto jump = last(); auto arg2 = apply(*node.arg2); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_or, result, arg2}, node.inferred_type); mcontext.nodes->emplace_back(assignment{result, last()}); mcontext.nodes->emplace_back(label{}); std::get(jump->instruction).target = last(); return result; } // Pointer arithmetic auto arg2 = apply(*node.arg2); auto arg1_type = get_type(*node.arg1); auto arg2_type = get_type(*node.arg2); auto arg1_is_pointer = types::is_pointer_type(*arg1_type); auto arg2_is_pointer = types::is_pointer_type(*arg2_type); if ((node.type == ast::binary_operation_type::addition || node.type == ast::binary_operation_type::subtraction) && (arg1_is_pointer || arg2_is_pointer)) { // Pointer types are equal and referenced types are non-empty - guaranteed by type checker std::int64_t element_size = 0; if (arg1_is_pointer) element_size = ast::type_size(*std::get(*arg1_type).referenced_type); else element_size = ast::type_size(*std::get(*arg2_type).referenced_type); auto i64_type = std::make_shared(types::primitive_type{types::i64_type{}}); mcontext.nodes->emplace_back(literal{ast::i64_literal{element_size}}, i64_type); auto element_size_node = last(); if (node.type == ast::binary_operation_type::addition) { if (arg1_is_pointer) { mcontext.nodes->emplace_back(cast_operation{arg2, i64_type}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg1, last()}, node.inferred_type); } else // if (arg2_is_pointer) { mcontext.nodes->emplace_back(cast_operation{arg1, i64_type}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg2, last()}, node.inferred_type); } } else if (node.type == ast::binary_operation_type::subtraction) { mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::subtraction, arg1, arg2}, node.inferred_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::division, last(), element_size_node}, i64_type); } return last(); } // Different-type integer comparison if (!types::equal(*arg1_type, *arg2_type) && types::is_integer_type(*arg1_type) && types::is_integer_type(*arg2_type) && ast::is_comparison(node.type)) { bool arg1_unsigned = types::is_unsigned_integer_type(*arg1_type); bool arg2_unsigned = types::is_unsigned_integer_type(*arg2_type); std::size_t arg1_size = types::type_size(*arg1_type); std::size_t arg2_size = types::type_size(*arg2_type); std::size_t max_size = std::max(arg1_size, arg2_size); types::type_ptr max_type = (arg1_size > arg2_size) ? arg1_type : arg2_type; if ((arg1_unsigned && arg2_unsigned) || (!arg1_unsigned && !arg2_unsigned)) { // Both signed or both unsigned: just cast the smaller one to the larger type if (arg1_size < arg2_size) { mcontext.nodes->emplace_back(cast_operation{arg1, max_type}, max_type); mcontext.nodes->emplace_back(binary_operation{node.type, last(), arg2}, node.inferred_type); return last(); } else { mcontext.nodes->emplace_back(cast_operation{arg2, max_type}, max_type); mcontext.nodes->emplace_back(binary_operation{node.type, arg1, last()}, node.inferred_type); return last(); } } else { // Different signedness // Swap arg1 and arg2 if arg1 is unsigned, and reverse the operation auto type = node.type; if (arg1_unsigned) { std::swap(arg1, arg2); std::swap(arg1_type, arg2_type); std::swap(arg1_size, arg2_size); std::swap(arg1_is_pointer, arg2_is_pointer); if (type == ast::binary_operation_type::less) type = ast::binary_operation_type::greater; else if (type == ast::binary_operation_type::greater) type = ast::binary_operation_type::less; else if (type == ast::binary_operation_type::less_equals) type = ast::binary_operation_type::greater_equals; else if (type == ast::binary_operation_type::greater_equals) type = ast::binary_operation_type::less_equals; } // Compare with zero first zero_literal_visitor{{}, mcontext}.apply(*arg1_type); switch (type) { case ast::binary_operation_type::equals: mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type); mcontext.nodes->emplace_back(jump_if_zero{last(), {}}); break; case ast::binary_operation_type::not_equals: mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type); mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}}); break; case ast::binary_operation_type::less: mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type); mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}}); break; case ast::binary_operation_type::greater: mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater, arg1, last()}, node.inferred_type); mcontext.nodes->emplace_back(jump_if_zero{last(), {}}); break; case ast::binary_operation_type::less_equals: mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less_equals, arg1, last()}, node.inferred_type); mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}}); break; case ast::binary_operation_type::greater_equals: mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type); mcontext.nodes->emplace_back(jump_if_zero{last(), {}}); break; default: break; } auto result = std::prev(last()); auto jump_node = last(); // Less-than-zero case handled, here arg1 is nonnegative - cast both to the largest unsigned type types::type_ptr max_unsigned_type; if (max_size == 1) max_unsigned_type = std::make_unique(types::primitive_type{types::u8_type{}}); else if (max_size == 2) max_unsigned_type = std::make_unique(types::primitive_type{types::u16_type{}}); else if (max_size == 4) max_unsigned_type = std::make_unique(types::primitive_type{types::u32_type{}}); else if (max_size == 8) max_unsigned_type = std::make_unique(types::primitive_type{types::u64_type{}}); mcontext.nodes->emplace_back(cast_operation{arg1, max_unsigned_type}, max_unsigned_type); auto new_arg1 = last(); mcontext.nodes->emplace_back(cast_operation{arg2, max_unsigned_type}, max_unsigned_type); auto new_arg2 = last(); mcontext.nodes->emplace_back(binary_operation{type, new_arg1, new_arg2}, node.inferred_type); mcontext.nodes->emplace_back(assignment{result, last()}); mcontext.nodes->emplace_back(label{}); if (auto jump_if_zero = std::get_if(&jump_node->instruction)) jump_if_zero->target = last(); else if (auto jump_if_nonzero = std::get_if(&jump_node->instruction)) jump_if_nonzero->target = last(); return result; } } // General case mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type); return last(); } node_ref apply(ast::cast_operation const & node) { auto arg = apply(*node.expression); mcontext.nodes->emplace_back(cast_operation{arg, node.inferred_type}, node.inferred_type); return last(); } node_ref apply(ast::function_call const & node) { if (node.function) { std::vector arguments; for (auto const & argument : node.arguments) arguments.push_back(apply(*argument)); if (auto identifier = std::get_if(node.function.get()); identifier && identifier->function_node) { mcontext.nodes->emplace_back(call{{}, std::move(arguments)}, node.inferred_type); lcontext.resolve_call.emplace_back(last(), identifier->function_node); return last(); } auto function = apply(*node.function); mcontext.nodes->emplace_back(call_pointer{function, std::move(arguments)}, node.inferred_type); return last(); } else // if (node.type) { auto type = ast::get_type(*node.type); if (auto struct_type = std::get_if(type.get())) { mcontext.nodes->emplace_back(alloc{}, node.inferred_type); auto result = last(); for (std::size_t i = 0; i < node.arguments.size(); ++i) { auto const & field = struct_type->node->fields[i]; auto arg = apply(*node.arguments[i]); mcontext.nodes->emplace_back(assignment{result, arg, {i}}, field.inferred_type); } return result; } throw std::runtime_error("Type constructors are not implemented for non-struct types"); } } node_ref apply(ast::array const & node) { mcontext.nodes->emplace_back(alloc{}, node.inferred_type); auto array = last(); for (std::size_t i = 0; i < node.elements.size(); ++i) { auto element = apply(node.elements[i]); mcontext.nodes->emplace_back(assignment{array, element, {i}}); } return array; } node_ref apply(ast::array_access const & node) { auto object_type = ast::get_type(*node.array); auto array_type = std::get_if(object_type.get()); auto pointer_type = std::get_if(object_type.get()); if (array_type || pointer_type) { types::type_ptr element_type; if (array_type) element_type = array_type->element_type; else // if (pointer_type) element_type = pointer_type->referenced_type; auto target_pointer_type = std::make_shared(types::pointer_type{element_type, true}); node_ref base_ptr; if (array_type) { auto array = apply(*node.array); mcontext.nodes->emplace_back(cast_operation{array, target_pointer_type}, target_pointer_type); base_ptr = last(); } else // if (pointer_type) { base_ptr = apply(*node.array); } auto i64_type = std::make_shared(types::primitive_type{types::i64_type{}}); std::int64_t element_size = ast::type_size(*element_type); auto index = apply(*node.index); mcontext.nodes->emplace_back(cast_operation{index, i64_type}, i64_type); auto index_cast = last(); mcontext.nodes->emplace_back(literal{ast::i64_literal{element_size}}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, index_cast, last()}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, base_ptr, last()}, target_pointer_type); mcontext.nodes->emplace_back(load{last()}, node.inferred_type); return last(); } else throw std::runtime_error("Unknown array access left-hand side"); } node_ref apply(ast::field_access const & node) { auto object = apply(*node.object); auto object_type = ast::get_type(*node.object); auto struct_node = std::get_if(object_type.get())->node; for (std::size_t i = 0; i < struct_node->fields.size(); ++i) { if (struct_node->fields[i].name == node.field_name) { mcontext.nodes->emplace_back(copy{object, {i}}, node.inferred_type); return last(); } } throw std::runtime_error("Unknown field name"); } // TODO: array, array_access, field_access // Statements node_ref apply(ast::expression_ptr const & node) { return apply(*node); } std::optional apply_field_chain_assignment(ast::expression_ptr const & lhs_node, node_ref rhs, std::vector path) { if (auto identifier = std::get_if(lhs_node.get())) { auto lhs = apply(*lhs_node); mcontext.nodes->emplace_back(assignment{lhs, rhs, std::move(path)}, identifier->inferred_type); return last(); } else if (auto field_access = std::get_if(lhs_node.get())) { auto object_type = ast::get_type(*field_access->object); auto struct_node = std::get_if(object_type.get())->node; for (std::size_t i = 0; i < struct_node->fields.size(); ++i) { auto const & field = struct_node->fields[i]; if (field.name == field_access->field_name) { path.push_back(i); return apply_field_chain_assignment(field_access->object, rhs, std::move(path)); } } } return std::nullopt; } std::optional apply_get_address(ast::expression_ptr const & node) { auto result_type = std::make_shared(types::pointer_type{ast::get_type(*node), true}); if (auto identifier = std::get_if(node.get())) { auto object = apply(*node); mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::address_of, object}, result_type); return last(); } if (auto unary_operation = std::get_if(node.get())) { if (unary_operation->type == ast::unary_operation_type::dereference) { return apply(*unary_operation->arg1); } } if (auto array_access = std::get_if(node.get())) { auto object_type = get_type(*array_access->array); auto array_type = std::get_if(object_type.get()); auto pointer_type = std::get_if(object_type.get()); if (array_type || pointer_type) { types::type_ptr element_type; if (array_type) element_type = array_type->element_type; else // if (pointer_type) element_type = pointer_type->referenced_type; auto target_pointer_type = std::make_shared(types::pointer_type{element_type, true}); node_ref base_ptr; if (array_type) { auto array = apply(*array_access->array); mcontext.nodes->emplace_back(cast_operation{array, target_pointer_type}, target_pointer_type); base_ptr = last(); } else // if (pointer_type) { base_ptr = apply(*array_access->array); } auto i64_type = std::make_shared(types::primitive_type{types::i64_type{}}); std::int64_t element_size = ast::type_size(*element_type); auto index = apply(*array_access->index); mcontext.nodes->emplace_back(cast_operation{index, i64_type}, i64_type); auto index_cast = last(); mcontext.nodes->emplace_back(literal{ast::i64_literal{element_size}}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, index_cast, last()}, i64_type); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, base_ptr, last()}, target_pointer_type); return last(); } } if (auto field_access = std::get_if(node.get())) { auto object_type = ast::get_type(*field_access->object); auto struct_node = std::get_if(object_type.get())->node; if (auto object_ptr = apply_get_address(field_access->object)) { for (std::size_t i = 0; i < struct_node->fields.size(); ++i) { auto const & field = struct_node->fields[i]; if (field.name == field_access->field_name) { mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}}, std::make_shared(types::primitive_type{types::u64_type{}})); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, *object_ptr, last()}, result_type); return last(); } } throw std::runtime_error("Unknown field name"); } } return std::nullopt; } node_ref apply(ast::assignment const & node) { auto rhs = apply(*node.rhs); // Detect compound field access (like a.b.c = 1) - a sequence of field_access nodes // terminating in an identifier node. if (auto result = apply_field_chain_assignment(node.lhs, rhs, {})) return *result; // Otherwise, compile into explicit memory store if (auto lhs_ptr = apply_get_address(node.lhs)) { mcontext.nodes->emplace_back(store{*lhs_ptr, rhs}, ast::get_type(*node.rhs)); return last(); } throw std::runtime_error("Unknown assignment left-hand side"); } node_ref apply(ast::variable_declaration const & node) { auto before = last(); auto result = apply(*node.initializer); if (result == before) { // Evaluating variable initializer didn't produce any nodes // It must have been just a reference to another variable or smth like that // Introduce a copy node to prevent accidental variable coalescing mcontext.nodes->emplace_back(copy{result}); result = last(); } lcontext.variables[&node] = result; return result; } node_ref apply(ast::if_chain const & node) { std::vector jumps_to_end; for (auto const & block : node.blocks) { std::optional jump_to_next; if (block.condition) { auto condition = apply(*block.condition); mcontext.nodes->emplace_back(jump_if_zero{condition, {}}); jump_to_next = last(); } lcontext.scopes.emplace_back(); apply(*block.statements); lcontext.scopes.pop_back(); mcontext.nodes->emplace_back(jump{}); jumps_to_end.push_back(last()); mcontext.nodes->emplace_back(label{}); if (jump_to_next) std::get((*jump_to_next)->instruction).target = last(); } auto end = last(); for (auto const & jump_to_end : jumps_to_end) std::get(jump_to_end->instruction).target = end; return end; } node_ref apply(ast::while_block const & node) { mcontext.nodes->emplace_back(label{}); auto begin = last(); auto condition = apply(*node.condition); mcontext.nodes->emplace_back(jump_if_zero{condition, {}}); auto jump1 = last(); lcontext.scopes.emplace_back(); apply(*node.statements); lcontext.scopes.pop_back(); mcontext.nodes->emplace_back(jump{begin}); mcontext.nodes->emplace_back(label{}); std::get(jump1->instruction).target = last(); return last(); } node_ref apply(ast::function_definition const & node) { return last(); } node_ref apply(ast::foreign_function_declaration const & node) { return last(); } node_ref apply(ast::return_statement const & node) { if (node.value) { auto value = apply(*node.value); mcontext.nodes->emplace_back(return_value{value}, value->inferred_type); } else mcontext.nodes->emplace_back(return_value{}, std::make_shared(types::unit_type{})); return last(); } node_ref apply(ast::struct_definition const & node) { return last(); } // Statement list void do_apply(ast::function_definition const & node) { mcontext.nodes->emplace_back(label{}); auto begin = last(); lcontext.scopes.emplace_back(); for (std::size_t i = 0; i < node.arguments.size(); ++i) { mcontext.nodes->emplace_back(argument{i}, ast::get_type(*node.arguments[i].type)); lcontext.variables[&node.arguments[i]] = last(); } apply(*node.statements); if (types::equal(*ast::get_type(*node.return_type), types::unit_type{})) mcontext.nodes->emplace_back(return_value{}, std::make_shared(types::unit_type{})); auto end = last(); lcontext.functions[&node] = {begin, end}; mcontext.labels[&(*begin)] = lcontext.scopes.back().label_prefix + node.name; lcontext.scopes.pop_back(); } private: node_ref last() { return std::prev(mcontext.nodes->end()); } }; // Main compilation visitor struct compile_visitor : ast::const_statement_visitor { module_context & mcontext; local_context & lcontext; using const_statement_visitor::apply; void apply(ast::expression_ptr const &) {} void apply(ast::assignment const &) {} void apply(ast::variable_declaration const &) {} void apply(ast::if_chain const & node) { for (auto const & block : node.blocks) { lcontext.scopes.emplace_back(); apply(*block.statements); lcontext.scopes.pop_back(); } } void apply(ast::while_block const & node) { lcontext.scopes.emplace_back(); apply(*node.statements); lcontext.scopes.pop_back(); } void apply(ast::function_definition const & node) { compile_function_visitor{{}, {}, mcontext, lcontext}.do_apply(node); std::string label_prefix; if (!lcontext.scopes.empty()) label_prefix = lcontext.scopes.back().label_prefix + node.name + "."; lcontext.scopes.emplace_back(std::move(label_prefix)); apply(*node.statements); // Don't pop_back entry point scope if (lcontext.scopes.size() > 1) lcontext.scopes.pop_back(); } void apply(ast::foreign_function_declaration const &) {} void apply(ast::return_statement const &) {} void apply(ast::struct_definition const &) {} }; } void compile(module_context & mcontext, ast::statement_ptr const & root) { if (!mcontext.nodes) mcontext.nodes = std::make_shared(); local_context lcontext; mcontext.nodes->emplace_back(label{}); auto extra_label = std::prev(mcontext.nodes->end()); compile_visitor{{}, mcontext, lcontext}.apply(*root); mcontext.nodes->erase(extra_label); for (auto & symbol : lcontext.functions) symbol.second.second++; for (auto const & resolve : lcontext.resolve_address) std::get(resolve.address->instruction).target = lcontext.functions.at(resolve.target).first; for (auto const & resolve : lcontext.resolve_call) std::get(resolve.call->instruction).target = lcontext.functions.at(resolve.target).first; for (auto const & symbol : lcontext.functions) mcontext.symbols[symbol.first] = {.begin = symbol.second.first, .end = symbol.second.second}; mcontext.entry_point = lcontext.functions.at(std::get_if(root.get())).first; } }