Refactor tree visitors (again)

This commit is contained in:
Nikita Lisitsa 2025-12-20 15:52:59 +03:00
parent 1f8c73614c
commit d33b00e368
6 changed files with 247 additions and 280 deletions

View file

@ -5,67 +5,52 @@
namespace pslang::ast
{
namespace detail
template <typename Derived>
struct const_expression_visitor
{
template <typename Visitor>
struct const_expression_visitor_helper
Derived & derived()
{
Visitor & visitor;
return static_cast<Derived &>(*this);
}
template <typename Expression>
void operator()(Expression const & expression)
{
visitor(expression);
}
void operator()(literal const & expression)
{
std::visit(*this, expression);
}
void operator()(expression const & expression)
{
std::visit(*this, expression);
}
};
template <typename Visitor>
struct expression_visitor_helper
auto make_visitor()
{
Visitor & visitor;
return [this](auto const & type){ return derived().apply(type); };
}
template <typename Expression>
void operator()(Expression & expression)
{
visitor(expression);
}
auto apply(literal const & expression)
{
return std::visit(make_visitor(), expression);
}
void operator()(literal & expression)
{
std::visit(*this, expression);
}
auto apply(expression const & expression)
{
return std::visit(make_visitor(), expression);
}
};
void operator()(expression & expression)
{
std::visit(*this, expression);
}
};
}
template <typename Visitor>
void apply(Visitor && visitor, expression const & expression)
template <typename Derived>
struct expression_visitor
{
detail::const_expression_visitor_helper<Visitor> helper{visitor};
helper(expression);
}
Derived & derived()
{
return static_cast<Derived &>(*this);
}
template <typename Visitor>
void apply(Visitor && visitor, expression & expression)
{
detail::expression_visitor_helper<Visitor> helper{visitor};
helper(expression);
}
auto make_visitor()
{
return [this](auto & type){ return derived().apply(type); };
}
auto apply(literal & expression)
{
return std::visit(make_visitor(), expression);
}
auto apply(expression & expression)
{
return std::visit(make_visitor(), expression);
}
};
}

View file

@ -5,83 +5,54 @@
namespace pslang::ast
{
namespace detail
template <typename Derived>
struct const_statement_visitor
{
template <typename Visitor>
struct const_statement_visitor_helper
Derived & derived()
{
Visitor & visitor;
return static_cast<Derived &>(*this);
}
template <typename Statement>
void operator()(Statement const & statement)
{
visitor(statement);
}
void operator()(statement const & statement)
{
std::visit(*this, statement);
}
void operator()(statement_list const & statement_list)
{
for (auto const & statement : statement_list.statements)
operator()(*statement);
}
};
template <typename Visitor>
struct statement_visitor_helper
auto make_visitor()
{
Visitor & visitor;
return [this](auto const & type){ return derived().apply(type); };
}
template <typename Statement>
void operator()(Statement & statement)
{
visitor(statement);
}
auto apply(statement const & statement)
{
return std::visit(make_visitor(), statement);
}
void operator()(statement & statement)
{
std::visit(*this, statement);
}
void apply(statement_list const & statement_list)
{
for (auto const & statement : statement_list.statements)
derived().apply(*statement);
}
};
void operator()(statement_list & statement_list)
{
for (auto & statement : statement_list.statements)
operator()(*statement);
}
};
}
template <typename Visitor>
void apply(Visitor && visitor, statement const & statement)
template <typename Derived>
struct statement_visitor
{
detail::const_statement_visitor_helper<Visitor> helper{visitor};
helper(statement);
}
Derived & derived()
{
return static_cast<Derived &>(*this);
}
template <typename Visitor>
void apply(Visitor && visitor, statement & statement)
{
detail::statement_visitor_helper<Visitor> helper{visitor};
helper(statement);
}
auto make_visitor()
{
return [this](auto & type){ return derived().apply(type); };
}
template <typename Visitor>
void apply(Visitor && visitor, statement_list const & statement_list)
{
detail::const_statement_visitor_helper<Visitor> helper{visitor};
helper(statement_list);
}
auto apply(statement & statement)
{
return std::visit(make_visitor(), statement);
}
template <typename Visitor>
void apply(Visitor && visitor, statement_list & statement_list)
{
detail::statement_visitor_helper<Visitor> helper{visitor};
helper(statement_list);
}
void apply(statement_list & statement_list)
{
for (auto & statement : statement_list.statements)
derived().apply(*statement);
}
};
}

