Refactor tree visitors (again)

This commit is contained in:
Nikita Lisitsa 2025-12-20 15:52:59 +03:00
parent 1f8c73614c
commit d33b00e368
6 changed files with 247 additions and 280 deletions

View file

@ -5,67 +5,52 @@
namespace pslang::ast namespace pslang::ast
{ {
namespace detail template <typename Derived>
struct const_expression_visitor
{ {
Derived & derived()
template <typename Visitor>
struct const_expression_visitor_helper
{ {
Visitor & visitor; return static_cast<Derived &>(*this);
}
template <typename Expression> auto make_visitor()
void operator()(Expression const & expression)
{
visitor(expression);
}
void operator()(literal const & expression)
{
std::visit(*this, expression);
}
void operator()(expression const & expression)
{
std::visit(*this, expression);
}
};
template <typename Visitor>
struct expression_visitor_helper
{ {
Visitor & visitor; return [this](auto const & type){ return derived().apply(type); };
}
template <typename Expression> auto apply(literal const & expression)
void operator()(Expression & expression) {
{ return std::visit(make_visitor(), expression);
visitor(expression); }
}
void operator()(literal & expression) auto apply(expression const & expression)
{ {
std::visit(*this, expression); return std::visit(make_visitor(), expression);
} }
};
void operator()(expression & expression) template <typename Derived>
{ struct expression_visitor
std::visit(*this, expression);
}
};
}
template <typename Visitor>
void apply(Visitor && visitor, expression const & expression)
{ {
detail::const_expression_visitor_helper<Visitor> helper{visitor}; Derived & derived()
helper(expression); {
} return static_cast<Derived &>(*this);
}
template <typename Visitor> auto make_visitor()
void apply(Visitor && visitor, expression & expression) {
{ return [this](auto & type){ return derived().apply(type); };
detail::expression_visitor_helper<Visitor> helper{visitor}; }
helper(expression);
} auto apply(literal & expression)
{
return std::visit(make_visitor(), expression);
}
auto apply(expression & expression)
{
return std::visit(make_visitor(), expression);
}
};
} }

View file

@ -5,83 +5,54 @@
namespace pslang::ast namespace pslang::ast
{ {
namespace detail template <typename Derived>
struct const_statement_visitor
{ {
Derived & derived()
template <typename Visitor>
struct const_statement_visitor_helper
{ {
Visitor & visitor; return static_cast<Derived &>(*this);
}
template <typename Statement> auto make_visitor()
void operator()(Statement const & statement)
{
visitor(statement);
}
void operator()(statement const & statement)
{
std::visit(*this, statement);
}
void operator()(statement_list const & statement_list)
{
for (auto const & statement : statement_list.statements)
operator()(*statement);
}
};
template <typename Visitor>
struct statement_visitor_helper
{ {
Visitor & visitor; return [this](auto const & type){ return derived().apply(type); };
}
template <typename Statement> auto apply(statement const & statement)
void operator()(Statement & statement) {
{ return std::visit(make_visitor(), statement);
visitor(statement); }
}
void operator()(statement & statement) void apply(statement_list const & statement_list)
{ {
std::visit(*this, statement); for (auto const & statement : statement_list.statements)
} derived().apply(*statement);
}
};
void operator()(statement_list & statement_list) template <typename Derived>
{ struct statement_visitor
for (auto & statement : statement_list.statements)
operator()(*statement);
}
};
}
template <typename Visitor>
void apply(Visitor && visitor, statement const & statement)
{ {
detail::const_statement_visitor_helper<Visitor> helper{visitor}; Derived & derived()
helper(statement); {
} return static_cast<Derived &>(*this);
}
template <typename Visitor> auto make_visitor()
void apply(Visitor && visitor, statement & statement) {
{ return [this](auto & type){ return derived().apply(type); };
detail::statement_visitor_helper<Visitor> helper{visitor}; }
helper(statement);
}
template <typename Visitor> auto apply(statement & statement)
void apply(Visitor && visitor, statement_list const & statement_list) {
{ return std::visit(make_visitor(), statement);
detail::const_statement_visitor_helper<Visitor> helper{visitor}; }
helper(statement_list);
}
template <typename Visitor> void apply(statement_list & statement_list)
void apply(Visitor && visitor, statement_list & statement_list) {
{ for (auto & statement : statement_list.statements)
detail::statement_visitor_helper<Visitor> helper{visitor}; derived().apply(*statement);
helper(statement_list); }
} };
} }

