Simplify tree visitors

This commit is contained in:
Nikita Lisitsa 2025-12-20 13:20:44 +03:00
parent aee506d102
commit 8c07e1950b
7 changed files with 99 additions and 194 deletions

View file

@ -16,14 +16,7 @@ namespace pslang::ast
template <typename Expression> template <typename Expression>
void operator()(Expression const & expression) void operator()(Expression const & expression)
{ {
if constexpr (std::is_invocable_v<Visitor, const_expression_visitor_helper &, Expression const &>) visitor(expression);
{
visitor(*this, expression);
}
else
{
visitor(expression);
}
} }
void operator()(literal const & expression) void operator()(literal const & expression)
@ -45,14 +38,7 @@ namespace pslang::ast
template <typename Expression> template <typename Expression>
void operator()(Expression & expression) void operator()(Expression & expression)
{ {
if constexpr (std::is_invocable_v<Visitor, expression_visitor_helper &, Expression &>) visitor(expression);
{
visitor(*this, expression);
}
else
{
visitor(expression);
}
} }
void operator()(literal & expression) void operator()(literal & expression)

View file

@ -16,14 +16,7 @@ namespace pslang::ast
template <typename Statement> template <typename Statement>
void operator()(Statement const & statement) void operator()(Statement const & statement)
{ {
if constexpr (std::is_invocable_v<Visitor, const_statement_visitor_helper &, Statement const &>) visitor(statement);
{
visitor(*this, statement);
}
else
{
visitor(statement);
}
} }
void operator()(statement const & statement) void operator()(statement const & statement)
@ -46,14 +39,7 @@ namespace pslang::ast
template <typename Statement> template <typename Statement>
void operator()(Statement & statement) void operator()(Statement & statement)
{ {
if constexpr (std::is_invocable_v<Visitor, statement_visitor_helper &, Statement &>) visitor(statement);
{
visitor(*this, statement);
}
else
{
visitor(statement);
}
} }
void operator()(statement & statement) void operator()(statement & statement)

View file

@ -16,14 +16,7 @@ namespace pslang::ast
template <typename Type> template <typename Type>
void operator()(Type const & type) void operator()(Type const & type)
{ {
if constexpr (std::is_invocable_v<Visitor, const_type_visitor_helper &, Type const &>) visitor(type);
{
visitor(*this, type);
}
else
{
visitor(type);
}
} }
void operator()(types::primitive_type const & type) void operator()(types::primitive_type const & type)
@ -45,14 +38,7 @@ namespace pslang::ast
template <typename Type> template <typename Type>
void operator()(Type & type) void operator()(Type & type)
{ {
if constexpr (std::is_invocable_v<Visitor, type_visitor_helper &, Type &>) visitor(type);
{
visitor(*this, type);
}
else
{
visitor(type);
}
} }
void operator()(types::primitive_type & type) void operator()(types::primitive_type & type)

View file

