pslang/libs/ir/source/compiler.cpp
2026-03-23 00:26:58 +03:00

394 lines
11 KiB
C++

#include <pslang/ir/compiler.hpp>
#include <pslang/ir/node.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/ast/expression_visitor.hpp>
#include <unordered_map>
namespace pslang::ir
{
namespace
{
struct local_context
{
std::unordered_map<ast::function_definition const *, node_ref> functions;
std::unordered_map<ast::foreign_function_declaration const *, node_ref> foreign_functions;
std::unordered_map<ast::variable_base const *, node_ref> variables;
struct scope
{
std::string label_prefix;
};
std::vector<scope> scopes;
struct resolve_address_data
{
node_ref address;
ast::function_definition const * target;
};
struct resolve_call_data
{
node_ref call;
ast::function_definition const * target;
};
std::vector<resolve_address_data> resolve_address;
std::vector<resolve_call_data> resolve_call;
};
// Compile a single function and store the entry point node_ref
// in local_context
struct compile_function_visitor
: ast::const_statement_visitor<compile_function_visitor>
, ast::const_expression_visitor<compile_function_visitor>
{
using const_statement_visitor::apply;
using const_expression_visitor::apply;
module_context & mcontext;
local_context & lcontext;
template <typename Node>
node_ref apply(Node const & node)
{
throw std::runtime_error(std::string("IR compile visitor not implemented for ") + typeid(Node).name());
}
// Expressions
template <typename T>
node_ref apply(ast::primitive_literal_base<T> const & node)
{
mcontext.nodes.emplace_back(literal{node}, ast::get_type(node));
return last();
}
node_ref apply(ast::identifier const & node)
{
if (node.variable_node)
{
return lcontext.variables.at(node.variable_node);
}
else if (node.function_node)
{
mcontext.nodes.emplace_back(address{}, node.inferred_type);
lcontext.resolve_address.emplace_back(last(), node.function_node);
return last();
}
else if (node.foreign_function_node)
{
mcontext.nodes.emplace_back(extern_symbol{node.name}, node.inferred_type);
return last();
}
else
throw std::runtime_error("Unknown identifier \"" + node.name + "\"");
}
node_ref apply(ast::unary_operation const & node)
{
auto arg1 = apply(*node.arg1);
mcontext.nodes.emplace_back(unary_operation{node.type, arg1}, node.inferred_type);
return last();
}
node_ref apply(ast::binary_operation const & node)
{
auto arg1 = apply(*node.arg1);
if (node.type == ast::binary_operation_type::logical_and)
{
mcontext.nodes.emplace_back(jump_if_zero{arg1});
auto jump = last();
auto arg2 = apply(*node.arg2);
mcontext.nodes.emplace_back(binary_operation{ast::binary_operation_type::binary_and, arg1, arg2}, node.inferred_type);
mcontext.nodes.emplace_back(assignment{arg1, last()});
mcontext.nodes.emplace_back(nop{});
std::get<jump_if_zero>(jump->instruction).target = last();
return arg1;
}
if (node.type == ast::binary_operation_type::logical_or)
{
mcontext.nodes.emplace_back(unary_operation{ast::unary_operation_type::logical_not, arg1}, node.inferred_type);
mcontext.nodes.emplace_back(jump_if_zero{last()});
auto jump = last();
auto arg2 = apply(*node.arg2);
mcontext.nodes.emplace_back(binary_operation{ast::binary_operation_type::binary_or, arg1, arg2}, node.inferred_type);
mcontext.nodes.emplace_back(assignment{arg1, last()});
mcontext.nodes.emplace_back(nop{});
std::get<jump_if_zero>(jump->instruction).target = last();
return arg1;
}
auto arg2 = apply(*node.arg2);
mcontext.nodes.emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type);
return last();
}
node_ref apply(ast::cast_operation const & node)
{
auto arg = apply(*node.expression);
mcontext.nodes.emplace_back(cast_operation{arg, node.inferred_type}, node.inferred_type);
return last();
}
node_ref apply(ast::function_call const & node)
{
if (node.function)
{
std::vector<node_ref> arguments;
for (auto const & argument : node.arguments)
arguments.push_back(apply(*argument));
if (auto identifier = std::get_if<ast::identifier>(node.function.get()); identifier && identifier->function_node)
{
mcontext.nodes.emplace_back(call{{}, std::move(arguments)}, node.inferred_type);
lcontext.resolve_call.emplace_back(last(), identifier->function_node);
return last();
}
auto function = apply(*node.function);
mcontext.nodes.emplace_back(call_pointer{function, std::move(arguments)}, node.inferred_type);
return last();
}
else // if (node.type)
{
throw std::runtime_error("IR compile visitor not implemented for type constructors");
}
}
// TODO: array, array_access, field_access
// Statements
node_ref apply(ast::expression_ptr const & node)
{
return apply(*node);
}
node_ref apply(ast::assignment const & node)
{
auto rhs = apply(*node.rhs);
auto lhs = apply(*node.lhs);
mcontext.nodes.emplace_back(assignment{lhs, rhs}, ast::get_type(*node.rhs));
return last();
}
node_ref apply(ast::variable_declaration const & node)
{
auto before = last();
auto result = apply(*node.initializer);
if (result == before)
{
// Evaluating variable initializer didn't produce any nodes
// It must have been just a reference to another variable or smth like that
// Introduce a copy node to prevent accidental variable coalescing
mcontext.nodes.emplace_back(copy{result});
result = last();
}
lcontext.variables[&node] = result;
return result;
}
node_ref apply(ast::if_chain const & node)
{
std::vector<node_ref> jumps_to_end;
for (auto const & block : node.blocks)
{
std::optional<node_ref> jump_to_next;
if (block.condition)
{
auto condition = apply(*block.condition);
mcontext.nodes.emplace_back(jump_if_zero{condition, {}});
jump_to_next = last();
}
lcontext.scopes.emplace_back();
apply(*block.statements);
lcontext.scopes.pop_back();
mcontext.nodes.emplace_back(jump{});
jumps_to_end.push_back(last());
mcontext.nodes.emplace_back(nop{});
if (jump_to_next)
std::get<jump_if_zero>((*jump_to_next)->instruction).target = last();
}
auto end = last();
for (auto const & jump_to_end : jumps_to_end)
std::get<jump>(jump_to_end->instruction).target = end;
return end;
}
node_ref apply(ast::while_block const & node)
{
auto before = last();
auto condition = apply(*node.condition);
mcontext.nodes.emplace_back(jump_if_zero{condition, {}});
auto jump1 = last();
lcontext.scopes.emplace_back();
apply(*node.statements);
lcontext.scopes.pop_back();
mcontext.nodes.emplace_back(jump{std::next(before)});
mcontext.nodes.emplace_back(nop{});
std::get<jump_if_zero>(jump1->instruction).target = last();
return last();
}
node_ref apply(ast::function_definition const & node)
{
return last();
}
node_ref apply(ast::foreign_function_declaration const & node)
{
return last();
}
node_ref apply(ast::return_statement const & node)
{
if (node.value)
{
auto value = apply(*node.value);
mcontext.nodes.emplace_back(return_value{value}, value->inferred_type);
}
else
mcontext.nodes.emplace_back(return_value{}, std::make_shared<types::type>(types::unit_type{}));
return last();
}
node_ref apply(ast::struct_definition const & node)
{
return last();
}
// Statement list
void do_apply(ast::function_definition const & node)
{
auto before = last();
lcontext.scopes.emplace_back();
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
mcontext.nodes.emplace_back(argument{i}, ast::get_type(*node.arguments[i].type));
lcontext.variables[&node.arguments[i]] = last();
}
apply(*node.statements);
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
mcontext.nodes.emplace_back(return_value{}, std::make_shared<types::type>(types::unit_type{}));
lcontext.scopes.pop_back();
auto entry = std::next(before);
lcontext.functions[&node] = entry;
mcontext.labels[&(*entry)] = lcontext.scopes.back().label_prefix + node.name;
}
private:
node_ref last()
{
return std::prev(mcontext.nodes.end());
}
};
// Main compilation visitor
struct compile_visitor
: ast::const_statement_visitor<compile_visitor>
{
module_context & mcontext;
local_context & lcontext;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_chain const & node)
{
for (auto const & block : node.blocks)
{
lcontext.scopes.emplace_back();
apply(*block.statements);
lcontext.scopes.pop_back();
}
}
void apply(ast::while_block const & node)
{
lcontext.scopes.emplace_back();
apply(*node.statements);
lcontext.scopes.pop_back();
}
void apply(ast::function_definition const & node)
{
compile_function_visitor{{}, {}, mcontext, lcontext}.do_apply(node);
std::string label_prefix;
if (!lcontext.scopes.empty())
label_prefix = lcontext.scopes.back().label_prefix + node.name + ".";
lcontext.scopes.emplace_back(std::move(label_prefix));
apply(*node.statements);
// Don't pop_back entry point scope
if (lcontext.scopes.size() > 1)
lcontext.scopes.pop_back();
}
void apply(ast::foreign_function_declaration const &) {}
void apply(ast::return_statement const &) {}
void apply(ast::struct_definition const &) {}
};
}
void compile(module_context & mcontext, ast::statement_list_ptr const & statements)
{
// Add a fake AST node for the entry point
auto root = std::make_shared<ast::statement>(ast::function_definition{
{
{},
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
statements,
{},
});
mcontext.nodes.emplace_back(nop{});
auto extra_nop = std::prev(mcontext.nodes.end());
local_context lcontext;
compile_visitor{{}, mcontext, lcontext}.apply(*root);
mcontext.labels[&(*std::next(extra_nop))] = "[entry point]";
mcontext.nodes.erase(extra_nop);
for (auto const & resolve : lcontext.resolve_address)
std::get<address>(resolve.address->instruction).target = lcontext.functions.at(resolve.target);
for (auto const & resolve : lcontext.resolve_call)
std::get<call>(resolve.call->instruction).target = lcontext.functions.at(resolve.target);
for (auto const & symbol : lcontext.functions)
mcontext.symbols[symbol.first] = symbol.second;
mcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
}
}