pslang/libs/ast/source/resolve_identifiers.cpp
2025-12-20 13:20:44 +03:00

271 lines
7 KiB
C++

#include <pslang/ast/preprocess.hpp>
#include <pslang/ast/statement.hpp>
#include <pslang/ast/type_visitor.hpp>
#include <pslang/ast/expression_visitor.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/ast/error.hpp>
#include <unordered_set>
#include <vector>
namespace pslang::ast
{
namespace
{
struct scope
{
std::unordered_set<std::string> functions;
std::unordered_set<std::string> structs;
std::unordered_set<std::string> variables;
bool contains_transitive(std::string const & name) const
{
return false
|| (functions.count(name) > 0)
|| (structs.count(name) > 0)
;
}
bool contains_type(std::string const & name) const
{
return false
|| (structs.count(name) > 0)
;
}
bool contains(std::string const & name) const
{
return false
|| (functions.count(name) > 0)
|| (structs.count(name) > 0)
|| (variables.count(name) > 0)
;
}
bool contains(std::string const & name, bool crossed_function_scope) const
{
if (crossed_function_scope)
return contains_transitive(name);
return contains(name);
}
bool is_function_scope = false;
bool is_global_scope = false;
};
struct resolve_identifiers_visitor
{
std::vector<scope> scopes;
void operator()(types::unit_type const &)
{}
void operator()(types::primitive_type const &)
{}
void operator()(array_type const & array_type)
{
apply(*this, *array_type.element_type);
}
void operator()(function_type const & function_type)
{
for (auto const & argument : function_type.arguments)
apply(*this, *argument);
apply(*this, *function_type.result);
}
void operator()(type_identifier & identifier)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->contains_type(identifier.name))
{
identifier.level = it.base() - scopes.begin() - 1;
return;
}
}
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
}
void operator()(literal const &)
{}
void operator()(identifier & identifier)
{
bool crossed_function_scope = false;
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->contains(identifier.name, crossed_function_scope && !it->is_global_scope))
{
identifier.level = it.base() - scopes.begin() - 1;
return;
}
crossed_function_scope |= it->is_function_scope;
}
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
}
void operator()(unary_operation const & unary_operation)
{
apply(*this, *unary_operation.arg1);
}
void operator()(binary_operation const & binary_operation)
{
apply(*this, *binary_operation.arg1);
apply(*this, *binary_operation.arg2);
}
void operator()(cast_operation const & cast_operation)
{
apply(*this, *cast_operation.expression);
apply(*this, *cast_operation.type);
}
void operator()(function_call const & function_call)
{
apply(*this, *function_call.function);
for (auto const & argument : function_call.arguments)
apply(*this, *argument);
}
void operator()(array const & array)
{
for (auto const & element : array.elements)
apply(*this, *element);
}
void operator()(array_access const & array_access)
{
apply(*this, *array_access.array);
apply(*this, *array_access.index);
}
void operator()(field_access const & field_access)
{
apply(*this, *field_access.object);
}
void operator()(expression_ptr const & expression_ptr)
{
apply(*this, *expression_ptr);
}
void operator()(assignment const & assignment)
{
ast::apply(*this, assignment.lhs);
ast::apply(*this, assignment.rhs);
}
void operator()(variable_declaration const & variable_declaration)
{
if (scopes.back().contains(variable_declaration.name))
throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location);
if (variable_declaration.type)
apply(*this, *variable_declaration.type);
apply(*this, *variable_declaration.initializer);
scopes.back().variables.insert(variable_declaration.name);
}
void operator()(if_block const & if_block)
{
throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location);
}
void operator()(else_if_block const & else_if_block)
{
throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location);
}
void operator()(else_block const & else_block)
{
throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location);
}
void operator()(if_chain const & if_chain)
{
for (auto const & block : if_chain.blocks)
{
if (block.condition)
apply(*this, *block.condition);
scopes.emplace_back();
apply(*this, *block.statements);
scopes.pop_back();
}
}
void operator()(while_block const & while_block)
{
apply(*this, *while_block.condition);
scopes.emplace_back();
apply(*this, *while_block.statements);
scopes.pop_back();
}
void operator()(function_definition const & function_definition)
{
if (scopes.back().contains(function_definition.name))
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
std::unordered_set<std::string> argument_names;
for (auto const & argument : function_definition.arguments)
{
if (argument_names.count(argument.name) > 0)
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location);
argument_names.insert(argument.name);
apply(*this, *argument.type);
}
apply(*this, *function_definition.return_type);
scopes.back().functions.insert(function_definition.name);
auto & scope = scopes.emplace_back();
scope.is_function_scope = true;
scope.variables = std::move(argument_names);
apply(*this, *function_definition.statements);
scopes.pop_back();
}
void operator()(return_statement const & return_statement)
{
if (return_statement.value)
apply(*this, *return_statement.value);
}
void operator()(field_definition const & field_definition)
{
apply(*this, *field_definition.type);
}
void operator()(struct_definition const & struct_definition)
{
if (scopes.back().contains(struct_definition.name))
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
for (auto const & field : struct_definition.fields)
apply(*this, field);
scopes.back().structs.insert(struct_definition.name);
}
};
}
void resolve_identifiers(statement_list_ptr & statements)
{
resolve_identifiers_visitor visitor;
visitor.scopes.emplace_back().is_global_scope = true;
apply(visitor, *statements);
}
}