@ -41,21 +41,19 @@ namespace pslang::ast
types::print(out, types::primitive_type{type}); types::print(out, types::primitive_type{type});
} }
template <typename Self> void operator()(array_type const & type)
void operator()(Self & self, array_type const & type)
{ {
self(*type.element_type); apply(*this, *type.element_type);
out << "[" << type.size << "]"; out << "[" << type.size << "]";
} }
template <typename Self> void operator()(function_type const & type)
void operator()(Self & self, function_type const & type)
{ {
if (type.arguments.size() == 1) if (type.arguments.size() == 1)
{ {
self(*type.arguments.front()); apply(*this, *type.arguments.front());
out << " -> "; out << " -> ";
self(*type.result); apply(*this, *type.result);
return; return;
} }
@ -65,10 +63,10 @@ namespace pslang::ast
{ {
if (!first) out << ", "; if (!first) out << ", ";
first = false; first = false;
self(*argument); apply(*this, *argument);
} }
out << ") -> "; out << ") -> ";
self(*type.result); apply(*this, *type.result);
} }
void operator()(type_identifier const & type) void operator()(type_identifier const & type)
@ -82,11 +80,11 @@ namespace pslang::ast
std::ostream & out; std::ostream & out;
print_options options; print_options options;
template <typename Self, typename Node> template <typename Node>
void child(Self & self, Node const & node) void child(Node const & node)
{ {
++options.indent_level; ++options.indent_level;
self(node); apply(*this, node);
--options.indent_level; --options.indent_level;
} }
@ -162,67 +160,60 @@ namespace pslang::ast
out << "identifier { name = \"" << node.name << "\" }\n"; out << "identifier { name = \"" << node.name << "\" }\n";
} }
template <typename Self> void operator()(unary_operation const & node)
void operator()(Self & self, unary_operation const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << node.type << '\n'; out << node.type << '\n';
child(self, *node.arg1); child(*node.arg1);
} }
template <typename Self> void operator()(binary_operation const & node)
void operator()(Self & self, binary_operation const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << node.type << '\n'; out << node.type << '\n';
child(self, *node.arg1); child(*node.arg1);
child(self, *node.arg2); child(*node.arg2);
} }
template <typename Self> void operator()(cast_operation const & node)
void operator()(Self & self, cast_operation const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "cast as "; out << "cast as ";
print(out, *node.type); print(out, *node.type);
out << '\n'; out << '\n';
child(self, *node.expression); child(*node.expression);
} }
template <typename Self> void operator()(function_call const & node)
void operator()(Self & self, function_call const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "call\n"; out << "call\n";
child(self, *node.function); child(*node.function);
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
child(self, *argument); child(*argument);
} }
template <typename Self> void operator()(array const & node)
void operator()(Self & self, array const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "array\n"; out << "array\n";
for (auto const & element : node.elements) for (auto const & element : node.elements)
child(self, *element); child(*element);
} }
template <typename Self> void operator()(array_access const & node)
void operator()(Self & self, array_access const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "array access\n"; out << "array access\n";
child(self, *node.array); child(*node.array);
child(self, *node.index); child(*node.index);
} }
template <typename Self> void operator()(field_access const & node)
void operator()(Self & self, field_access const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "field access { name = \"" << node.field_name << "\" }\n"; out << "field access { name = \"" << node.field_name << "\" }\n";
child(self, *node.object); child(*node.object);
} }
}; };
@ -231,11 +222,11 @@ namespace pslang::ast
std::ostream & out; std::ostream & out;
print_options options; print_options options;
template <typename Self, typename Node> template <typename Node>
void child(Self & self, Node const & node) void child(Node const & node)
{ {
++options.indent_level; ++options.indent_level;
self(node); ast::apply(*this, node);
--options.indent_level; --options.indent_level;
} }
@ -244,17 +235,15 @@ namespace pslang::ast
print(out, *node, options); print(out, *node, options);
} }
template <typename Self> void operator()(assignment const & node)
void operator()(Self & self, assignment const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "assignment\n"; out << "assignment\n";
child(self, node.lhs); child(node.lhs);
child(self, node.rhs); child(node.rhs);
} }
template <typename Self> void operator()(variable_declaration const & node)
void operator()(Self & self, variable_declaration const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "variable declaration { category = " << node.category << ", name = \"" << node.name << "\""; out << "variable declaration { category = " << node.category << ", name = \"" << node.name << "\"";
@ -264,15 +253,14 @@ namespace pslang::ast
print(out, *node.type); print(out, *node.type);
} }
out << " }\n"; out << " }\n";
child(self, node.initializer); child(node.initializer);
} }
template <typename Self> void operator()(if_block const & node)
void operator()(Self & self, if_block const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "if\n"; out << "if\n";
child(self, node.condition); child(node.condition);
} }
void operator()(else_block const & node) void operator()(else_block const & node)
@ -281,16 +269,14 @@ namespace pslang::ast
out << "else\n"; out << "else\n";
} }
template <typename Self> void operator()(else_if_block const & node)
void operator()(Self & self, else_if_block const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "else if\n"; out << "else if\n";
child(self, node.condition); child(node.condition);
} }
template <typename Self> void operator()(if_chain const & node)
void operator()(Self & self, if_chain const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "if chain\n"; out << "if chain\n";
@ -300,7 +286,7 @@ namespace pslang::ast
put_indent(out, options); put_indent(out, options);
out << "condition\n"; out << "condition\n";
if (block.condition) if (block.condition)
child(self, block.condition); child(block.condition);
else else
{ {
put_indent(out, as_child(options)); put_indent(out, as_child(options));
@ -309,22 +295,20 @@ namespace pslang::ast
put_indent(out, options); put_indent(out, options);
out << "body\n"; out << "body\n";
child(self, *block.statements); child(*block.statements);
} }
--options.indent_level; --options.indent_level;
} }
template <typename Self> void operator()(while_block const & node)
void operator()(Self & self, while_block const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "while\n"; out << "while\n";
child(self, node.condition); child(node.condition);
child(self, *node.statements); child(*node.statements);
} }
template <typename Self> void operator()(function_definition const & node)
void operator()(Self & self, function_definition const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "function { name = \"" << node.name << "\", return type = "; out << "function { name = \"" << node.name << "\", return type = ";
@ -340,16 +324,15 @@ namespace pslang::ast
} }
put_indent(out, options); put_indent(out, options);
out << "body\n"; out << "body\n";
child(self, *node.statements); child(*node.statements);
} }
template <typename Self> void operator()(return_statement const & node)
void operator()(Self & self, return_statement const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "return\n"; out << "return\n";
if (node.value) if (node.value)
child(self, node.value); child(node.value);
} }
void operator()(field_definition const & node) void operator()(field_definition const & node)
@ -360,13 +343,12 @@ namespace pslang::ast
out << " }\n"; out << " }\n";
} }
template <typename Self> void operator()(struct_definition const & node)
void operator()(Self & self, struct_definition const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "struct { name = \"" << node.name << "\" }\n"; out << "struct { name = \"" << node.name << "\" }\n";
for (auto const & field : node.fields) for (auto const & field : node.fields)
child(self, field); child(field);
} }
}; };

