Refactor IR compiler: get rid of useless scoping

This commit is contained in:
Nikita Lisitsa 2026-03-23 00:11:33 +03:00
parent ebc19fad20
commit e084b48fd3
2 changed files with 44 additions and 106 deletions

View file

@ -11,18 +11,16 @@ namespace pslang::ir
namespace namespace
{ {
struct local_context struct local_context
{ {
std::unordered_map<ast::function_definition const *, node_ref> functions; std::unordered_map<ast::function_definition const *, node_ref> functions;
std::unordered_map<ast::foreign_function_declaration const *, node_ref> foreign_functions;
std::unordered_map<ast::variable_base const *, node_ref> variables;
struct scope struct scope
{ {
std::string label_prefix; std::string label_prefix;
std::unordered_map<std::string, node_ref> variables;
std::unordered_set<std::string> foreign_functions;
std::unordered_map<std::string, ast::function_definition const *> functions;
std::unordered_map<std::string, ast::struct_definition const *> structs;
}; };
std::vector<scope> scopes; std::vector<scope> scopes;
@ -43,43 +41,6 @@ namespace pslang::ir
std::vector<resolve_call_data> resolve_call; std::vector<resolve_call_data> resolve_call;
}; };
// Iterate over a single scope (i.e. not visiting subscopes recursively)
// and add all defined functions & foreign functions to the current scope
struct populate_symbols_visitor
: ast::const_statement_visitor<populate_symbols_visitor>
{
local_context & lcontext;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_chain const &) {}
void apply(ast::while_block const &) {}
void apply(ast::function_definition const & node)
{
lcontext.scopes.back().functions[node.name] = &node;
}
void apply(ast::foreign_function_declaration const & node)
{
lcontext.scopes.back().foreign_functions.insert(node.name);
}
void apply(ast::return_statement const &) {}
void apply(ast::struct_definition const & node)
{
lcontext.scopes.back().structs[node.name] = &node;
}
};
// Compile a single function and store the entry point node_ref // Compile a single function and store the entry point node_ref
// in local_context // in local_context
struct compile_function_visitor struct compile_function_visitor
@ -109,25 +70,22 @@ namespace pslang::ir
node_ref apply(ast::identifier const & node) node_ref apply(ast::identifier const & node)
{ {
for (auto it = lcontext.scopes.rbegin(); it != lcontext.scopes.rend(); ++it) if (node.variable_node)
{ {
if (auto jt = it->variables.find(node.name); jt != it->variables.end()) return lcontext.variables.at(node.variable_node);
return jt->second; }
else if (node.function_node)
if (auto jt = it->foreign_functions.find(node.name); jt != it->foreign_functions.end()) {
mcontext.nodes.emplace_back(address{}, node.inferred_type);
lcontext.resolve_address.emplace_back(last(), node.function_node);
return last();
}
else if (node.foreign_function_node)
{ {
mcontext.nodes.emplace_back(extern_symbol{node.name}, node.inferred_type); mcontext.nodes.emplace_back(extern_symbol{node.name}, node.inferred_type);
return last(); return last();
} }
else
if (auto jt = it->functions.find(node.name); jt != it->functions.end())
{
mcontext.nodes.emplace_back(address{}, node.inferred_type);
lcontext.resolve_address.emplace_back(last(), jt->second);
return last();
}
}
throw std::runtime_error("Unknown identifier \"" + node.name + "\""); throw std::runtime_error("Unknown identifier \"" + node.name + "\"");
} }
@ -188,18 +146,12 @@ namespace pslang::ir
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
arguments.push_back(apply(*argument)); arguments.push_back(apply(*argument));
if (auto identifier = std::get_if<ast::identifier>(node.function.get())) if (auto identifier = std::get_if<ast::identifier>(node.function.get()); identifier && identifier->function_node)
{
for (auto it = lcontext.scopes.rbegin(); it != lcontext.scopes.rend(); ++it)
{
if (auto jt = it->functions.find(identifier->name); jt != it->functions.end())
{ {
mcontext.nodes.emplace_back(call{{}, std::move(arguments)}, node.inferred_type); mcontext.nodes.emplace_back(call{{}, std::move(arguments)}, node.inferred_type);
lcontext.resolve_call.emplace_back(last(), jt->second); lcontext.resolve_call.emplace_back(last(), identifier->function_node);
return last(); return last();
} }
}
}
auto function = apply(*node.function); auto function = apply(*node.function);
mcontext.nodes.emplace_back(call_pointer{function, std::move(arguments)}, node.inferred_type); mcontext.nodes.emplace_back(call_pointer{function, std::move(arguments)}, node.inferred_type);
@ -240,7 +192,7 @@ namespace pslang::ir
mcontext.nodes.emplace_back(copy{result}); mcontext.nodes.emplace_back(copy{result});
result = last(); result = last();
} }
lcontext.scopes.back().variables[node.name] = result; lcontext.variables[&node] = result;
return result; return result;
} }
@ -320,13 +272,6 @@ namespace pslang::ir
// Statement list // Statement list
void apply(ast::statement_list const & list)
{
populate_symbols_visitor{{}, lcontext}.apply(list);
for (auto const & statement : list.statements)
apply(*statement);
}
void do_apply(ast::function_definition const & node) void do_apply(ast::function_definition const & node)
{ {
auto before = last(); auto before = last();
@ -335,7 +280,7 @@ namespace pslang::ir
for (std::size_t i = 0; i < node.arguments.size(); ++i) for (std::size_t i = 0; i < node.arguments.size(); ++i)
{ {
mcontext.nodes.emplace_back(argument{i}, ast::get_type(*node.arguments[i].type)); mcontext.nodes.emplace_back(argument{i}, ast::get_type(*node.arguments[i].type));
lcontext.scopes.back().variables[node.arguments[i].name] = last(); lcontext.variables[&node.arguments[i]] = last();
} }
apply(*node.statements); apply(*node.statements);
@ -407,13 +352,6 @@ namespace pslang::ir
void apply(ast::return_statement const &) {} void apply(ast::return_statement const &) {}
void apply(ast::struct_definition const &) {} void apply(ast::struct_definition const &) {}
void apply(ast::statement_list const & list)
{
populate_symbols_visitor{{}, lcontext}.apply(list);
for (auto const & statement : list.statements)
apply(*statement);
}
}; };
} }
@ -445,8 +383,8 @@ namespace pslang::ir
for (auto const & resolve : lcontext.resolve_call) for (auto const & resolve : lcontext.resolve_call)
std::get<call>(resolve.call->instruction).target = lcontext.functions.at(resolve.target); std::get<call>(resolve.call->instruction).target = lcontext.functions.at(resolve.target);
for (auto const & symbol : lcontext.scopes.front().functions) for (auto const & symbol : lcontext.functions)
mcontext.symbols[symbol.second] = lcontext.functions.at(symbol.second); mcontext.symbols[symbol.first] = symbol.second;
mcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get())); mcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
} }