398 lines
11 KiB
C++
398 lines
11 KiB
C++
#include <pslang/ast/preprocess.hpp>
|
|
#include <pslang/ast/statement.hpp>
|
|
#include <pslang/ast/type_visitor.hpp>
|
|
#include <pslang/ast/expression_visitor.hpp>
|
|
#include <pslang/ast/statement_visitor.hpp>
|
|
#include <pslang/ast/error.hpp>
|
|
#include <pslang/types/type.hpp>
|
|
|
|
#include <unordered_set>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
namespace pslang::ast
|
|
{
|
|
|
|
namespace
|
|
{
|
|
|
|
struct scope
|
|
{
|
|
std::unordered_map<std::string, function_definition *> functions;
|
|
std::unordered_map<std::string, foreign_function_declaration *> foreign_functions;
|
|
std::unordered_map<std::string, struct_definition *> structs;
|
|
std::unordered_map<std::string, variable_base *> variables;
|
|
std::unordered_map<std::string, variable_declaration *> constants;
|
|
|
|
bool contains_transitive(std::string const & name) const
|
|
{
|
|
return false
|
|
|| (functions.count(name) > 0)
|
|
|| (foreign_functions.count(name) > 0)
|
|
|| (structs.count(name) > 0)
|
|
|| (constants.count(name) > 0)
|
|
;
|
|
}
|
|
|
|
bool contains_type(std::string const & name) const
|
|
{
|
|
return false
|
|
|| (structs.count(name) > 0)
|
|
;
|
|
}
|
|
|
|
bool contains(std::string const & name) const
|
|
{
|
|
return false
|
|
|| (functions.count(name) > 0)
|
|
|| (foreign_functions.count(name) > 0)
|
|
|| (structs.count(name) > 0)
|
|
|| (variables.count(name) > 0)
|
|
|| (constants.count(name) > 0)
|
|
;
|
|
}
|
|
|
|
bool contains(std::string const & name, bool crossed_function_scope) const
|
|
{
|
|
if (crossed_function_scope)
|
|
return contains_transitive(name);
|
|
return contains(name);
|
|
}
|
|
|
|
bool is_function_scope = false;
|
|
};
|
|
|
|
struct populate_globals_visitor
|
|
: statement_visitor<populate_globals_visitor>
|
|
{
|
|
std::vector<scope> & scopes;
|
|
|
|
using statement_visitor::apply;
|
|
|
|
void apply(expression_ptr const &)
|
|
{}
|
|
|
|
void apply(assignment const &)
|
|
{}
|
|
|
|
void apply(variable_declaration const &)
|
|
{}
|
|
|
|
void apply(if_chain const &)
|
|
{}
|
|
|
|
void apply(while_block const &)
|
|
{}
|
|
|
|
void apply(function_definition & function_definition)
|
|
{
|
|
if (scopes.back().contains(function_definition.name))
|
|
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
|
|
|
|
scopes.back().functions[function_definition.name] = &function_definition;
|
|
}
|
|
|
|
void apply(foreign_function_declaration & foreign_function_declaration)
|
|
{
|
|
if (scopes.back().contains(foreign_function_declaration.name))
|
|
throw parse_error("Identifier \"" + foreign_function_declaration.name + "\" is already defined at this scope", foreign_function_declaration.location);
|
|
|
|
scopes.back().foreign_functions[foreign_function_declaration.name] = &foreign_function_declaration;
|
|
}
|
|
|
|
void apply(return_statement const &)
|
|
{}
|
|
|
|
void apply(struct_definition & struct_definition)
|
|
{
|
|
if (scopes.back().contains(struct_definition.name))
|
|
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
|
|
|
|
scopes.back().structs[struct_definition.name] = &struct_definition;
|
|
}
|
|
};
|
|
|
|
struct resolve_identifiers_visitor
|
|
: type_visitor<resolve_identifiers_visitor>
|
|
, expression_visitor<resolve_identifiers_visitor>
|
|
, statement_visitor<resolve_identifiers_visitor>
|
|
{
|
|
std::vector<scope> & scopes;
|
|
|
|
using type_visitor::apply;
|
|
using expression_visitor::apply;
|
|
using statement_visitor::apply;
|
|
|
|
void apply(types::unit_type const &)
|
|
{}
|
|
|
|
void apply(types::primitive_type const &)
|
|
{}
|
|
|
|
void apply(array_type const & array_type)
|
|
{
|
|
apply(*array_type.element_type);
|
|
}
|
|
|
|
void apply(function_type const & function_type)
|
|
{
|
|
for (auto const & argument : function_type.arguments)
|
|
apply(*argument);
|
|
apply(*function_type.result);
|
|
}
|
|
|
|
void apply(pointer_type const & pointer_type)
|
|
{
|
|
apply(*pointer_type.referenced_type);
|
|
}
|
|
|
|
void apply(type_identifier & identifier)
|
|
{
|
|
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
|
|
{
|
|
if (auto jt = it->structs.find(identifier.name); jt != it->structs.end())
|
|
{
|
|
identifier.node = jt->second;
|
|
return;
|
|
}
|
|
}
|
|
|
|
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
|
|
}
|
|
|
|
void apply(literal const &)
|
|
{}
|
|
|
|
void apply(identifier & identifier)
|
|
{
|
|
// NB: cannot be a type
|
|
// The case of type constructors is resolved earlier in function_call node
|
|
|
|
bool crossed_function_scope = false;
|
|
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
|
|
{
|
|
if (auto jt = it->functions.find(identifier.name); jt != it->functions.end())
|
|
{
|
|
identifier.function_node = jt->second;
|
|
return;
|
|
}
|
|
|
|
if (auto jt = it->foreign_functions.find(identifier.name); jt != it->foreign_functions.end())
|
|
{
|
|
identifier.foreign_function_node = jt->second;
|
|
return;
|
|
}
|
|
|
|
if (auto jt = it->constants.find(identifier.name); jt != it->constants.end())
|
|
{
|
|
identifier.constant_node = jt->second;
|
|
return;
|
|
}
|
|
|
|
if (!crossed_function_scope)
|
|
{
|
|
if (auto jt = it->variables.find(identifier.name); jt != it->variables.end())
|
|
{
|
|
identifier.variable_node = jt->second;
|
|
return;
|
|
}
|
|
}
|
|
|
|
crossed_function_scope |= it->is_function_scope;
|
|
}
|
|
|
|
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
|
|
}
|
|
|
|
void apply(unary_operation const & unary_operation)
|
|
{
|
|
apply(*unary_operation.arg1);
|
|
}
|
|
|
|
void apply(binary_operation const & binary_operation)
|
|
{
|
|
apply(*binary_operation.arg1);
|
|
apply(*binary_operation.arg2);
|
|
}
|
|
|
|
void apply(cast_operation const & cast_operation)
|
|
{
|
|
apply(*cast_operation.expression);
|
|
apply(*cast_operation.type);
|
|
}
|
|
|
|
void apply(function_call & function_call)
|
|
{
|
|
if (auto id = std::get_if<identifier>(function_call.function.get()))
|
|
{
|
|
if (auto type = types::builtin_type(id->name))
|
|
{
|
|
if (auto unit_type = std::get_if<types::unit_type>(type.get()))
|
|
function_call.type = std::make_unique<ast::type>(*unit_type);
|
|
else if (auto primitive_type = std::get_if<types::primitive_type>(type.get()))
|
|
function_call.type = std::make_unique<ast::type>(*primitive_type);
|
|
else
|
|
throw invalid_ast_error("Unknown built-in type \"" + id->name + "\"", get_location(*function_call.function));
|
|
|
|
function_call.function = nullptr;
|
|
}
|
|
else
|
|
{
|
|
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
|
|
{
|
|
if (auto jt = it->structs.find(id->name); jt != it->structs.end())
|
|
{
|
|
function_call.type = std::make_unique<ast::type>(type_identifier{.name = id->name, .location = id->location, .node = jt->second});
|
|
function_call.function = nullptr;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (function_call.function)
|
|
apply(*function_call.function);
|
|
if (function_call.type)
|
|
apply(*function_call.type);
|
|
for (auto const & argument : function_call.arguments)
|
|
apply(*argument);
|
|
}
|
|
|
|
void apply(array const & array)
|
|
{
|
|
for (auto const & element : array.elements)
|
|
apply(*element);
|
|
}
|
|
|
|
void apply(array_access const & array_access)
|
|
{
|
|
apply(*array_access.array);
|
|
apply(*array_access.index);
|
|
}
|
|
|
|
void apply(field_access const & field_access)
|
|
{
|
|
apply(*field_access.object);
|
|
}
|
|
|
|
void apply(expression_ptr const & expression_ptr)
|
|
{
|
|
apply(*expression_ptr);
|
|
}
|
|
|
|
void apply(assignment const & assignment)
|
|
{
|
|
apply(assignment.lhs);
|
|
apply(assignment.rhs);
|
|
}
|
|
|
|
void apply(variable_declaration & variable_declaration)
|
|
{
|
|
if (scopes.back().contains(variable_declaration.name))
|
|
throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location);
|
|
|
|
if (variable_declaration.type)
|
|
apply(*variable_declaration.type);
|
|
apply(*variable_declaration.initializer);
|
|
|
|
if (variable_declaration.category == value_category::compile_time)
|
|
scopes.back().constants[variable_declaration.name] = &variable_declaration;
|
|
else
|
|
scopes.back().variables[variable_declaration.name] = &variable_declaration;
|
|
}
|
|
|
|
void apply(if_chain const & if_chain)
|
|
{
|
|
for (auto const & block : if_chain.blocks)
|
|
{
|
|
if (block.condition)
|
|
apply(*block.condition);
|
|
scopes.emplace_back();
|
|
apply(*block.statements);
|
|
scopes.pop_back();
|
|
}
|
|
}
|
|
|
|
void apply(while_block const & while_block)
|
|
{
|
|
apply(*while_block.condition);
|
|
scopes.emplace_back();
|
|
apply(*while_block.statements);
|
|
scopes.pop_back();
|
|
}
|
|
|
|
void apply(function_definition & function_definition)
|
|
{
|
|
// Already added to scope by populate_globals_visitor
|
|
|
|
std::unordered_map<std::string, variable_base *> arguments;
|
|
for (auto & argument : function_definition.arguments)
|
|
{
|
|
if (arguments.count(argument.name) > 0)
|
|
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location);
|
|
arguments[argument.name] = &argument;
|
|
apply(*argument.type);
|
|
}
|
|
|
|
apply(*function_definition.return_type);
|
|
|
|
auto & scope = scopes.emplace_back();
|
|
scope.is_function_scope = true;
|
|
scope.variables = std::move(arguments);
|
|
apply(*function_definition.statements);
|
|
scopes.pop_back();
|
|
}
|
|
|
|
void apply(foreign_function_declaration const & foreign_function_declaration)
|
|
{
|
|
// Already added to scope by populate_globals_visitor
|
|
|
|
std::unordered_set<std::string> argument_names;
|
|
for (auto const & argument : foreign_function_declaration.arguments)
|
|
{
|
|
if (argument_names.count(argument.name) > 0)
|
|
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + foreign_function_declaration.name + "\"", argument.location);
|
|
argument_names.insert(argument.name);
|
|
apply(*argument.type);
|
|
}
|
|
|
|
apply(*foreign_function_declaration.return_type);
|
|
}
|
|
|
|
void apply(return_statement const & return_statement)
|
|
{
|
|
if (return_statement.value)
|
|
apply(*return_statement.value);
|
|
}
|
|
|
|
void apply(struct_definition & struct_definition)
|
|
{
|
|
// Already added to scope by populate_globals_visitor
|
|
|
|
for (auto const & field : struct_definition.fields)
|
|
apply(*field.type);
|
|
|
|
scopes.back().structs[struct_definition.name] = &struct_definition;
|
|
}
|
|
|
|
void apply(statement_list & statement_list)
|
|
{
|
|
populate_globals_visitor populate_globals_visitor{{}, scopes};
|
|
populate_globals_visitor.apply(statement_list);
|
|
|
|
for (auto const & statement : statement_list.statements)
|
|
apply(*statement);
|
|
}
|
|
};
|
|
|
|
}
|
|
|
|
void resolve_identifiers(statement_ptr & root)
|
|
{
|
|
std::vector<scope> scopes;
|
|
scopes.emplace_back();
|
|
resolve_identifiers_visitor visitor{{}, {}, {}, scopes};
|
|
visitor.apply(*root);
|
|
}
|
|
|
|
}
|