Refactor IR compiler: get rid of useless scoping

This commit is contained in:
Nikita Lisitsa 2026-03-23 00:11:33 +03:00
parent ebc19fad20
commit e084b48fd3
2 changed files with 44 additions and 106 deletions

View file

@ -8,23 +8,23 @@
namespace pslang::ast
{
struct function_definition;
struct function_definition;
}
namespace pslang::ir
{
struct module_context
{
node_list nodes;
struct module_context
{
node_list nodes;
std::unordered_map<node const *, std::string> labels;
std::unordered_map<node const *, std::string> labels;
std::unordered_map<ast::function_definition const *, node_ref> symbols;
node_ref entry_point;
};
std::unordered_map<ast::function_definition const *, node_ref> symbols;
node_ref entry_point;
};
void compile(module_context & context, ast::statement_list_ptr const & statements);
void compile(module_context & context, ast::statement_list_ptr const & statements);
}

View file

@ -11,18 +11,16 @@ namespace pslang::ir
namespace
{
struct local_context
{
std::unordered_map<ast::function_definition const *, node_ref> functions;
std::unordered_map<ast::foreign_function_declaration const *, node_ref> foreign_functions;
std::unordered_map<ast::variable_base const *, node_ref> variables;
struct scope
{
std::string label_prefix;
std::unordered_map<std::string, node_ref> variables;
std::unordered_set<std::string> foreign_functions;
std::unordered_map<std::string, ast::function_definition const *> functions;
std::unordered_map<std::string, ast::struct_definition const *> structs;
};
std::vector<scope> scopes;
@ -43,43 +41,6 @@ namespace pslang::ir
std::vector<resolve_call_data> resolve_call;
};
// Iterate over a single scope (i.e. not visiting subscopes recursively)
// and add all defined functions & foreign functions to the current scope
struct populate_symbols_visitor
: ast::const_statement_visitor<populate_symbols_visitor>
{
local_context & lcontext;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_chain const &) {}
void apply(ast::while_block const &) {}
void apply(ast::function_definition const & node)
{
lcontext.scopes.back().functions[node.name] = &node;
}
void apply(ast::foreign_function_declaration const & node)
{
lcontext.scopes.back().foreign_functions.insert(node.name);
}
void apply(ast::return_statement const &) {}
void apply(ast::struct_definition const & node)
{
lcontext.scopes.back().structs[node.name] = &node;
}
};
// Compile a single function and store the entry point node_ref
// in local_context
struct compile_function_visitor
@ -109,26 +70,23 @@ namespace pslang::ir
node_ref apply(ast::identifier const & node)
{
for (auto it = lcontext.scopes.rbegin(); it != lcontext.scopes.rend(); ++it)
if (node.variable_node)
{
if (auto jt = it->variables.find(node.name); jt != it->variables.end())
return jt->second;
if (auto jt = it->foreign_functions.find(node.name); jt != it->foreign_functions.end())
{
mcontext.nodes.emplace_back(extern_symbol{node.name}, node.inferred_type);
return last();
}
if (auto jt = it->functions.find(node.name); jt != it->functions.end())
{
mcontext.nodes.emplace_back(address{}, node.inferred_type);
lcontext.resolve_address.emplace_back(last(), jt->second);
return last();
}
return lcontext.variables.at(node.variable_node);
}
throw std::runtime_error("Unknown identifier \"" + node.name + "\"");
else if (node.function_node)
{
mcontext.nodes.emplace_back(address{}, node.inferred_type);
lcontext.resolve_address.emplace_back(last(), node.function_node);
return last();
}
else if (node.foreign_function_node)
{
mcontext.nodes.emplace_back(extern_symbol{node.name}, node.inferred_type);
return last();
}
else
throw std::runtime_error("Unknown identifier \"" + node.name + "\"");
}
node_ref apply(ast::unary_operation const & node)
@ -188,17 +146,11 @@ namespace pslang::ir
for (auto const & argument : node.arguments)
arguments.push_back(apply(*argument));
if (auto identifier = std::get_if<ast::identifier>(node.function.get()))
if (auto identifier = std::get_if<ast::identifier>(node.function.get()); identifier && identifier->function_node)
{
for (auto it = lcontext.scopes.rbegin(); it != lcontext.scopes.rend(); ++it)
{
if (auto jt = it->functions.find(identifier->name); jt != it->functions.end())
{
mcontext.nodes.emplace_back(call{{}, std::move(arguments)}, node.inferred_type);
lcontext.resolve_call.emplace_back(last(), jt->second);
return last();
}
}
mcontext.nodes.emplace_back(call{{}, std::move(arguments)}, node.inferred_type);
lcontext.resolve_call.emplace_back(last(), identifier->function_node);
return last();
}
auto function = apply(*node.function);
@ -240,7 +192,7 @@ namespace pslang::ir
mcontext.nodes.emplace_back(copy{result});
result = last();
}
lcontext.scopes.back().variables[node.name] = result;
lcontext.variables[&node] = result;
return result;
}
@ -320,13 +272,6 @@ namespace pslang::ir
// Statement list
void apply(ast::statement_list const & list)
{
populate_symbols_visitor{{}, lcontext}.apply(list);
for (auto const & statement : list.statements)
apply(*statement);
}
void do_apply(ast::function_definition const & node)
{
auto before = last();
@ -335,7 +280,7 @@ namespace pslang::ir
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
mcontext.nodes.emplace_back(argument{i}, ast::get_type(*node.arguments[i].type));
lcontext.scopes.back().variables[node.arguments[i].name] = last();
lcontext.variables[&node.arguments[i]] = last();
}
apply(*node.statements);
@ -407,13 +352,6 @@ namespace pslang::ir
void apply(ast::return_statement const &) {}
void apply(ast::struct_definition const &) {}
void apply(ast::statement_list const & list)
{
populate_symbols_visitor{{}, lcontext}.apply(list);
for (auto const & statement : list.statements)
apply(*statement);
}
};
}
@ -422,15 +360,15 @@ namespace pslang::ir
{
// Add a fake AST node for the entry point
auto root = std::make_shared<ast::statement>(ast::function_definition{
{
{},
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
statements,
{
{},
});
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
statements,
{},
});
mcontext.nodes.emplace_back(nop{});
auto extra_nop = std::prev(mcontext.nodes.end());
@ -445,8 +383,8 @@ namespace pslang::ir
for (auto const & resolve : lcontext.resolve_call)
std::get<call>(resolve.call->instruction).target = lcontext.functions.at(resolve.target);
for (auto const & symbol : lcontext.scopes.front().functions)
mcontext.symbols[symbol.second] = lcontext.functions.at(symbol.second);
for (auto const & symbol : lcontext.functions)
mcontext.symbols[symbol.first] = symbol.second;
mcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
}