View file

@ -55,11 +55,6 @@ namespace pslang::ast
bool is_global_scope = false; bool is_global_scope = false;
}; };
struct context
{
std::vector<scope> scopes;
};
struct resolve_identifiers_visitor struct resolve_identifiers_visitor
{ {
std::vector<scope> scopes; std::vector<scope> scopes;
@ -70,18 +65,16 @@ namespace pslang::ast
void operator()(types::primitive_type const &) void operator()(types::primitive_type const &)
{} {}
template <typename Self> void operator()(array_type const & array_type)
void operator()(Self & self, array_type const & array_type)
{ {
self(*array_type.element_type); apply(*this, *array_type.element_type);
} }
template <typename Self> void operator()(function_type const & function_type)
void operator()(Self & self, function_type const & function_type)
{ {
for (auto const & argument : function_type.arguments) for (auto const & argument : function_type.arguments)
self(*argument); apply(*this, *argument);
self(*function_type.result); apply(*this, *function_type.result);
} }
void operator()(type_identifier & identifier) void operator()(type_identifier & identifier)
@ -118,52 +111,45 @@ namespace pslang::ast
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
} }
template <typename Self> void operator()(unary_operation const & unary_operation)
void operator()(Self & self, unary_operation const & unary_operation)
{ {
self(*unary_operation.arg1); apply(*this, *unary_operation.arg1);
} }
template <typename Self> void operator()(binary_operation const & binary_operation)
void operator()(Self & self, binary_operation const & binary_operation)
{ {
self(*binary_operation.arg1); apply(*this, *binary_operation.arg1);
self(*binary_operation.arg2); apply(*this, *binary_operation.arg2);
} }
template <typename Self> void operator()(cast_operation const & cast_operation)
void operator()(Self & self, cast_operation const & cast_operation)
{ {
self(*cast_operation.expression); apply(*this, *cast_operation.expression);
apply(*this, *cast_operation.type); apply(*this, *cast_operation.type);
} }
template <typename Self> void operator()(function_call const & function_call)
void operator()(Self & self, function_call const & function_call)
{ {
self(*function_call.function); apply(*this, *function_call.function);
for (auto const & argument : function_call.arguments) for (auto const & argument : function_call.arguments)
self(*argument); apply(*this, *argument);
} }
template <typename Self> void operator()(array const & array)
void operator()(Self & self, array const & array)
{ {
for (auto const & element : array.elements) for (auto const & element : array.elements)
self(*element); apply(*this, *element);
} }
template <typename Self> void operator()(array_access const & array_access)
void operator()(Self & self, array_access const & array_access)
{ {
self(*array_access.array); apply(*this, *array_access.array);
self(*array_access.index); apply(*this, *array_access.index);
} }
template <typename Self> void operator()(field_access const & field_access)
void operator()(Self & self, field_access const & field_access)
{ {
self(*field_access.object); apply(*this, *field_access.object);
} }
void operator()(expression_ptr const & expression_ptr) void operator()(expression_ptr const & expression_ptr)
@ -171,11 +157,10 @@ namespace pslang::ast
apply(*this, *expression_ptr); apply(*this, *expression_ptr);
} }
template <typename Self> void operator()(assignment const & assignment)
void operator()(Self & self, assignment const & assignment)
{ {
self(assignment.lhs); ast::apply(*this, assignment.lhs);
self(assignment.rhs); ast::apply(*this, assignment.rhs);
} }
void operator()(variable_declaration const & variable_declaration) void operator()(variable_declaration const & variable_declaration)
@ -205,30 +190,27 @@ namespace pslang::ast
throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location); throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location);
} }
template <typename Self> void operator()(if_chain const & if_chain)
void operator()(Self & self, if_chain const & if_chain)
{ {
for (auto const & block : if_chain.blocks) for (auto const & block : if_chain.blocks)
{ {
if (block.condition) if (block.condition)
apply(*this, *block.condition); apply(*this, *block.condition);
scopes.emplace_back(); scopes.emplace_back();
self(*block.statements); apply(*this, *block.statements);
scopes.pop_back(); scopes.pop_back();
} }
} }
template <typename Self> void operator()(while_block const & while_block)
void operator()(Self & self, while_block const & while_block)
{ {
apply(*this, *while_block.condition); apply(*this, *while_block.condition);
scopes.emplace_back(); scopes.emplace_back();
self(*while_block.statements); apply(*this, *while_block.statements);
scopes.pop_back(); scopes.pop_back();
} }
template <typename Self> void operator()(function_definition const & function_definition)
void operator()(Self & self, function_definition const & function_definition)
{ {
if (scopes.back().contains(function_definition.name)) if (scopes.back().contains(function_definition.name))
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
@ -249,7 +231,7 @@ namespace pslang::ast
auto & scope = scopes.emplace_back(); auto & scope = scopes.emplace_back();
scope.is_function_scope = true; scope.is_function_scope = true;
scope.variables = std::move(argument_names); scope.variables = std::move(argument_names);
self(*function_definition.statements); apply(*this, *function_definition.statements);
scopes.pop_back(); scopes.pop_back();
} }
@ -264,14 +246,13 @@ namespace pslang::ast
apply(*this, *field_definition.type); apply(*this, *field_definition.type);
} }
template <typename Self> void operator()(struct_definition const & struct_definition)
void operator()(Self & self, struct_definition const & struct_definition)
{ {
if (scopes.back().contains(struct_definition.name)) if (scopes.back().contains(struct_definition.name))
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
for (auto const & field : struct_definition.fields) for (auto const & field : struct_definition.fields)
self(field); apply(*this, field);
scopes.back().structs.insert(struct_definition.name); scopes.back().structs.insert(struct_definition.name);
} }

