From aee506d102114393587c0835f5640a06cf7074fc Mon Sep 17 00:00:00 2001 From: lisyarus Date: Sat, 20 Dec 2025 13:05:51 +0300 Subject: [PATCH] Rewrite identifier resolution using visitors --- .../include/pslang/ast/statement_visitor.hpp | 12 + libs/ast/source/print.cpp | 7 - libs/ast/source/resolve_identifiers.cpp | 406 +++++++++--------- plans.txt | 1 - 4 files changed, 213 insertions(+), 213 deletions(-) diff --git a/libs/ast/include/pslang/ast/statement_visitor.hpp b/libs/ast/include/pslang/ast/statement_visitor.hpp index 7c26e2d..faa8b32 100644 --- a/libs/ast/include/pslang/ast/statement_visitor.hpp +++ b/libs/ast/include/pslang/ast/statement_visitor.hpp @@ -30,6 +30,12 @@ namespace pslang::ast { std::visit(*this, statement); } + + void operator()(statement_list const & statement_list) + { + for (auto const & statement : statement_list.statements) + operator()(*statement); + } }; template @@ -54,6 +60,12 @@ namespace pslang::ast { std::visit(*this, statement); } + + void operator()(statement_list & statement_list) + { + for (auto & statement : statement_list.statements) + operator()(*statement); + } }; } diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index 16790e5..0dd8819 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -368,13 +368,6 @@ namespace pslang::ast for (auto const & field : node.fields) child(self, field); } - - template - void operator()(Self & self, statement_list const & node) - { - for (auto const & statement : node.statements) - self(*statement); - } }; } diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index 5999a3b..a5161f8 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -1,5 +1,8 @@ #include #include +#include +#include +#include #include #include @@ -57,238 +60,231 @@ namespace pslang::ast std::vector scopes; }; - void resolve_identifiers(context & context, type & type); - void resolve_identifiers(context & context, expression & expression); - void resolve_identifiers(context & context, statement_list const & statements); - - void resolve_identifiers_impl(context &, types::unit_type const &) - {} - - void resolve_identifiers_impl(context &, types::primitive_type const &) - {} - - void resolve_identifiers_impl(context & context, array_type const & array_type) + struct resolve_identifiers_visitor { - resolve_identifiers(context, *array_type.element_type); - } + std::vector scopes; - void resolve_identifiers_impl(context & context, function_type const & function_type) - { - for (auto const & argument : function_type.arguments) - resolve_identifiers(context, *argument); - resolve_identifiers(context, *function_type.result); - } + void operator()(types::unit_type const &) + {} - void resolve_identifiers_impl(context & context, type_identifier & identifier) - { - for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it) + void operator()(types::primitive_type const &) + {} + + template + void operator()(Self & self, array_type const & array_type) { - if (it->contains_type(identifier.name)) + self(*array_type.element_type); + } + + template + void operator()(Self & self, function_type const & function_type) + { + for (auto const & argument : function_type.arguments) + self(*argument); + self(*function_type.result); + } + + void operator()(type_identifier & identifier) + { + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { - identifier.level = it.base() - context.scopes.begin() - 1; - return; + if (it->contains_type(identifier.name)) + { + identifier.level = it.base() - scopes.begin() - 1; + return; + } + } + + throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); + } + + void operator()(literal const &) + {} + + void operator()(identifier & identifier) + { + bool crossed_function_scope = false; + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + if (it->contains(identifier.name, crossed_function_scope && !it->is_global_scope)) + { + identifier.level = it.base() - scopes.begin() - 1; + return; + } + + crossed_function_scope |= it->is_function_scope; + } + + throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); + } + + template + void operator()(Self & self, unary_operation const & unary_operation) + { + self(*unary_operation.arg1); + } + + template + void operator()(Self & self, binary_operation const & binary_operation) + { + self(*binary_operation.arg1); + self(*binary_operation.arg2); + } + + template + void operator()(Self & self, cast_operation const & cast_operation) + { + self(*cast_operation.expression); + apply(*this, *cast_operation.type); + } + + template + void operator()(Self & self, function_call const & function_call) + { + self(*function_call.function); + for (auto const & argument : function_call.arguments) + self(*argument); + } + + template + void operator()(Self & self, array const & array) + { + for (auto const & element : array.elements) + self(*element); + } + + template + void operator()(Self & self, array_access const & array_access) + { + self(*array_access.array); + self(*array_access.index); + } + + template + void operator()(Self & self, field_access const & field_access) + { + self(*field_access.object); + } + + void operator()(expression_ptr const & expression_ptr) + { + apply(*this, *expression_ptr); + } + + template + void operator()(Self & self, assignment const & assignment) + { + self(assignment.lhs); + self(assignment.rhs); + } + + void operator()(variable_declaration const & variable_declaration) + { + if (scopes.back().contains(variable_declaration.name)) + throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location); + + if (variable_declaration.type) + apply(*this, *variable_declaration.type); + apply(*this, *variable_declaration.initializer); + + scopes.back().variables.insert(variable_declaration.name); + } + + void operator()(if_block const & if_block) + { + throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location); + } + + void operator()(else_if_block const & else_if_block) + { + throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location); + } + + void operator()(else_block const & else_block) + { + throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location); + } + + template + void operator()(Self & self, if_chain const & if_chain) + { + for (auto const & block : if_chain.blocks) + { + if (block.condition) + apply(*this, *block.condition); + scopes.emplace_back(); + self(*block.statements); + scopes.pop_back(); } } - throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); - } - - void resolve_identifiers(context & context, type & type) - { - return std::visit([&](auto & type){ return resolve_identifiers_impl(context, type); }, type); - } - - void resolve_identifiers_impl(context &, literal const &) - {} - - void resolve_identifiers_impl(context & context, identifier & identifier) - { - bool crossed_function_scope = false; - for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it) + template + void operator()(Self & self, while_block const & while_block) { - if (it->contains(identifier.name, crossed_function_scope && !it->is_global_scope)) + apply(*this, *while_block.condition); + scopes.emplace_back(); + self(*while_block.statements); + scopes.pop_back(); + } + + template + void operator()(Self & self, function_definition const & function_definition) + { + if (scopes.back().contains(function_definition.name)) + throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); + + std::unordered_set argument_names; + for (auto const & argument : function_definition.arguments) { - identifier.level = it.base() - context.scopes.begin() - 1; - return; + if (argument_names.count(argument.name) > 0) + throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location); + argument_names.insert(argument.name); + apply(*this, *argument.type); } - crossed_function_scope |= it->is_function_scope; + apply(*this, *function_definition.return_type); + + scopes.back().functions.insert(function_definition.name); + + auto & scope = scopes.emplace_back(); + scope.is_function_scope = true; + scope.variables = std::move(argument_names); + self(*function_definition.statements); + scopes.pop_back(); } - throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); - } - - void resolve_identifiers_impl(context & context, unary_operation const & unary_operation) - { - resolve_identifiers(context, *unary_operation.arg1); - } - - void resolve_identifiers_impl(context & context, binary_operation const & binary_operation) - { - resolve_identifiers(context, *binary_operation.arg1); - resolve_identifiers(context, *binary_operation.arg2); - } - - void resolve_identifiers_impl(context & context, cast_operation const & cast_operation) - { - resolve_identifiers(context, *cast_operation.expression); - resolve_identifiers(context, *cast_operation.type); - } - - void resolve_identifiers_impl(context & context, function_call const & function_call) - { - resolve_identifiers(context, *function_call.function); - for (auto const & argument : function_call.arguments) - resolve_identifiers(context, *argument); - } - - void resolve_identifiers_impl(context & context, array const & array) - { - for (auto const & element : array.elements) - resolve_identifiers(context, *element); - } - - void resolve_identifiers_impl(context & context, array_access const & array_access) - { - resolve_identifiers(context, *array_access.array); - resolve_identifiers(context, *array_access.index); - } - - void resolve_identifiers_impl(context & context, field_access const & field_access) - { - resolve_identifiers(context, *field_access.object); - } - - void resolve_identifiers(context & context, expression & expression) - { - return std::visit([&](auto & expression){ return resolve_identifiers_impl(context, expression); }, expression); - } - - void resolve_identifiers_impl(context & context, expression_ptr const & expression_ptr) - { - resolve_identifiers(context, *expression_ptr); - } - - void resolve_identifiers_impl(context & context, assignment const & assignment) - { - resolve_identifiers(context, *assignment.lhs); - resolve_identifiers(context, *assignment.rhs); - } - - void resolve_identifiers_impl(context & context, variable_declaration const & variable_declaration) - { - if (context.scopes.back().contains(variable_declaration.name)) - throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location); - - if (variable_declaration.type) - resolve_identifiers(context, *variable_declaration.type); - resolve_identifiers(context, *variable_declaration.initializer); - - context.scopes.back().variables.insert(variable_declaration.name); - } - - void resolve_identifiers_impl(context &, if_block const & if_block) - { - throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location); - } - - void resolve_identifiers_impl(context &, else_if_block const & else_if_block) - { - throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location); - } - - void resolve_identifiers_impl(context &, else_block const & else_block) - { - throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location); - } - - void resolve_identifiers_impl(context & context, if_chain const & if_chain) - { - for (auto const & block : if_chain.blocks) + void operator()(return_statement const & return_statement) { - if (block.condition) - resolve_identifiers(context, *block.condition); - context.scopes.emplace_back(); - resolve_identifiers(context, *block.statements); - context.scopes.pop_back(); + if (return_statement.value) + apply(*this, *return_statement.value); } - } - void resolve_identifiers_impl(context & context, while_block const & while_block) - { - resolve_identifiers(context, *while_block.condition); - context.scopes.emplace_back(); - resolve_identifiers(context, *while_block.statements); - context.scopes.pop_back(); - } - - void resolve_identifiers_impl(context & context, function_definition const & function_definition) - { - if (context.scopes.back().contains(function_definition.name)) - throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); - - std::unordered_set argument_names; - for (auto const & argument : function_definition.arguments) + void operator()(field_definition const & field_definition) { - if (argument_names.count(argument.name) > 0) - throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location); - argument_names.insert(argument.name); - resolve_identifiers(context, *argument.type); + apply(*this, *field_definition.type); } - resolve_identifiers(context, *function_definition.return_type); - - context.scopes.back().functions.insert(function_definition.name); - - auto & scope = context.scopes.emplace_back(); - scope.is_function_scope = true; - scope.variables = std::move(argument_names); - resolve_identifiers(context, *function_definition.statements); - context.scopes.pop_back(); - } - - void resolve_identifiers_impl(context & context, return_statement const & return_statement) - { - if (return_statement.value) - resolve_identifiers(context, *return_statement.value); - } - - void resolve_identifiers_impl(context & context, field_definition const & field_definition) - { - resolve_identifiers(context, *field_definition.type); - } - - void resolve_identifiers_impl(context & context, struct_definition const & struct_definition) - { - if (context.scopes.back().contains(struct_definition.name)) - throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); - - for (auto const & field : struct_definition.fields) - resolve_identifiers_impl(context, field); - - context.scopes.back().structs.insert(struct_definition.name); - } - - void resolve_identifiers(context & context, statement const & statement) - { - return std::visit([&](auto const & statement){ return resolve_identifiers_impl(context, statement); }, statement); - } - - void resolve_identifiers(context & context, statement_list const & statements) - { - for (auto const & statement : statements.statements) + template + void operator()(Self & self, struct_definition const & struct_definition) { - resolve_identifiers(context, *statement); + if (scopes.back().contains(struct_definition.name)) + throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); + + for (auto const & field : struct_definition.fields) + self(field); + + scopes.back().structs.insert(struct_definition.name); } - } + + }; } void resolve_identifiers(statement_list_ptr & statements) { - context context; - context.scopes.emplace_back().is_global_scope = true; - resolve_identifiers(context, *statements); + resolve_identifiers_visitor visitor; + visitor.scopes.emplace_back().is_global_scope = true; + apply(visitor, *statements); } } diff --git a/plans.txt b/plans.txt index 1e78ada..9c86592 100644 --- a/plans.txt +++ b/plans.txt @@ -1,5 +1,4 @@ Future plans: -* Rewrite identifier resolution using visitors * Default-constructing (no arguments, zero-initialization) primitive types & struct types * Introduce type checking & inference as a separate pre-execution/compilation AST pass + add type information into expression nodes * Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter