295 lines
9 KiB
C++
295 lines
9 KiB
C++
#include <pslang/ast/preprocess.hpp>
|
|
#include <pslang/ast/statement.hpp>
|
|
#include <pslang/ast/error.hpp>
|
|
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
namespace pslang::ast
|
|
{
|
|
|
|
namespace
|
|
{
|
|
|
|
struct scope
|
|
{
|
|
std::unordered_set<std::string> functions;
|
|
std::unordered_set<std::string> structs;
|
|
std::unordered_set<std::string> variables;
|
|
|
|
bool contains_transitive(std::string const & name) const
|
|
{
|
|
return false
|
|
|| (functions.count(name) > 0)
|
|
|| (structs.count(name) > 0)
|
|
;
|
|
}
|
|
|
|
bool contains_type(std::string const & name) const
|
|
{
|
|
return false
|
|
|| (structs.count(name) > 0)
|
|
;
|
|
}
|
|
|
|
bool contains(std::string const & name) const
|
|
{
|
|
return false
|
|
|| (functions.count(name) > 0)
|
|
|| (structs.count(name) > 0)
|
|
|| (variables.count(name) > 0)
|
|
;
|
|
}
|
|
|
|
bool contains(std::string const & name, bool crossed_function_scope) const
|
|
{
|
|
if (crossed_function_scope)
|
|
return contains_transitive(name);
|
|
return contains(name);
|
|
}
|
|
|
|
bool is_function_scope = false;
|
|
bool is_global_scope = false;
|
|
};
|
|
|
|
struct context
|
|
{
|
|
std::vector<scope> scopes;
|
|
};
|
|
|
|
void resolve_identifiers(context & context, type & type);
|
|
void resolve_identifiers(context & context, expression & expression);
|
|
void resolve_identifiers(context & context, statement_list const & statements);
|
|
|
|
void resolve_identifiers_impl(context &, types::unit_type const &)
|
|
{}
|
|
|
|
void resolve_identifiers_impl(context &, types::primitive_type const &)
|
|
{}
|
|
|
|
void resolve_identifiers_impl(context & context, array_type const & array_type)
|
|
{
|
|
resolve_identifiers(context, *array_type.element_type);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, function_type const & function_type)
|
|
{
|
|
for (auto const & argument : function_type.arguments)
|
|
resolve_identifiers(context, *argument);
|
|
resolve_identifiers(context, *function_type.result);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, type_identifier & identifier)
|
|
{
|
|
for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it)
|
|
{
|
|
if (it->contains_type(identifier.name))
|
|
{
|
|
identifier.level = it.base() - context.scopes.begin() - 1;
|
|
return;
|
|
}
|
|
}
|
|
|
|
// TODO location
|
|
throw parse_error("Identifier \"" + identifier.name + "\" not found", {});
|
|
}
|
|
|
|
void resolve_identifiers(context & context, type & type)
|
|
{
|
|
return std::visit([&](auto & type){ return resolve_identifiers_impl(context, type); }, type);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context &, literal const &)
|
|
{}
|
|
|
|
void resolve_identifiers_impl(context & context, identifier & identifier)
|
|
{
|
|
bool crossed_function_scope = false;
|
|
for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it)
|
|
{
|
|
if (it->contains(identifier.name, crossed_function_scope && !it->is_global_scope))
|
|
{
|
|
identifier.level = it.base() - context.scopes.begin() - 1;
|
|
return;
|
|
}
|
|
|
|
crossed_function_scope |= it->is_function_scope;
|
|
}
|
|
|
|
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, unary_operation const & unary_operation)
|
|
{
|
|
resolve_identifiers(context, *unary_operation.arg1);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, binary_operation const & binary_operation)
|
|
{
|
|
resolve_identifiers(context, *binary_operation.arg1);
|
|
resolve_identifiers(context, *binary_operation.arg2);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, cast_operation const & cast_operation)
|
|
{
|
|
resolve_identifiers(context, *cast_operation.expression);
|
|
resolve_identifiers(context, *cast_operation.type);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, function_call const & function_call)
|
|
{
|
|
resolve_identifiers(context, *function_call.function);
|
|
for (auto const & argument : function_call.arguments)
|
|
resolve_identifiers(context, *argument);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, array const & array)
|
|
{
|
|
for (auto const & element : array.elements)
|
|
resolve_identifiers(context, *element);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, array_access const & array_access)
|
|
{
|
|
resolve_identifiers(context, *array_access.array);
|
|
resolve_identifiers(context, *array_access.index);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, field_access const & field_access)
|
|
{
|
|
resolve_identifiers(context, *field_access.object);
|
|
}
|
|
|
|
void resolve_identifiers(context & context, expression & expression)
|
|
{
|
|
return std::visit([&](auto & expression){ return resolve_identifiers_impl(context, expression); }, expression);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, expression_ptr const & expression_ptr)
|
|
{
|
|
resolve_identifiers(context, *expression_ptr);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, assignment const & assignment)
|
|
{
|
|
resolve_identifiers(context, *assignment.lhs);
|
|
resolve_identifiers(context, *assignment.rhs);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, variable_declaration const & variable_declaration)
|
|
{
|
|
if (context.scopes.back().contains(variable_declaration.name))
|
|
throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location);
|
|
|
|
if (variable_declaration.type)
|
|
resolve_identifiers(context, *variable_declaration.type);
|
|
resolve_identifiers(context, *variable_declaration.initializer);
|
|
|
|
context.scopes.back().variables.insert(variable_declaration.name);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context &, if_block const & if_block)
|
|
{
|
|
throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context &, else_if_block const & else_if_block)
|
|
{
|
|
throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context &, else_block const & else_block)
|
|
{
|
|
throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, if_chain const & if_chain)
|
|
{
|
|
for (auto const & block : if_chain.blocks)
|
|
{
|
|
if (block.condition)
|
|
resolve_identifiers(context, *block.condition);
|
|
context.scopes.emplace_back();
|
|
resolve_identifiers(context, *block.statements);
|
|
context.scopes.pop_back();
|
|
}
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, while_block const & while_block)
|
|
{
|
|
resolve_identifiers(context, *while_block.condition);
|
|
context.scopes.emplace_back();
|
|
resolve_identifiers(context, *while_block.statements);
|
|
context.scopes.pop_back();
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, function_definition const & function_definition)
|
|
{
|
|
if (context.scopes.back().contains(function_definition.name))
|
|
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
|
|
|
|
std::unordered_set<std::string> argument_names;
|
|
for (auto const & argument : function_definition.arguments)
|
|
{
|
|
if (argument_names.count(argument.name) > 0)
|
|
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location);
|
|
argument_names.insert(argument.name);
|
|
resolve_identifiers(context, *argument.type);
|
|
}
|
|
|
|
resolve_identifiers(context, *function_definition.return_type);
|
|
|
|
context.scopes.back().functions.insert(function_definition.name);
|
|
|
|
auto & scope = context.scopes.emplace_back();
|
|
scope.is_function_scope = true;
|
|
scope.variables = std::move(argument_names);
|
|
resolve_identifiers(context, *function_definition.statements);
|
|
context.scopes.pop_back();
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, return_statement const & return_statement)
|
|
{
|
|
if (return_statement.value)
|
|
resolve_identifiers(context, *return_statement.value);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, field_definition const & field_definition)
|
|
{
|
|
resolve_identifiers(context, *field_definition.type);
|
|
}
|
|
|
|
void resolve_identifiers_impl(context & context, struct_definition const & struct_definition)
|
|
{
|
|
if (context.scopes.back().contains(struct_definition.name))
|
|
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
|
|
|
|
for (auto const & field : struct_definition.fields)
|
|
resolve_identifiers_impl(context, field);
|
|
|
|
context.scopes.back().structs.insert(struct_definition.name);
|
|
}
|
|
|
|
void resolve_identifiers(context & context, statement const & statement)
|
|
{
|
|
return std::visit([&](auto const & statement){ return resolve_identifiers_impl(context, statement); }, statement);
|
|
}
|
|
|
|
void resolve_identifiers(context & context, statement_list const & statements)
|
|
{
|
|
for (auto const & statement : statements.statements)
|
|
{
|
|
resolve_identifiers(context, *statement);
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
void resolve_identifiers(statement_list_ptr & statements)
|
|
{
|
|
context context;
|
|
context.scopes.emplace_back().is_global_scope = true;
|
|
resolve_identifiers(context, *statements);
|
|
}
|
|
|
|
}
|