#include #include #include #include #include namespace pslang::ast { namespace { struct scope { std::unordered_set functions; std::unordered_set structs; std::unordered_set variables; bool contains_transitive(std::string const & name) const { return false || (functions.count(name) > 0) || (structs.count(name) > 0) ; } bool contains_type(std::string const & name) const { return false || (structs.count(name) > 0) ; } bool contains(std::string const & name) const { return false || (functions.count(name) > 0) || (structs.count(name) > 0) || (variables.count(name) > 0) ; } bool contains(std::string const & name, bool crossed_function_scope) const { if (crossed_function_scope) return contains_transitive(name); return contains(name); } bool is_function_scope = false; bool is_global_scope = false; }; struct context { std::vector scopes; }; void resolve_identifiers(context & context, type & type); void resolve_identifiers(context & context, expression & expression); void resolve_identifiers(context & context, statement_list const & statements); void resolve_identifiers_impl(context &, types::unit_type const &) {} void resolve_identifiers_impl(context &, types::primitive_type const &) {} void resolve_identifiers_impl(context & context, array_type const & array_type) { resolve_identifiers(context, *array_type.element_type); } void resolve_identifiers_impl(context & context, function_type const & function_type) { for (auto const & argument : function_type.arguments) resolve_identifiers(context, *argument); resolve_identifiers(context, *function_type.result); } void resolve_identifiers_impl(context & context, type_identifier & identifier) { for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it) { if (it->contains_type(identifier.name)) { identifier.level = it.base() - context.scopes.begin() - 1; return; } } // TODO location throw parse_error("Identifier \"" + identifier.name + "\" not found", {}); } void resolve_identifiers(context & context, type & type) { return std::visit([&](auto & type){ return resolve_identifiers_impl(context, type); }, type); } void resolve_identifiers_impl(context &, literal const &) {} void resolve_identifiers_impl(context & context, identifier & identifier) { bool crossed_function_scope = false; for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it) { if (it->contains(identifier.name, crossed_function_scope && !it->is_global_scope)) { identifier.level = it.base() - context.scopes.begin() - 1; return; } crossed_function_scope |= it->is_function_scope; } throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); } void resolve_identifiers_impl(context & context, unary_operation const & unary_operation) { resolve_identifiers(context, *unary_operation.arg1); } void resolve_identifiers_impl(context & context, binary_operation const & binary_operation) { resolve_identifiers(context, *binary_operation.arg1); resolve_identifiers(context, *binary_operation.arg2); } void resolve_identifiers_impl(context & context, cast_operation const & cast_operation) { resolve_identifiers(context, *cast_operation.expression); resolve_identifiers(context, *cast_operation.type); } void resolve_identifiers_impl(context & context, function_call const & function_call) { resolve_identifiers(context, *function_call.function); for (auto const & argument : function_call.arguments) resolve_identifiers(context, *argument); } void resolve_identifiers_impl(context & context, array const & array) { for (auto const & element : array.elements) resolve_identifiers(context, *element); } void resolve_identifiers_impl(context & context, array_access const & array_access) { resolve_identifiers(context, *array_access.array); resolve_identifiers(context, *array_access.index); } void resolve_identifiers_impl(context & context, field_access const & field_access) { resolve_identifiers(context, *field_access.object); } void resolve_identifiers(context & context, expression & expression) { return std::visit([&](auto & expression){ return resolve_identifiers_impl(context, expression); }, expression); } void resolve_identifiers_impl(context & context, expression_ptr const & expression_ptr) { resolve_identifiers(context, *expression_ptr); } void resolve_identifiers_impl(context & context, assignment const & assignment) { resolve_identifiers(context, *assignment.lhs); resolve_identifiers(context, *assignment.rhs); } void resolve_identifiers_impl(context & context, variable_declaration const & variable_declaration) { if (context.scopes.back().contains(variable_declaration.name)) throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location); if (variable_declaration.type) resolve_identifiers(context, *variable_declaration.type); resolve_identifiers(context, *variable_declaration.initializer); context.scopes.back().variables.insert(variable_declaration.name); } void resolve_identifiers_impl(context &, if_block const & if_block) { throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location); } void resolve_identifiers_impl(context &, else_if_block const & else_if_block) { throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location); } void resolve_identifiers_impl(context &, else_block const & else_block) { throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location); } void resolve_identifiers_impl(context & context, if_chain const & if_chain) { for (auto const & block : if_chain.blocks) { if (block.condition) resolve_identifiers(context, *block.condition); context.scopes.emplace_back(); resolve_identifiers(context, *block.statements); context.scopes.pop_back(); } } void resolve_identifiers_impl(context & context, while_block const & while_block) { resolve_identifiers(context, *while_block.condition); context.scopes.emplace_back(); resolve_identifiers(context, *while_block.statements); context.scopes.pop_back(); } void resolve_identifiers_impl(context & context, function_definition const & function_definition) { if (context.scopes.back().contains(function_definition.name)) throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); std::unordered_set argument_names; for (auto const & argument : function_definition.arguments) { if (argument_names.count(argument.name) > 0) throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location); argument_names.insert(argument.name); resolve_identifiers(context, *argument.type); } resolve_identifiers(context, *function_definition.return_type); context.scopes.back().functions.insert(function_definition.name); auto & scope = context.scopes.emplace_back(); scope.is_function_scope = true; scope.variables = std::move(argument_names); resolve_identifiers(context, *function_definition.statements); context.scopes.pop_back(); } void resolve_identifiers_impl(context & context, return_statement const & return_statement) { if (return_statement.value) resolve_identifiers(context, *return_statement.value); } void resolve_identifiers_impl(context & context, field_definition const & field_definition) { resolve_identifiers(context, *field_definition.type); } void resolve_identifiers_impl(context & context, struct_definition const & struct_definition) { if (context.scopes.back().contains(struct_definition.name)) throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); for (auto const & field : struct_definition.fields) resolve_identifiers_impl(context, field); context.scopes.back().structs.insert(struct_definition.name); } void resolve_identifiers(context & context, statement const & statement) { return std::visit([&](auto const & statement){ return resolve_identifiers_impl(context, statement); }, statement); } void resolve_identifiers(context & context, statement_list const & statements) { for (auto const & statement : statements.statements) { resolve_identifiers(context, *statement); } } } void resolve_identifiers(statement_list_ptr & statements) { context context; context.scopes.emplace_back().is_global_scope = true; resolve_identifiers(context, *statements); } }