View file

@ -16,14 +16,7 @@ namespace pslang::types
template <typename Type> template <typename Type>
void operator()(Type const & type) void operator()(Type const & type)
{ {
if constexpr (std::is_invocable_v<Visitor, const_visitor_helper &, Type const &>) visitor(type);
{
visitor(*this, type);
}
else
{
visitor(type);
}
} }
void operator()(types::primitive_type const & type) void operator()(types::primitive_type const & type)
@ -45,14 +38,7 @@ namespace pslang::types
template <typename Type> template <typename Type>
void operator()(Type & type) void operator()(Type & type)
{ {
if constexpr (std::is_invocable_v<Visitor, visitor_helper &, Type &>) visitor(type);
{
visitor(*this, type);
}
else
{
visitor(type);
}
} }
void operator()(types::primitive_type & type) void operator()(types::primitive_type & type)

View file

@ -71,21 +71,19 @@ namespace pslang::types
out << "f64"; out << "f64";
} }
template <typename Self> void operator()(array_type const & type)
void operator()(Self & self, array_type const & type)
{ {
self(*type.element_type); apply(*this, *type.element_type);
out << "[" << type.size << "]"; out << "[" << type.size << "]";
} }
template <typename Self> void operator()(function_type const & type)
void operator()(Self & self, function_type const & type)
{ {
if (type.arguments.size() == 1) if (type.arguments.size() == 1)
{ {
self(*type.arguments.front()); apply(*this, *type.arguments.front());
out << " -> "; out << " -> ";
self(*type.result); apply(*this, *type.result);
return; return;
} }
@ -95,10 +93,10 @@ namespace pslang::types
{ {
if (!first) out << ", "; if (!first) out << ", ";
first = false; first = false;
self(*argument); apply(*this, *argument);
} }
out << ") -> "; out << ") -> ";
self(*type.result); apply(*this, *type.result);
} }
void operator()(named_type const & type) void operator()(named_type const & type)