View file

@ -5,67 +5,52 @@
namespace pslang::ast
{
namespace detail
template <typename Derived>
struct const_type_visitor
{
template <typename Visitor>
struct const_type_visitor_helper
Derived & derived()
{
Visitor & visitor;
return static_cast<Derived &>(*this);
}
template <typename Type>
void operator()(Type const & type)
{
visitor(type);
}
void operator()(types::primitive_type const & type)
{
std::visit(*this, type);
}
void operator()(ast::type const & type)
{
std::visit(*this, type);
}
};
template <typename Visitor>
struct type_visitor_helper
auto make_visitor()
{
Visitor & visitor;
return [this](auto const & type){ return derived().apply(type); };
}
template <typename Type>
void operator()(Type & type)
{
visitor(type);
}
auto apply(types::primitive_type const & type)
{
return std::visit(make_visitor(), type);
}
void operator()(types::primitive_type & type)
{
std::visit(*this, type);
}
auto apply(ast::type const & type)
{
return std::visit(make_visitor(), type);
}
};
void operator()(ast::type & type)
{
std::visit(*this, type);
}
};
}
template <typename Visitor>
void apply(Visitor && visitor, type const & type)
template <typename Derived>
struct type_visitor
{
detail::const_type_visitor_helper<Visitor> helper{visitor};
helper(type);
}
Derived & derived()
{
return static_cast<Derived &>(*this);
}
template <typename Visitor>
void apply(Visitor && visitor, type & type)
{
detail::type_visitor_helper<Visitor> helper{visitor};
helper(type);
}
auto make_visitor()
{
return [this](auto & type){ return derived().apply(type); };
}
auto apply(types::primitive_type & type)
{
return std::visit(make_visitor(), type);
}
auto apply(ast::type & type)
{
return std::visit(make_visitor(), type);
}
};
}

View file