View file

@ -5,67 +5,52 @@
namespace pslang::ast namespace pslang::ast
{ {
namespace detail template <typename Derived>
struct const_type_visitor
{ {
Derived & derived()
template <typename Visitor>
struct const_type_visitor_helper
{ {
Visitor & visitor; return static_cast<Derived &>(*this);
}
template <typename Type> auto make_visitor()
void operator()(Type const & type)
{
visitor(type);
}
void operator()(types::primitive_type const & type)
{
std::visit(*this, type);
}
void operator()(ast::type const & type)
{
std::visit(*this, type);
}
};
template <typename Visitor>
struct type_visitor_helper
{ {
Visitor & visitor; return [this](auto const & type){ return derived().apply(type); };
}
template <typename Type> auto apply(types::primitive_type const & type)
void operator()(Type & type) {
{ return std::visit(make_visitor(), type);
visitor(type); }
}
void operator()(types::primitive_type & type) auto apply(ast::type const & type)
{ {
std::visit(*this, type); return std::visit(make_visitor(), type);
} }
};
void operator()(ast::type & type) template <typename Derived>
{ struct type_visitor
std::visit(*this, type);
}
};
}
template <typename Visitor>
void apply(Visitor && visitor, type const & type)
{ {
detail::const_type_visitor_helper<Visitor> helper{visitor}; Derived & derived()
helper(type); {
} return static_cast<Derived &>(*this);
}
template <typename Visitor> auto make_visitor()
void apply(Visitor && visitor, type & type) {
{ return [this](auto & type){ return derived().apply(type); };
detail::type_visitor_helper<Visitor> helper{visitor}; }
helper(type);
} auto apply(types::primitive_type & type)
{
return std::visit(make_visitor(), type);
}
auto apply(ast::type & type)
{
return std::visit(make_visitor(), type);
}
};
} }

View file

