#include #include #include #include #include #include namespace pslang::ir { namespace { struct local_context { std::unordered_map functions; struct scope { std::string label_prefix; std::unordered_map variables; std::unordered_set foreign_functions; std::unordered_map functions; std::unordered_map structs; }; std::vector scopes; struct resolve_address_data { node_ref address; ast::function_definition const * target; }; struct resolve_call_data { node_ref call; ast::function_definition const * target; }; std::vector resolve_address; std::vector resolve_call; }; // Iterate over a single scope (i.e. not visiting subscopes recursively) // and add all defined functions & foreign functions to the current scope struct populate_symbols_visitor : ast::const_statement_visitor { local_context & lcontext; using const_statement_visitor::apply; void apply(ast::expression_ptr const &) {} void apply(ast::assignment const &) {} void apply(ast::variable_declaration const &) {} void apply(ast::if_block const &) {} void apply(ast::else_block const &) {} void apply(ast::else_if_block const &) {} void apply(ast::if_chain const &) {} void apply(ast::while_block const &) {} void apply(ast::function_definition const & node) { lcontext.scopes.back().functions[node.name] = &node; } void apply(ast::foreign_function_declaration const & node) { lcontext.scopes.back().foreign_functions.insert(node.name); } void apply(ast::return_statement const &) {} void apply(ast::field_definition const &) {} void apply(ast::struct_definition const & node) { lcontext.scopes.back().structs[node.name] = &node; } }; // Compile a single function and store the entry point node_ref // in local_context struct compile_function_visitor : ast::const_statement_visitor , ast::const_expression_visitor { using const_statement_visitor::apply; using const_expression_visitor::apply; module_context & mcontext; local_context & lcontext; template node_ref apply(Node const & node) { throw std::runtime_error(std::string("IR compile visitor not implemented for ") + typeid(Node).name()); } // Expressions template node_ref apply(ast::primitive_literal_base const & node) { mcontext.nodes.emplace_back(literal{node}, ast::get_type(node)); return last(); } node_ref apply(ast::identifier const & node) { for (auto it = lcontext.scopes.rbegin(); it != lcontext.scopes.rend(); ++it) { if (auto jt = it->variables.find(node.name); jt != it->variables.end()) return jt->second; if (auto jt = it->foreign_functions.find(node.name); jt != it->foreign_functions.end()) { mcontext.nodes.emplace_back(extern_symbol{node.name}, node.inferred_type); return last(); } if (auto jt = it->functions.find(node.name); jt != it->functions.end()) { mcontext.nodes.emplace_back(address{}, node.inferred_type); lcontext.resolve_address.emplace_back(last(), jt->second); return last(); } } throw std::runtime_error("Unknown identifier \"" + node.name + "\""); } node_ref apply(ast::unary_operation const & node) { auto arg1 = apply(*node.arg1); mcontext.nodes.emplace_back(unary_operation{node.type, arg1}, node.inferred_type); return last(); } node_ref apply(ast::binary_operation const & node) { auto arg1 = apply(*node.arg1); if (node.type == ast::binary_operation_type::logical_and) { mcontext.nodes.emplace_back(jump_if_zero{arg1}); auto jump = last(); auto arg2 = apply(*node.arg2); mcontext.nodes.emplace_back(binary_operation{ast::binary_operation_type::binary_and, arg1, arg2}, node.inferred_type); mcontext.nodes.emplace_back(assignment{arg1, last()}); mcontext.nodes.emplace_back(nop{}); std::get(jump->instruction).target = last(); return arg1; } if (node.type == ast::binary_operation_type::logical_or) { mcontext.nodes.emplace_back(unary_operation{ast::unary_operation_type::logical_not, arg1}, node.inferred_type); mcontext.nodes.emplace_back(jump_if_zero{last()}); auto jump = last(); auto arg2 = apply(*node.arg2); mcontext.nodes.emplace_back(binary_operation{ast::binary_operation_type::binary_or, arg1, arg2}, node.inferred_type); mcontext.nodes.emplace_back(assignment{arg1, last()}); mcontext.nodes.emplace_back(nop{}); std::get(jump->instruction).target = last(); return arg1; } auto arg2 = apply(*node.arg2); mcontext.nodes.emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type); return last(); } node_ref apply(ast::cast_operation const & node) { auto arg = apply(*node.expression); mcontext.nodes.emplace_back(cast_operation{arg, node.inferred_type}, node.inferred_type); return last(); } node_ref apply(ast::function_call const & node) { if (node.function) { std::vector arguments; for (auto const & argument : node.arguments) arguments.push_back(apply(*argument)); if (auto identifier = std::get_if(node.function.get())) { for (auto it = lcontext.scopes.rbegin(); it != lcontext.scopes.rend(); ++it) { if (auto jt = it->functions.find(identifier->name); jt != it->functions.end()) { mcontext.nodes.emplace_back(call{{}, std::move(arguments)}, node.inferred_type); lcontext.resolve_call.emplace_back(last(), jt->second); return last(); } } } auto function = apply(*node.function); mcontext.nodes.emplace_back(call_pointer{function, std::move(arguments)}, node.inferred_type); return last(); } else // if (node.type) { throw std::runtime_error("IR compile visitor not implemented for type constructors"); } } // TODO: array, array_access, field_access // Statements node_ref apply(ast::expression_ptr const & node) { apply(*node); return last(); } node_ref apply(ast::assignment const & node) { auto rhs = apply(*node.rhs); auto lhs = apply(*node.lhs); mcontext.nodes.emplace_back(assignment{lhs, rhs}, ast::get_type(*node.rhs)); return last(); } node_ref apply(ast::variable_declaration const & node) { apply(*node.initializer); lcontext.scopes.back().variables[node.name] = last(); return last(); } node_ref apply(ast::if_chain const & node) { std::vector jumps_to_end; for (auto const & block : node.blocks) { std::optional jump_to_next; if (block.condition) { auto condition = apply(*block.condition); mcontext.nodes.emplace_back(jump_if_zero{condition, {}}); jump_to_next = last(); } lcontext.scopes.emplace_back(); apply(*block.statements); lcontext.scopes.pop_back(); mcontext.nodes.emplace_back(jump{}); jumps_to_end.push_back(last()); mcontext.nodes.emplace_back(nop{}); if (jump_to_next) std::get((*jump_to_next)->instruction).target = last(); } auto end = last(); for (auto const & jump_to_end : jumps_to_end) std::get(jump_to_end->instruction).target = end; return end; } node_ref apply(ast::while_block const & node) { auto before = last(); auto condition = apply(*node.condition); mcontext.nodes.emplace_back(jump_if_zero{condition, {}}); auto jump1 = last(); lcontext.scopes.emplace_back(); apply(*node.statements); lcontext.scopes.pop_back(); mcontext.nodes.emplace_back(jump{std::next(before)}); mcontext.nodes.emplace_back(nop{}); std::get(jump1->instruction).target = last(); return last(); } node_ref apply(ast::function_definition const & node) { return last(); } node_ref apply(ast::foreign_function_declaration const & node) { return last(); } node_ref apply(ast::return_statement const & node) { if (node.value) { auto value = apply(*node.value); mcontext.nodes.emplace_back(return_value{value}, value->inferred_type); } else mcontext.nodes.emplace_back(return_value{}, std::make_shared(types::unit_type{})); return last(); } // TODO: struct_definition, field_definition // Statement list void apply(ast::statement_list const & list) { populate_symbols_visitor{{}, lcontext}.apply(list); for (auto const & statement : list.statements) apply(*statement); } void do_apply(ast::function_definition const & node) { auto before = last(); lcontext.scopes.emplace_back(); for (std::size_t i = 0; i < node.arguments.size(); ++i) { mcontext.nodes.emplace_back(argument{i}, ast::get_type(*node.arguments[i].type)); lcontext.scopes.back().variables[node.arguments[i].name] = last(); } apply(*node.statements); if (types::equal(*ast::get_type(*node.return_type), types::unit_type{})) mcontext.nodes.emplace_back(return_value{}, std::make_shared(types::unit_type{})); lcontext.scopes.pop_back(); auto entry = std::next(before); lcontext.functions[&node] = entry; mcontext.labels[&(*entry)] = lcontext.scopes.back().label_prefix + node.name; } private: node_ref last() { return std::prev(mcontext.nodes.end()); } }; // Main compilation visitor struct compile_visitor : ast::const_statement_visitor { module_context & mcontext; local_context & lcontext; using const_statement_visitor::apply; void apply(ast::expression_ptr const &) {} void apply(ast::assignment const &) {} void apply(ast::variable_declaration const &) {} void apply(ast::if_block const &) {} void apply(ast::else_block const &) {} void apply(ast::else_if_block const &) {} void apply(ast::if_chain const & node) { for (auto const & block : node.blocks) { lcontext.scopes.emplace_back(); apply(*block.statements); lcontext.scopes.pop_back(); } } void apply(ast::while_block const & node) { lcontext.scopes.emplace_back(); apply(*node.statements); lcontext.scopes.pop_back(); } void apply(ast::function_definition const & node) { compile_function_visitor{{}, {}, mcontext, lcontext}.do_apply(node); std::string label_prefix; if (!lcontext.scopes.empty()) label_prefix = lcontext.scopes.back().label_prefix + node.name + "."; lcontext.scopes.emplace_back(std::move(label_prefix)); apply(*node.statements); // Don't pop_back entry point scope if (lcontext.scopes.size() > 1) lcontext.scopes.pop_back(); } void apply(ast::foreign_function_declaration const &) {} void apply(ast::return_statement const &) {} void apply(ast::field_definition const &) {} void apply(ast::struct_definition const &) {} void apply(ast::statement_list const & list) { populate_symbols_visitor{{}, lcontext}.apply(list); for (auto const & statement : list.statements) apply(*statement); } }; } void compile(module_context & mcontext, ast::statement_list_ptr const & statements) { // Add a fake AST node for the entry point auto root = std::make_shared(ast::function_definition{ { {}, {}, std::make_shared(types::unit_type{}), {}, }, statements, {}, }); mcontext.nodes.emplace_back(nop{}); auto extra_nop = std::prev(mcontext.nodes.end()); local_context lcontext; compile_visitor{{}, mcontext, lcontext}.apply(*root); mcontext.labels[&(*std::next(extra_nop))] = "[entry point]"; mcontext.nodes.erase(extra_nop); for (auto const & resolve : lcontext.resolve_address) std::get
(resolve.address->instruction).target = lcontext.functions.at(resolve.target); for (auto const & resolve : lcontext.resolve_call) std::get(resolve.call->instruction).target = lcontext.functions.at(resolve.target); for (auto const & symbol : lcontext.scopes.front().functions) mcontext.symbols[symbol.second] = lcontext.functions.at(symbol.second); mcontext.entry_point = lcontext.functions.at(std::get_if(root.get())); } }