@ -27,33 +27,36 @@ namespace pslang::ast
}
struct type_print_visitor
: const_type_visitor<type_print_visitor>
{
std::ostream & out;
void operator()(types::unit_type const &)
using const_type_visitor::apply;
void apply(types::unit_type const &)
{
out << "unit";
}
template <typename T>
void operator()(types::primitive_type_base<T> const & type)
void apply(types::primitive_type_base<T> const & type)
{
types::print(out, types::primitive_type{type});
}
void operator()(array_type const & type)
void apply(array_type const & type)
{
apply(*this, *type.element_type);
apply(*type.element_type);
out << "[" << type.size << "]";
}
void operator()(function_type const & type)
void apply(function_type const & type)
{
if (type.arguments.size() == 1)
{
apply(*this, *type.arguments.front());
apply(*type.arguments.front());
out << " -> ";
apply(*this, *type.result);
apply(*type.result);
return;
}
@ -63,111 +66,114 @@ namespace pslang::ast
{
if (!first) out << ", ";
first = false;
apply(*this, *argument);
apply(*argument);
}
out << ") -> ";
apply(*this, *type.result);
apply(*type.result);
}
void operator()(type_identifier const & type)
void apply(type_identifier const & type)
{
out << type.name;
}
};
struct expression_print_visitor
: const_expression_visitor<expression_print_visitor>
{
std::ostream & out;
print_options options;
using const_expression_visitor::apply;
template <typename Node>
void child(Node const & node)
{
++options.indent_level;
apply(*this, node);
apply(node);
--options.indent_level;
}
void operator()(bool_literal const & node)
void apply(bool_literal const & node)
{
put_indent(out, options);
out << "bool literal { value = " << (node.value ? "true" : "false") << " }\n";
}
void operator()(i8_literal const & node)
void apply(i8_literal const & node)
{
put_indent(out, options);
out << "i8 literal { value = " << (std::int32_t)node.value << " }\n";
}
void operator()(u8_literal const & node)
void apply(u8_literal const & node)
{
put_indent(out, options);
out << "u8 literal { value = " << (std::uint32_t)node.value << " }\n";
}
void operator()(i16_literal const & node)
void apply(i16_literal const & node)
{
put_indent(out, options);
out << "i16 literal { value = " << node.value << " }\n";
}
void operator()(u16_literal const & node)
void apply(u16_literal const & node)
{
put_indent(out, options);
out << "u16 literal { value = " << node.value << " }\n";
}
void operator()(i32_literal const & node)
void apply(i32_literal const & node)
{
put_indent(out, options);
out << "i32 literal { value = " << node.value << " }\n";
}
void operator()(u32_literal const & node)
void apply(u32_literal const & node)
{
put_indent(out, options);
out << "u32 literal { value = " << node.value << " }\n";
}
void operator()(i64_literal const & node)
void apply(i64_literal const & node)
{
put_indent(out, options);
out << "i64 literal { value = " << node.value << " }\n";
}
void operator()(u64_literal const & node)
void apply(u64_literal const & node)
{
put_indent(out, options);
out << "u64 literal { value = " << node.value << " }\n";
}
void operator()(f32_literal const & node)
void apply(f32_literal const & node)
{
put_indent(out, options);
out << "f32 literal { value = " << std::setprecision(7) << node.value << " }\n";
}
void operator()(f64_literal const & node)
void apply(f64_literal const & node)
{
put_indent(out, options);
out << "f64 literal { value = " << std::setprecision(15) << node.value << " }\n";
}
void operator()(identifier const & node)
void apply(identifier const & node)
{
put_indent(out, options);
out << "identifier { name = \"" << node.name << "\" }\n";
}
void operator()(unary_operation const & node)
void apply(unary_operation const & node)
{
put_indent(out, options);
out << node.type << '\n';
child(*node.arg1);
}
void operator()(binary_operation const & node)
void apply(binary_operation const & node)
{
put_indent(out, options);
out << node.type << '\n';
@ -175,7 +181,7 @@ namespace pslang::ast
child(*node.arg2);
}
void operator()(cast_operation const & node)
void apply(cast_operation const & node)
{
put_indent(out, options);
out << "cast as ";
@ -184,7 +190,7 @@ namespace pslang::ast
child(*node.expression);
}
void operator()(function_call const & node)
void apply(function_call const & node)
{
put_indent(out, options);
out << "call\n";
@ -193,7 +199,7 @@ namespace pslang::ast
child(*argument);
}
void operator()(array const & node)
void apply(array const & node)
{
put_indent(out, options);
out << "array\n";
@ -201,7 +207,7 @@ namespace pslang::ast
child(*element);
}
void operator()(array_access const & node)
void apply(array_access const & node)
{
put_indent(out, options);
out << "array access\n";
@ -209,7 +215,7 @@ namespace pslang::ast
child(*node.index);
}
void operator()(field_access const & node)
void apply(field_access const & node)
{
put_indent(out, options);
out << "field access { name = \"" << node.field_name << "\" }\n";
@ -218,24 +224,27 @@ namespace pslang::ast
};
struct statement_print_visitor
: const_statement_visitor<statement_print_visitor>
{
std::ostream & out;
print_options options;
using const_statement_visitor::apply;
template <typename Node>
void child(Node const & node)
{
++options.indent_level;
ast::apply(*this, node);
apply(node);
--options.indent_level;
}
void operator()(expression_ptr const & node)
void apply(expression_ptr const & node)
{
print(out, *node, options);
}
void operator()(assignment const & node)
void apply(assignment const & node)
{
put_indent(out, options);
out << "assignment\n";
@ -243,7 +252,7 @@ namespace pslang::ast
child(node.rhs);
}
void operator()(variable_declaration const & node)
void apply(variable_declaration const & node)
{
put_indent(out, options);
out << "variable declaration { category = " << node.category << ", name = \"" << node.name << "\"";
@ -256,27 +265,27 @@ namespace pslang::ast
child(node.initializer);
}
void operator()(if_block const & node)
void apply(if_block const & node)
{
put_indent(out, options);
out << "if\n";
child(node.condition);
}
void operator()(else_block const & node)
void apply(else_block const & node)
{
put_indent(out, options);
out << "else\n";
}
void operator()(else_if_block const & node)
void apply(else_if_block const & node)
{
put_indent(out, options);
out << "else if\n";
child(node.condition);
}
void operator()(if_chain const & node)
void apply(if_chain const & node)
{
put_indent(out, options);
out << "if chain\n";
@ -300,7 +309,7 @@ namespace pslang::ast
--options.indent_level;
}
void operator()(while_block const & node)
void apply(while_block const & node)
{
put_indent(out, options);
out << "while\n";
@ -308,7 +317,7 @@ namespace pslang::ast
child(*node.statements);
}
void operator()(function_definition const & node)
void apply(function_definition const & node)
{
put_indent(out, options);
out << "function { name = \"" << node.name << "\", return type = ";
@ -327,7 +336,7 @@ namespace pslang::ast
child(*node.statements);
}
void operator()(return_statement const & node)
void apply(return_statement const & node)
{
put_indent(out, options);
out << "return\n";
@ -335,7 +344,7 @@ namespace pslang::ast
child(node.value);
}
void operator()(field_definition const & node)
void apply(field_definition const & node)
{
put_indent(out, options);
out << "field { name = \"" << node.name << "\", type = ";
@ -343,7 +352,7 @@ namespace pslang::ast
out << " }\n";
}
void operator()(struct_definition const & node)
void apply(struct_definition const & node)
{
put_indent(out, options);
out << "struct { name = \"" << node.name << "\" }\n";
@ -356,22 +365,22 @@ namespace pslang::ast
void print(std::ostream & out, type const & node)
{
apply(type_print_visitor{out}, node);
type_print_visitor{{}, out}.apply(node);
}
void print(std::ostream & out, expression const & node, print_options const & options)
{
apply(expression_print_visitor{out, options}, node);
expression_print_visitor{{}, out, options}.apply(node);
}
void print(std::ostream & out, statement const & node, print_options const & options)
{
apply(statement_print_visitor{out, options}, node);
statement_print_visitor{{}, out, options}.apply(node);
}
void print(std::ostream & out, statement_list const & node, print_options const & options)
{
apply(statement_print_visitor{out, options}, node);
statement_print_visitor{{}, out, options}.apply(node);
}
}

View file

@ -56,28 +56,35 @@ namespace pslang::ast
};
struct resolve_identifiers_visitor
: type_visitor<resolve_identifiers_visitor>
, expression_visitor<resolve_identifiers_visitor>
, statement_visitor<resolve_identifiers_visitor>
{
std::vector<scope> scopes;
void operator()(types::unit_type const &)
using type_visitor::apply;
using expression_visitor::apply;
using statement_visitor::apply;
void apply(types::unit_type const &)
{}
void operator()(types::primitive_type const &)
void apply(types::primitive_type const &)
{}
void operator()(array_type const & array_type)
void apply(array_type const & array_type)
{
apply(*this, *array_type.element_type);
apply(*array_type.element_type);
}
void operator()(function_type const & function_type)
void apply(function_type const & function_type)
{
for (auto const & argument : function_type.arguments)
apply(*this, *argument);
apply(*this, *function_type.result);
apply(*argument);
apply(*function_type.result);
}
void operator()(type_identifier & identifier)
void apply(type_identifier & identifier)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
@ -91,10 +98,10 @@ namespace pslang::ast
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
}
void operator()(literal const &)
void apply(literal const &)
{}
void operator()(identifier & identifier)
void apply(identifier & identifier)
{
bool crossed_function_scope = false;
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
@ -111,106 +118,106 @@ namespace pslang::ast
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
}
void operator()(unary_operation const & unary_operation)
void apply(unary_operation const & unary_operation)
{
apply(*this, *unary_operation.arg1);
apply(*unary_operation.arg1);
}
void operator()(binary_operation const & binary_operation)
void apply(binary_operation const & binary_operation)
{
apply(*this, *binary_operation.arg1);
apply(*this, *binary_operation.arg2);
apply(*binary_operation.arg1);
apply(*binary_operation.arg2);
}
void operator()(cast_operation const & cast_operation)
void apply(cast_operation const & cast_operation)
{
apply(*this, *cast_operation.expression);
apply(*this, *cast_operation.type);
apply(*cast_operation.expression);
apply(*cast_operation.type);
}
void operator()(function_call const & function_call)
void apply(function_call const & function_call)
{
apply(*this, *function_call.function);
apply(*function_call.function);
for (auto const & argument : function_call.arguments)
apply(*this, *argument);
apply(*argument);
}
void operator()(array const & array)
void apply(array const & array)
{
for (auto const & element : array.elements)
apply(*this, *element);
apply(*element);
}
void operator()(array_access const & array_access)
void apply(array_access const & array_access)
{
apply(*this, *array_access.array);
apply(*this, *array_access.index);
apply(*array_access.array);
apply(*array_access.index);
}
void operator()(field_access const & field_access)
void apply(field_access const & field_access)
{
apply(*this, *field_access.object);
apply(*field_access.object);
}
void operator()(expression_ptr const & expression_ptr)
void apply(expression_ptr const & expression_ptr)
{
apply(*this, *expression_ptr);
apply(*expression_ptr);
}
void operator()(assignment const & assignment)
void apply(assignment const & assignment)
{
ast::apply(*this, assignment.lhs);
ast::apply(*this, assignment.rhs);
apply(assignment.lhs);
apply(assignment.rhs);
}
void operator()(variable_declaration const & variable_declaration)
void apply(variable_declaration const & variable_declaration)
{
if (scopes.back().contains(variable_declaration.name))
throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location);
if (variable_declaration.type)
apply(*this, *variable_declaration.type);
apply(*this, *variable_declaration.initializer);
apply(*variable_declaration.type);
apply(*variable_declaration.initializer);
scopes.back().variables.insert(variable_declaration.name);
}
void operator()(if_block const & if_block)
void apply(if_block const & if_block)
{
throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location);
}
void operator()(else_if_block const & else_if_block)
void apply(else_if_block const & else_if_block)
{
throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location);
}
void operator()(else_block const & else_block)
void apply(else_block const & else_block)
{
throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location);
}
void operator()(if_chain const & if_chain)
void apply(if_chain const & if_chain)
{
for (auto const & block : if_chain.blocks)
{
if (block.condition)
apply(*this, *block.condition);
apply(*block.condition);
scopes.emplace_back();
apply(*this, *block.statements);
apply(*block.statements);
scopes.pop_back();
}
}
void operator()(while_block const & while_block)
void apply(while_block const & while_block)
{
apply(*this, *while_block.condition);
apply(*while_block.condition);
scopes.emplace_back();
apply(*this, *while_block.statements);
apply(*while_block.statements);
scopes.pop_back();
}
void operator()(function_definition const & function_definition)
void apply(function_definition const & function_definition)
{
if (scopes.back().contains(function_definition.name))
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
@ -221,38 +228,38 @@ namespace pslang::ast
if (argument_names.count(argument.name) > 0)
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location);
argument_names.insert(argument.name);
apply(*this, *argument.type);
apply(*argument.type);
}
apply(*this, *function_definition.return_type);
apply(*function_definition.return_type);
scopes.back().functions.insert(function_definition.name);
auto & scope = scopes.emplace_back();
scope.is_function_scope = true;
scope.variables = std::move(argument_names);
apply(*this, *function_definition.statements);
apply(*function_definition.statements);
scopes.pop_back();
}
void operator()(return_statement const & return_statement)
void apply(return_statement const & return_statement)
{
if (return_statement.value)
apply(*this, *return_statement.value);
apply(*return_statement.value);
}
void operator()(field_definition const & field_definition)
void apply(field_definition const & field_definition)
{
apply(*this, *field_definition.type);
apply(*field_definition.type);
}
void operator()(struct_definition const & struct_definition)
void apply(struct_definition const & struct_definition)
{
if (scopes.back().contains(struct_definition.name))
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
for (auto const & field : struct_definition.fields)
apply(*this, field);
apply(field);
scopes.back().structs.insert(struct_definition.name);
}
@ -265,7 +272,7 @@ namespace pslang::ast
{
resolve_identifiers_visitor visitor;
visitor.scopes.emplace_back().is_global_scope = true;
apply(visitor, *statements);
visitor.apply(*statements);
}
}

View file

@ -13,14 +13,19 @@ namespace pslang::types
return static_cast<Derived &>(*this);
}
auto make_visitor()
{
return [this](auto const & type){ return derived().apply(type); };
}
auto apply(types::primitive_type const & type)
{
return std::visit([this](auto const & type){ return derived().apply(type); }, type);
return std::visit(make_visitor(), type);
}
auto apply(types::type const & type)
{
return std::visit([this](auto const & type){ return derived().apply(type); }, type);
return std::visit(make_visitor(), type);
}
};
@ -32,14 +37,19 @@ namespace pslang::types
return static_cast<Derived &>(*this);
}
auto make_visitor()
{
return [this](auto & type){ return derived().apply(type); };
}
auto apply(types::primitive_type & type)
{
return std::visit([this](auto & type){ return derived().apply(type); }, type);
return std::visit(make_visitor(), type);
}
auto apply(types::type & type)
{
return std::visit([this](auto & type){ return derived().apply(type); }, type);
return std::visit(make_visitor(), type);
}
};