@ -27,33 +27,36 @@ namespace pslang::ast
} }
struct type_print_visitor struct type_print_visitor
: const_type_visitor<type_print_visitor>
{ {
std::ostream & out; std::ostream & out;
void operator()(types::unit_type const &) using const_type_visitor::apply;
void apply(types::unit_type const &)
{ {
out << "unit"; out << "unit";
} }
template <typename T> template <typename T>
void operator()(types::primitive_type_base<T> const & type) void apply(types::primitive_type_base<T> const & type)
{ {
types::print(out, types::primitive_type{type}); types::print(out, types::primitive_type{type});
} }
void operator()(array_type const & type) void apply(array_type const & type)
{ {
apply(*this, *type.element_type); apply(*type.element_type);
out << "[" << type.size << "]"; out << "[" << type.size << "]";
} }
void operator()(function_type const & type) void apply(function_type const & type)
{ {
if (type.arguments.size() == 1) if (type.arguments.size() == 1)
{ {
apply(*this, *type.arguments.front()); apply(*type.arguments.front());
out << " -> "; out << " -> ";
apply(*this, *type.result); apply(*type.result);
return; return;
} }
@ -63,111 +66,114 @@ namespace pslang::ast
{ {
if (!first) out << ", "; if (!first) out << ", ";
first = false; first = false;
apply(*this, *argument); apply(*argument);
} }
out << ") -> "; out << ") -> ";
apply(*this, *type.result); apply(*type.result);
} }
void operator()(type_identifier const & type) void apply(type_identifier const & type)
{ {
out << type.name; out << type.name;
} }
}; };
struct expression_print_visitor struct expression_print_visitor
: const_expression_visitor<expression_print_visitor>
{ {
std::ostream & out; std::ostream & out;
print_options options; print_options options;
using const_expression_visitor::apply;
template <typename Node> template <typename Node>
void child(Node const & node) void child(Node const & node)
{ {
++options.indent_level; ++options.indent_level;
apply(*this, node); apply(node);
--options.indent_level; --options.indent_level;
} }
void operator()(bool_literal const & node) void apply(bool_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "bool literal { value = " << (node.value ? "true" : "false") << " }\n"; out << "bool literal { value = " << (node.value ? "true" : "false") << " }\n";
} }
void operator()(i8_literal const & node) void apply(i8_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "i8 literal { value = " << (std::int32_t)node.value << " }\n"; out << "i8 literal { value = " << (std::int32_t)node.value << " }\n";
} }
void operator()(u8_literal const & node) void apply(u8_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "u8 literal { value = " << (std::uint32_t)node.value << " }\n"; out << "u8 literal { value = " << (std::uint32_t)node.value << " }\n";
} }
void operator()(i16_literal const & node) void apply(i16_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "i16 literal { value = " << node.value << " }\n"; out << "i16 literal { value = " << node.value << " }\n";
} }
void operator()(u16_literal const & node) void apply(u16_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "u16 literal { value = " << node.value << " }\n"; out << "u16 literal { value = " << node.value << " }\n";
} }
void operator()(i32_literal const & node) void apply(i32_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "i32 literal { value = " << node.value << " }\n"; out << "i32 literal { value = " << node.value << " }\n";
} }
void operator()(u32_literal const & node) void apply(u32_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "u32 literal { value = " << node.value << " }\n"; out << "u32 literal { value = " << node.value << " }\n";
} }
void operator()(i64_literal const & node) void apply(i64_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "i64 literal { value = " << node.value << " }\n"; out << "i64 literal { value = " << node.value << " }\n";
} }
void operator()(u64_literal const & node) void apply(u64_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "u64 literal { value = " << node.value << " }\n"; out << "u64 literal { value = " << node.value << " }\n";
} }
void operator()(f32_literal const & node) void apply(f32_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "f32 literal { value = " << std::setprecision(7) << node.value << " }\n"; out << "f32 literal { value = " << std::setprecision(7) << node.value << " }\n";
} }
void operator()(f64_literal const & node) void apply(f64_literal const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "f64 literal { value = " << std::setprecision(15) << node.value << " }\n"; out << "f64 literal { value = " << std::setprecision(15) << node.value << " }\n";
} }
void operator()(identifier const & node) void apply(identifier const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "identifier { name = \"" << node.name << "\" }\n"; out << "identifier { name = \"" << node.name << "\" }\n";
} }
void operator()(unary_operation const & node) void apply(unary_operation const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << node.type << '\n'; out << node.type << '\n';
child(*node.arg1); child(*node.arg1);
} }
void operator()(binary_operation const & node) void apply(binary_operation const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << node.type << '\n'; out << node.type << '\n';
@ -175,7 +181,7 @@ namespace pslang::ast
child(*node.arg2); child(*node.arg2);
} }
void operator()(cast_operation const & node) void apply(cast_operation const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "cast as "; out << "cast as ";
@ -184,7 +190,7 @@ namespace pslang::ast
child(*node.expression); child(*node.expression);
} }
void operator()(function_call const & node) void apply(function_call const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "call\n"; out << "call\n";
@ -193,7 +199,7 @@ namespace pslang::ast
child(*argument); child(*argument);
} }
void operator()(array const & node) void apply(array const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "array\n"; out << "array\n";
@ -201,7 +207,7 @@ namespace pslang::ast
child(*element); child(*element);
} }
void operator()(array_access const & node) void apply(array_access const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "array access\n"; out << "array access\n";
@ -209,7 +215,7 @@ namespace pslang::ast
child(*node.index); child(*node.index);
} }
void operator()(field_access const & node) void apply(field_access const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "field access { name = \"" << node.field_name << "\" }\n"; out << "field access { name = \"" << node.field_name << "\" }\n";
@ -218,24 +224,27 @@ namespace pslang::ast
}; };
struct statement_print_visitor struct statement_print_visitor
: const_statement_visitor<statement_print_visitor>
{ {
std::ostream & out; std::ostream & out;
print_options options; print_options options;
using const_statement_visitor::apply;
template <typename Node> template <typename Node>
void child(Node const & node) void child(Node const & node)
{ {
++options.indent_level; ++options.indent_level;
ast::apply(*this, node); apply(node);
--options.indent_level; --options.indent_level;
} }
void operator()(expression_ptr const & node) void apply(expression_ptr const & node)
{ {
print(out, *node, options); print(out, *node, options);
} }
void operator()(assignment const & node) void apply(assignment const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "assignment\n"; out << "assignment\n";
@ -243,7 +252,7 @@ namespace pslang::ast
child(node.rhs); child(node.rhs);
} }
void operator()(variable_declaration const & node) void apply(variable_declaration const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "variable declaration { category = " << node.category << ", name = \"" << node.name << "\""; out << "variable declaration { category = " << node.category << ", name = \"" << node.name << "\"";
@ -256,27 +265,27 @@ namespace pslang::ast
child(node.initializer); child(node.initializer);
} }
void operator()(if_block const & node) void apply(if_block const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "if\n"; out << "if\n";
child(node.condition); child(node.condition);
} }
void operator()(else_block const & node) void apply(else_block const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "else\n"; out << "else\n";
} }
void operator()(else_if_block const & node) void apply(else_if_block const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "else if\n"; out << "else if\n";
child(node.condition); child(node.condition);
} }
void operator()(if_chain const & node) void apply(if_chain const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "if chain\n"; out << "if chain\n";
@ -300,7 +309,7 @@ namespace pslang::ast
--options.indent_level; --options.indent_level;
} }
void operator()(while_block const & node) void apply(while_block const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "while\n"; out << "while\n";
@ -308,7 +317,7 @@ namespace pslang::ast
child(*node.statements); child(*node.statements);
} }
void operator()(function_definition const & node) void apply(function_definition const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "function { name = \"" << node.name << "\", return type = "; out << "function { name = \"" << node.name << "\", return type = ";
@ -327,7 +336,7 @@ namespace pslang::ast
child(*node.statements); child(*node.statements);
} }
void operator()(return_statement const & node) void apply(return_statement const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "return\n"; out << "return\n";
@ -335,7 +344,7 @@ namespace pslang::ast
child(node.value); child(node.value);
} }
void operator()(field_definition const & node) void apply(field_definition const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "field { name = \"" << node.name << "\", type = "; out << "field { name = \"" << node.name << "\", type = ";
@ -343,7 +352,7 @@ namespace pslang::ast
out << " }\n"; out << " }\n";
} }
void operator()(struct_definition const & node) void apply(struct_definition const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "struct { name = \"" << node.name << "\" }\n"; out << "struct { name = \"" << node.name << "\" }\n";
@ -356,22 +365,22 @@ namespace pslang::ast
void print(std::ostream & out, type const & node) void print(std::ostream & out, type const & node)
{ {
apply(type_print_visitor{out}, node); type_print_visitor{{}, out}.apply(node);
} }
void print(std::ostream & out, expression const & node, print_options const & options) void print(std::ostream & out, expression const & node, print_options const & options)
{ {
apply(expression_print_visitor{out, options}, node); expression_print_visitor{{}, out, options}.apply(node);
} }
void print(std::ostream & out, statement const & node, print_options const & options) void print(std::ostream & out, statement const & node, print_options const & options)
{ {
apply(statement_print_visitor{out, options}, node); statement_print_visitor{{}, out, options}.apply(node);
} }
void print(std::ostream & out, statement_list const & node, print_options const & options) void print(std::ostream & out, statement_list const & node, print_options const & options)
{ {
apply(statement_print_visitor{out, options}, node); statement_print_visitor{{}, out, options}.apply(node);
} }
} }

