#include #include #include #include #include #include #include #include #include #include namespace pslang::ast { namespace { struct scope { std::unordered_map functions; std::unordered_map foreign_functions; std::unordered_map structs; std::unordered_map variables; bool contains_transitive(std::string const & name) const { return false || (functions.count(name) > 0) || (foreign_functions.count(name) > 0) || (structs.count(name) > 0) ; } bool contains_type(std::string const & name) const { return false || (structs.count(name) > 0) ; } bool contains(std::string const & name) const { return false || (functions.count(name) > 0) || (foreign_functions.count(name) > 0) || (structs.count(name) > 0) || (variables.count(name) > 0) ; } bool contains(std::string const & name, bool crossed_function_scope) const { if (crossed_function_scope) return contains_transitive(name); return contains(name); } bool is_function_scope = false; }; struct populate_globals_visitor : statement_visitor { std::vector & scopes; using statement_visitor::apply; void apply(expression_ptr const &) {} void apply(assignment const &) {} void apply(variable_declaration const &) {} void apply(if_chain const &) {} void apply(while_block const &) {} void apply(function_definition & function_definition) { if (scopes.back().contains(function_definition.name)) throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); scopes.back().functions[function_definition.name] = &function_definition; } void apply(foreign_function_declaration & foreign_function_declaration) { if (scopes.back().contains(foreign_function_declaration.name)) throw parse_error("Identifier \"" + foreign_function_declaration.name + "\" is already defined at this scope", foreign_function_declaration.location); scopes.back().foreign_functions[foreign_function_declaration.name] = &foreign_function_declaration; } void apply(return_statement const &) {} void apply(struct_definition & struct_definition) { if (scopes.back().contains(struct_definition.name)) throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); scopes.back().structs[struct_definition.name] = &struct_definition; } }; struct resolve_identifiers_visitor : type_visitor , expression_visitor , statement_visitor { std::vector & scopes; using type_visitor::apply; using expression_visitor::apply; using statement_visitor::apply; void apply(types::unit_type const &) {} void apply(types::primitive_type const &) {} void apply(array_type const & array_type) { apply(*array_type.element_type); } void apply(function_type const & function_type) { for (auto const & argument : function_type.arguments) apply(*argument); apply(*function_type.result); } void apply(type_identifier & identifier) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->structs.find(identifier.name); jt != it->structs.end()) { identifier.node = jt->second; return; } } throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); } void apply(literal const &) {} void apply(identifier & identifier) { // NB: cannot be a type // The case of type constructors is resolved earlier in function_call node bool crossed_function_scope = false; for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->functions.find(identifier.name); jt != it->functions.end()) { identifier.function_node = jt->second; return; } if (auto jt = it->foreign_functions.find(identifier.name); jt != it->foreign_functions.end()) { identifier.foreign_function_node = jt->second; return; } if (!crossed_function_scope) { if (auto jt = it->variables.find(identifier.name); jt != it->variables.end()) { identifier.variable_node = jt->second; return; } } crossed_function_scope |= it->is_function_scope; } throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); } void apply(unary_operation const & unary_operation) { apply(*unary_operation.arg1); } void apply(binary_operation const & binary_operation) { apply(*binary_operation.arg1); apply(*binary_operation.arg2); } void apply(cast_operation const & cast_operation) { apply(*cast_operation.expression); apply(*cast_operation.type); } void apply(function_call & function_call) { if (auto id = std::get_if(function_call.function.get())) { if (auto type = types::builtin_type(id->name)) { if (auto unit_type = std::get_if(type.get())) function_call.type = std::make_unique(*unit_type); else if (auto primitive_type = std::get_if(type.get())) function_call.type = std::make_unique(*primitive_type); else throw invalid_ast_error("Unknown built-in type \"" + id->name + "\"", get_location(*function_call.function)); function_call.function = nullptr; } else { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { if (auto jt = it->structs.find(id->name); jt != it->structs.end()) { function_call.type = std::make_unique(type_identifier{.name = id->name, .location = id->location, .node = jt->second}); function_call.function = nullptr; break; } } } } if (function_call.function) apply(*function_call.function); if (function_call.type) apply(*function_call.type); for (auto const & argument : function_call.arguments) apply(*argument); } void apply(array const & array) { for (auto const & element : array.elements) apply(*element); } void apply(array_access const & array_access) { apply(*array_access.array); apply(*array_access.index); } void apply(field_access const & field_access) { apply(*field_access.object); } void apply(expression_ptr const & expression_ptr) { apply(*expression_ptr); } void apply(assignment const & assignment) { apply(assignment.lhs); apply(assignment.rhs); } void apply(variable_declaration & variable_declaration) { if (scopes.back().contains(variable_declaration.name)) throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location); if (variable_declaration.type) apply(*variable_declaration.type); apply(*variable_declaration.initializer); scopes.back().variables[variable_declaration.name] = &variable_declaration; } void apply(if_chain const & if_chain) { for (auto const & block : if_chain.blocks) { if (block.condition) apply(*block.condition); scopes.emplace_back(); apply(*block.statements); scopes.pop_back(); } } void apply(while_block const & while_block) { apply(*while_block.condition); scopes.emplace_back(); apply(*while_block.statements); scopes.pop_back(); } void apply(function_definition & function_definition) { // Already added to scope by populate_globals_visitor std::unordered_map arguments; for (auto & argument : function_definition.arguments) { if (arguments.count(argument.name) > 0) throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location); arguments[argument.name] = &argument; apply(*argument.type); } apply(*function_definition.return_type); auto & scope = scopes.emplace_back(); scope.is_function_scope = true; scope.variables = std::move(arguments); apply(*function_definition.statements); scopes.pop_back(); } void apply(foreign_function_declaration const & foreign_function_declaration) { // Already added to scope by populate_globals_visitor std::unordered_set argument_names; for (auto const & argument : foreign_function_declaration.arguments) { if (argument_names.count(argument.name) > 0) throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + foreign_function_declaration.name + "\"", argument.location); argument_names.insert(argument.name); apply(*argument.type); } apply(*foreign_function_declaration.return_type); } void apply(return_statement const & return_statement) { if (return_statement.value) apply(*return_statement.value); } void apply(struct_definition & struct_definition) { // Already added to scope by populate_globals_visitor for (auto const & field : struct_definition.fields) apply(*field.type); scopes.back().structs[struct_definition.name] = &struct_definition; } void apply(statement_list & statement_list) { populate_globals_visitor populate_globals_visitor{{}, scopes}; populate_globals_visitor.apply(statement_list); for (auto const & statement : statement_list.statements) apply(*statement); } }; } void resolve_identifiers(statement_list_ptr & statements) { std::vector scopes; scopes.emplace_back(); resolve_identifiers_visitor visitor{{}, {}, {}, scopes}; visitor.apply(*statements); } }