View file

@ -56,28 +56,35 @@ namespace pslang::ast
}; };
struct resolve_identifiers_visitor struct resolve_identifiers_visitor
: type_visitor<resolve_identifiers_visitor>
, expression_visitor<resolve_identifiers_visitor>
, statement_visitor<resolve_identifiers_visitor>
{ {
std::vector<scope> scopes; std::vector<scope> scopes;
void operator()(types::unit_type const &) using type_visitor::apply;
using expression_visitor::apply;
using statement_visitor::apply;
void apply(types::unit_type const &)
{} {}
void operator()(types::primitive_type const &) void apply(types::primitive_type const &)
{} {}
void operator()(array_type const & array_type) void apply(array_type const & array_type)
{ {
apply(*this, *array_type.element_type); apply(*array_type.element_type);
} }
void operator()(function_type const & function_type) void apply(function_type const & function_type)
{ {
for (auto const & argument : function_type.arguments) for (auto const & argument : function_type.arguments)
apply(*this, *argument); apply(*argument);
apply(*this, *function_type.result); apply(*function_type.result);
} }
void operator()(type_identifier & identifier) void apply(type_identifier & identifier)
{ {
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{ {
@ -91,10 +98,10 @@ namespace pslang::ast
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
} }
void operator()(literal const &) void apply(literal const &)
{} {}
void operator()(identifier & identifier) void apply(identifier & identifier)
{ {
bool crossed_function_scope = false; bool crossed_function_scope = false;
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
@ -111,106 +118,106 @@ namespace pslang::ast
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
} }
void operator()(unary_operation const & unary_operation) void apply(unary_operation const & unary_operation)
{ {
apply(*this, *unary_operation.arg1); apply(*unary_operation.arg1);
} }
void operator()(binary_operation const & binary_operation) void apply(binary_operation const & binary_operation)
{ {
apply(*this, *binary_operation.arg1); apply(*binary_operation.arg1);
apply(*this, *binary_operation.arg2); apply(*binary_operation.arg2);
} }
void operator()(cast_operation const & cast_operation) void apply(cast_operation const & cast_operation)
{ {
apply(*this, *cast_operation.expression); apply(*cast_operation.expression);
apply(*this, *cast_operation.type); apply(*cast_operation.type);
} }
void operator()(function_call const & function_call) void apply(function_call const & function_call)
{ {
apply(*this, *function_call.function); apply(*function_call.function);
for (auto const & argument : function_call.arguments) for (auto const & argument : function_call.arguments)
apply(*this, *argument); apply(*argument);
} }
void operator()(array const & array) void apply(array const & array)
{ {
for (auto const & element : array.elements) for (auto const & element : array.elements)
apply(*this, *element); apply(*element);
} }
void operator()(array_access const & array_access) void apply(array_access const & array_access)
{ {
apply(*this, *array_access.array); apply(*array_access.array);
apply(*this, *array_access.index); apply(*array_access.index);
} }
void operator()(field_access const & field_access) void apply(field_access const & field_access)
{ {
apply(*this, *field_access.object); apply(*field_access.object);
} }
void operator()(expression_ptr const & expression_ptr) void apply(expression_ptr const & expression_ptr)
{ {
apply(*this, *expression_ptr); apply(*expression_ptr);
} }
void operator()(assignment const & assignment) void apply(assignment const & assignment)
{ {
ast::apply(*this, assignment.lhs); apply(assignment.lhs);
ast::apply(*this, assignment.rhs); apply(assignment.rhs);
} }
void operator()(variable_declaration const & variable_declaration) void apply(variable_declaration const & variable_declaration)
{ {
if (scopes.back().contains(variable_declaration.name)) if (scopes.back().contains(variable_declaration.name))
throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location); throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location);
if (variable_declaration.type) if (variable_declaration.type)
apply(*this, *variable_declaration.type); apply(*variable_declaration.type);
apply(*this, *variable_declaration.initializer); apply(*variable_declaration.initializer);
scopes.back().variables.insert(variable_declaration.name); scopes.back().variables.insert(variable_declaration.name);
} }
void operator()(if_block const & if_block) void apply(if_block const & if_block)
{ {
throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location); throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location);
} }
void operator()(else_if_block const & else_if_block) void apply(else_if_block const & else_if_block)
{ {
throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location); throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location);
} }
void operator()(else_block const & else_block) void apply(else_block const & else_block)
{ {
throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location); throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location);
} }
void operator()(if_chain const & if_chain) void apply(if_chain const & if_chain)
{ {
for (auto const & block : if_chain.blocks) for (auto const & block : if_chain.blocks)
{ {
if (block.condition) if (block.condition)
apply(*this, *block.condition); apply(*block.condition);
scopes.emplace_back(); scopes.emplace_back();
apply(*this, *block.statements); apply(*block.statements);
scopes.pop_back(); scopes.pop_back();
} }
} }
void operator()(while_block const & while_block) void apply(while_block const & while_block)
{ {
apply(*this, *while_block.condition); apply(*while_block.condition);
scopes.emplace_back(); scopes.emplace_back();
apply(*this, *while_block.statements); apply(*while_block.statements);
scopes.pop_back(); scopes.pop_back();
} }
void operator()(function_definition const & function_definition) void apply(function_definition const & function_definition)
{ {
if (scopes.back().contains(function_definition.name)) if (scopes.back().contains(function_definition.name))
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
@ -221,38 +228,38 @@ namespace pslang::ast
if (argument_names.count(argument.name) > 0) if (argument_names.count(argument.name) > 0)
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location); throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location);
argument_names.insert(argument.name); argument_names.insert(argument.name);
apply(*this, *argument.type); apply(*argument.type);
} }
apply(*this, *function_definition.return_type); apply(*function_definition.return_type);
scopes.back().functions.insert(function_definition.name); scopes.back().functions.insert(function_definition.name);
auto & scope = scopes.emplace_back(); auto & scope = scopes.emplace_back();
scope.is_function_scope = true; scope.is_function_scope = true;
scope.variables = std::move(argument_names); scope.variables = std::move(argument_names);
apply(*this, *function_definition.statements); apply(*function_definition.statements);
scopes.pop_back(); scopes.pop_back();
} }
void operator()(return_statement const & return_statement) void apply(return_statement const & return_statement)
{ {
if (return_statement.value) if (return_statement.value)
apply(*this, *return_statement.value); apply(*return_statement.value);
} }
void operator()(field_definition const & field_definition) void apply(field_definition const & field_definition)
{ {
apply(*this, *field_definition.type); apply(*field_definition.type);
} }
void operator()(struct_definition const & struct_definition) void apply(struct_definition const & struct_definition)
{ {
if (scopes.back().contains(struct_definition.name)) if (scopes.back().contains(struct_definition.name))
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
for (auto const & field : struct_definition.fields) for (auto const & field : struct_definition.fields)
apply(*this, field); apply(field);
scopes.back().structs.insert(struct_definition.name); scopes.back().structs.insert(struct_definition.name);
} }
@ -265,7 +272,7 @@ namespace pslang::ast
{ {
resolve_identifiers_visitor visitor; resolve_identifiers_visitor visitor;
visitor.scopes.emplace_back().is_global_scope = true; visitor.scopes.emplace_back().is_global_scope = true;
apply(visitor, *statements); visitor.apply(*statements);
} }
} }

View file

@ -13,14 +13,19 @@ namespace pslang::types
return static_cast<Derived &>(*this); return static_cast<Derived &>(*this);
} }
auto make_visitor()
{
return [this](auto const & type){ return derived().apply(type); };
}
auto apply(types::primitive_type const & type) auto apply(types::primitive_type const & type)
{ {
return std::visit([this](auto const & type){ return derived().apply(type); }, type); return std::visit(make_visitor(), type);
} }
auto apply(types::type const & type) auto apply(types::type const & type)
{ {
return std::visit([this](auto const & type){ return derived().apply(type); }, type); return std::visit(make_visitor(), type);
} }
}; };
@ -32,14 +37,19 @@ namespace pslang::types
return static_cast<Derived &>(*this); return static_cast<Derived &>(*this);
} }
auto make_visitor()
{
return [this](auto & type){ return derived().apply(type); };
}
auto apply(types::primitive_type & type) auto apply(types::primitive_type & type)
{ {
return std::visit([this](auto & type){ return derived().apply(type); }, type); return std::visit(make_visitor(), type);
} }
auto apply(types::type & type) auto apply(types::type & type)
{ {
return std::visit([this](auto & type){ return derived().apply(type); }, type); return std::visit(make_visitor(), type);
} }
}; };