diff --git a/libs/ast/include/pslang/ast/expression_visitor.hpp b/libs/ast/include/pslang/ast/expression_visitor.hpp index 00c2016..1839cab 100644 --- a/libs/ast/include/pslang/ast/expression_visitor.hpp +++ b/libs/ast/include/pslang/ast/expression_visitor.hpp @@ -5,67 +5,52 @@ namespace pslang::ast { - namespace detail + template + struct const_expression_visitor { - - template - struct const_expression_visitor_helper + Derived & derived() { - Visitor & visitor; + return static_cast(*this); + } - template - void operator()(Expression const & expression) - { - visitor(expression); - } - - void operator()(literal const & expression) - { - std::visit(*this, expression); - } - - void operator()(expression const & expression) - { - std::visit(*this, expression); - } - }; - - template - struct expression_visitor_helper + auto make_visitor() { - Visitor & visitor; + return [this](auto const & type){ return derived().apply(type); }; + } - template - void operator()(Expression & expression) - { - visitor(expression); - } + auto apply(literal const & expression) + { + return std::visit(make_visitor(), expression); + } - void operator()(literal & expression) - { - std::visit(*this, expression); - } + auto apply(expression const & expression) + { + return std::visit(make_visitor(), expression); + } + }; - void operator()(expression & expression) - { - std::visit(*this, expression); - } - }; - - } - - template - void apply(Visitor && visitor, expression const & expression) + template + struct expression_visitor { - detail::const_expression_visitor_helper helper{visitor}; - helper(expression); - } + Derived & derived() + { + return static_cast(*this); + } - template - void apply(Visitor && visitor, expression & expression) - { - detail::expression_visitor_helper helper{visitor}; - helper(expression); - } + auto make_visitor() + { + return [this](auto & type){ return derived().apply(type); }; + } + + auto apply(literal & expression) + { + return std::visit(make_visitor(), expression); + } + + auto apply(expression & expression) + { + return std::visit(make_visitor(), expression); + } + }; } diff --git a/libs/ast/include/pslang/ast/statement_visitor.hpp b/libs/ast/include/pslang/ast/statement_visitor.hpp index 888a001..19b8fd2 100644 --- a/libs/ast/include/pslang/ast/statement_visitor.hpp +++ b/libs/ast/include/pslang/ast/statement_visitor.hpp @@ -5,83 +5,54 @@ namespace pslang::ast { - namespace detail + template + struct const_statement_visitor { - - template - struct const_statement_visitor_helper + Derived & derived() { - Visitor & visitor; + return static_cast(*this); + } - template - void operator()(Statement const & statement) - { - visitor(statement); - } - - void operator()(statement const & statement) - { - std::visit(*this, statement); - } - - void operator()(statement_list const & statement_list) - { - for (auto const & statement : statement_list.statements) - operator()(*statement); - } - }; - - template - struct statement_visitor_helper + auto make_visitor() { - Visitor & visitor; + return [this](auto const & type){ return derived().apply(type); }; + } - template - void operator()(Statement & statement) - { - visitor(statement); - } + auto apply(statement const & statement) + { + return std::visit(make_visitor(), statement); + } - void operator()(statement & statement) - { - std::visit(*this, statement); - } + void apply(statement_list const & statement_list) + { + for (auto const & statement : statement_list.statements) + derived().apply(*statement); + } + }; - void operator()(statement_list & statement_list) - { - for (auto & statement : statement_list.statements) - operator()(*statement); - } - }; - - } - - template - void apply(Visitor && visitor, statement const & statement) + template + struct statement_visitor { - detail::const_statement_visitor_helper helper{visitor}; - helper(statement); - } + Derived & derived() + { + return static_cast(*this); + } - template - void apply(Visitor && visitor, statement & statement) - { - detail::statement_visitor_helper helper{visitor}; - helper(statement); - } + auto make_visitor() + { + return [this](auto & type){ return derived().apply(type); }; + } - template - void apply(Visitor && visitor, statement_list const & statement_list) - { - detail::const_statement_visitor_helper helper{visitor}; - helper(statement_list); - } + auto apply(statement & statement) + { + return std::visit(make_visitor(), statement); + } - template - void apply(Visitor && visitor, statement_list & statement_list) - { - detail::statement_visitor_helper helper{visitor}; - helper(statement_list); - } + void apply(statement_list & statement_list) + { + for (auto & statement : statement_list.statements) + derived().apply(*statement); + } + }; } diff --git a/libs/ast/include/pslang/ast/type_visitor.hpp b/libs/ast/include/pslang/ast/type_visitor.hpp index 3dd8d82..879a3ec 100644 --- a/libs/ast/include/pslang/ast/type_visitor.hpp +++ b/libs/ast/include/pslang/ast/type_visitor.hpp @@ -5,67 +5,52 @@ namespace pslang::ast { - namespace detail + template + struct const_type_visitor { - - template - struct const_type_visitor_helper + Derived & derived() { - Visitor & visitor; + return static_cast(*this); + } - template - void operator()(Type const & type) - { - visitor(type); - } - - void operator()(types::primitive_type const & type) - { - std::visit(*this, type); - } - - void operator()(ast::type const & type) - { - std::visit(*this, type); - } - }; - - template - struct type_visitor_helper + auto make_visitor() { - Visitor & visitor; + return [this](auto const & type){ return derived().apply(type); }; + } - template - void operator()(Type & type) - { - visitor(type); - } + auto apply(types::primitive_type const & type) + { + return std::visit(make_visitor(), type); + } - void operator()(types::primitive_type & type) - { - std::visit(*this, type); - } + auto apply(ast::type const & type) + { + return std::visit(make_visitor(), type); + } + }; - void operator()(ast::type & type) - { - std::visit(*this, type); - } - }; - - } - - template - void apply(Visitor && visitor, type const & type) + template + struct type_visitor { - detail::const_type_visitor_helper helper{visitor}; - helper(type); - } + Derived & derived() + { + return static_cast(*this); + } - template - void apply(Visitor && visitor, type & type) - { - detail::type_visitor_helper helper{visitor}; - helper(type); - } + auto make_visitor() + { + return [this](auto & type){ return derived().apply(type); }; + } + + auto apply(types::primitive_type & type) + { + return std::visit(make_visitor(), type); + } + + auto apply(ast::type & type) + { + return std::visit(make_visitor(), type); + } + }; } diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index db34ffa..486c004 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -27,33 +27,36 @@ namespace pslang::ast } struct type_print_visitor + : const_type_visitor { std::ostream & out; - void operator()(types::unit_type const &) + using const_type_visitor::apply; + + void apply(types::unit_type const &) { out << "unit"; } template - void operator()(types::primitive_type_base const & type) + void apply(types::primitive_type_base const & type) { types::print(out, types::primitive_type{type}); } - void operator()(array_type const & type) + void apply(array_type const & type) { - apply(*this, *type.element_type); + apply(*type.element_type); out << "[" << type.size << "]"; } - void operator()(function_type const & type) + void apply(function_type const & type) { if (type.arguments.size() == 1) { - apply(*this, *type.arguments.front()); + apply(*type.arguments.front()); out << " -> "; - apply(*this, *type.result); + apply(*type.result); return; } @@ -63,111 +66,114 @@ namespace pslang::ast { if (!first) out << ", "; first = false; - apply(*this, *argument); + apply(*argument); } out << ") -> "; - apply(*this, *type.result); + apply(*type.result); } - void operator()(type_identifier const & type) + void apply(type_identifier const & type) { out << type.name; } }; struct expression_print_visitor + : const_expression_visitor { std::ostream & out; print_options options; + using const_expression_visitor::apply; + template void child(Node const & node) { ++options.indent_level; - apply(*this, node); + apply(node); --options.indent_level; } - void operator()(bool_literal const & node) + void apply(bool_literal const & node) { put_indent(out, options); out << "bool literal { value = " << (node.value ? "true" : "false") << " }\n"; } - void operator()(i8_literal const & node) + void apply(i8_literal const & node) { put_indent(out, options); out << "i8 literal { value = " << (std::int32_t)node.value << " }\n"; } - void operator()(u8_literal const & node) + void apply(u8_literal const & node) { put_indent(out, options); out << "u8 literal { value = " << (std::uint32_t)node.value << " }\n"; } - void operator()(i16_literal const & node) + void apply(i16_literal const & node) { put_indent(out, options); out << "i16 literal { value = " << node.value << " }\n"; } - void operator()(u16_literal const & node) + void apply(u16_literal const & node) { put_indent(out, options); out << "u16 literal { value = " << node.value << " }\n"; } - void operator()(i32_literal const & node) + void apply(i32_literal const & node) { put_indent(out, options); out << "i32 literal { value = " << node.value << " }\n"; } - void operator()(u32_literal const & node) + void apply(u32_literal const & node) { put_indent(out, options); out << "u32 literal { value = " << node.value << " }\n"; } - void operator()(i64_literal const & node) + void apply(i64_literal const & node) { put_indent(out, options); out << "i64 literal { value = " << node.value << " }\n"; } - void operator()(u64_literal const & node) + void apply(u64_literal const & node) { put_indent(out, options); out << "u64 literal { value = " << node.value << " }\n"; } - void operator()(f32_literal const & node) + void apply(f32_literal const & node) { put_indent(out, options); out << "f32 literal { value = " << std::setprecision(7) << node.value << " }\n"; } - void operator()(f64_literal const & node) + void apply(f64_literal const & node) { put_indent(out, options); out << "f64 literal { value = " << std::setprecision(15) << node.value << " }\n"; } - void operator()(identifier const & node) + void apply(identifier const & node) { put_indent(out, options); out << "identifier { name = \"" << node.name << "\" }\n"; } - void operator()(unary_operation const & node) + void apply(unary_operation const & node) { put_indent(out, options); out << node.type << '\n'; child(*node.arg1); } - void operator()(binary_operation const & node) + void apply(binary_operation const & node) { put_indent(out, options); out << node.type << '\n'; @@ -175,7 +181,7 @@ namespace pslang::ast child(*node.arg2); } - void operator()(cast_operation const & node) + void apply(cast_operation const & node) { put_indent(out, options); out << "cast as "; @@ -184,7 +190,7 @@ namespace pslang::ast child(*node.expression); } - void operator()(function_call const & node) + void apply(function_call const & node) { put_indent(out, options); out << "call\n"; @@ -193,7 +199,7 @@ namespace pslang::ast child(*argument); } - void operator()(array const & node) + void apply(array const & node) { put_indent(out, options); out << "array\n"; @@ -201,7 +207,7 @@ namespace pslang::ast child(*element); } - void operator()(array_access const & node) + void apply(array_access const & node) { put_indent(out, options); out << "array access\n"; @@ -209,7 +215,7 @@ namespace pslang::ast child(*node.index); } - void operator()(field_access const & node) + void apply(field_access const & node) { put_indent(out, options); out << "field access { name = \"" << node.field_name << "\" }\n"; @@ -218,24 +224,27 @@ namespace pslang::ast }; struct statement_print_visitor + : const_statement_visitor { std::ostream & out; print_options options; + using const_statement_visitor::apply; + template void child(Node const & node) { ++options.indent_level; - ast::apply(*this, node); + apply(node); --options.indent_level; } - void operator()(expression_ptr const & node) + void apply(expression_ptr const & node) { print(out, *node, options); } - void operator()(assignment const & node) + void apply(assignment const & node) { put_indent(out, options); out << "assignment\n"; @@ -243,7 +252,7 @@ namespace pslang::ast child(node.rhs); } - void operator()(variable_declaration const & node) + void apply(variable_declaration const & node) { put_indent(out, options); out << "variable declaration { category = " << node.category << ", name = \"" << node.name << "\""; @@ -256,27 +265,27 @@ namespace pslang::ast child(node.initializer); } - void operator()(if_block const & node) + void apply(if_block const & node) { put_indent(out, options); out << "if\n"; child(node.condition); } - void operator()(else_block const & node) + void apply(else_block const & node) { put_indent(out, options); out << "else\n"; } - void operator()(else_if_block const & node) + void apply(else_if_block const & node) { put_indent(out, options); out << "else if\n"; child(node.condition); } - void operator()(if_chain const & node) + void apply(if_chain const & node) { put_indent(out, options); out << "if chain\n"; @@ -300,7 +309,7 @@ namespace pslang::ast --options.indent_level; } - void operator()(while_block const & node) + void apply(while_block const & node) { put_indent(out, options); out << "while\n"; @@ -308,7 +317,7 @@ namespace pslang::ast child(*node.statements); } - void operator()(function_definition const & node) + void apply(function_definition const & node) { put_indent(out, options); out << "function { name = \"" << node.name << "\", return type = "; @@ -327,7 +336,7 @@ namespace pslang::ast child(*node.statements); } - void operator()(return_statement const & node) + void apply(return_statement const & node) { put_indent(out, options); out << "return\n"; @@ -335,7 +344,7 @@ namespace pslang::ast child(node.value); } - void operator()(field_definition const & node) + void apply(field_definition const & node) { put_indent(out, options); out << "field { name = \"" << node.name << "\", type = "; @@ -343,7 +352,7 @@ namespace pslang::ast out << " }\n"; } - void operator()(struct_definition const & node) + void apply(struct_definition const & node) { put_indent(out, options); out << "struct { name = \"" << node.name << "\" }\n"; @@ -356,22 +365,22 @@ namespace pslang::ast void print(std::ostream & out, type const & node) { - apply(type_print_visitor{out}, node); + type_print_visitor{{}, out}.apply(node); } void print(std::ostream & out, expression const & node, print_options const & options) { - apply(expression_print_visitor{out, options}, node); + expression_print_visitor{{}, out, options}.apply(node); } void print(std::ostream & out, statement const & node, print_options const & options) { - apply(statement_print_visitor{out, options}, node); + statement_print_visitor{{}, out, options}.apply(node); } void print(std::ostream & out, statement_list const & node, print_options const & options) { - apply(statement_print_visitor{out, options}, node); + statement_print_visitor{{}, out, options}.apply(node); } } diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index 2e22496..9e180e6 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -56,28 +56,35 @@ namespace pslang::ast }; struct resolve_identifiers_visitor + : type_visitor + , expression_visitor + , statement_visitor { std::vector scopes; - void operator()(types::unit_type const &) + using type_visitor::apply; + using expression_visitor::apply; + using statement_visitor::apply; + + void apply(types::unit_type const &) {} - void operator()(types::primitive_type const &) + void apply(types::primitive_type const &) {} - void operator()(array_type const & array_type) + void apply(array_type const & array_type) { - apply(*this, *array_type.element_type); + apply(*array_type.element_type); } - void operator()(function_type const & function_type) + void apply(function_type const & function_type) { for (auto const & argument : function_type.arguments) - apply(*this, *argument); - apply(*this, *function_type.result); + apply(*argument); + apply(*function_type.result); } - void operator()(type_identifier & identifier) + void apply(type_identifier & identifier) { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { @@ -91,10 +98,10 @@ namespace pslang::ast throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); } - void operator()(literal const &) + void apply(literal const &) {} - void operator()(identifier & identifier) + void apply(identifier & identifier) { bool crossed_function_scope = false; for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) @@ -111,106 +118,106 @@ namespace pslang::ast throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); } - void operator()(unary_operation const & unary_operation) + void apply(unary_operation const & unary_operation) { - apply(*this, *unary_operation.arg1); + apply(*unary_operation.arg1); } - void operator()(binary_operation const & binary_operation) + void apply(binary_operation const & binary_operation) { - apply(*this, *binary_operation.arg1); - apply(*this, *binary_operation.arg2); + apply(*binary_operation.arg1); + apply(*binary_operation.arg2); } - void operator()(cast_operation const & cast_operation) + void apply(cast_operation const & cast_operation) { - apply(*this, *cast_operation.expression); - apply(*this, *cast_operation.type); + apply(*cast_operation.expression); + apply(*cast_operation.type); } - void operator()(function_call const & function_call) + void apply(function_call const & function_call) { - apply(*this, *function_call.function); + apply(*function_call.function); for (auto const & argument : function_call.arguments) - apply(*this, *argument); + apply(*argument); } - void operator()(array const & array) + void apply(array const & array) { for (auto const & element : array.elements) - apply(*this, *element); + apply(*element); } - void operator()(array_access const & array_access) + void apply(array_access const & array_access) { - apply(*this, *array_access.array); - apply(*this, *array_access.index); + apply(*array_access.array); + apply(*array_access.index); } - void operator()(field_access const & field_access) + void apply(field_access const & field_access) { - apply(*this, *field_access.object); + apply(*field_access.object); } - void operator()(expression_ptr const & expression_ptr) + void apply(expression_ptr const & expression_ptr) { - apply(*this, *expression_ptr); + apply(*expression_ptr); } - void operator()(assignment const & assignment) + void apply(assignment const & assignment) { - ast::apply(*this, assignment.lhs); - ast::apply(*this, assignment.rhs); + apply(assignment.lhs); + apply(assignment.rhs); } - void operator()(variable_declaration const & variable_declaration) + void apply(variable_declaration const & variable_declaration) { if (scopes.back().contains(variable_declaration.name)) throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location); if (variable_declaration.type) - apply(*this, *variable_declaration.type); - apply(*this, *variable_declaration.initializer); + apply(*variable_declaration.type); + apply(*variable_declaration.initializer); scopes.back().variables.insert(variable_declaration.name); } - void operator()(if_block const & if_block) + void apply(if_block const & if_block) { throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location); } - void operator()(else_if_block const & else_if_block) + void apply(else_if_block const & else_if_block) { throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location); } - void operator()(else_block const & else_block) + void apply(else_block const & else_block) { throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location); } - void operator()(if_chain const & if_chain) + void apply(if_chain const & if_chain) { for (auto const & block : if_chain.blocks) { if (block.condition) - apply(*this, *block.condition); + apply(*block.condition); scopes.emplace_back(); - apply(*this, *block.statements); + apply(*block.statements); scopes.pop_back(); } } - void operator()(while_block const & while_block) + void apply(while_block const & while_block) { - apply(*this, *while_block.condition); + apply(*while_block.condition); scopes.emplace_back(); - apply(*this, *while_block.statements); + apply(*while_block.statements); scopes.pop_back(); } - void operator()(function_definition const & function_definition) + void apply(function_definition const & function_definition) { if (scopes.back().contains(function_definition.name)) throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); @@ -221,38 +228,38 @@ namespace pslang::ast if (argument_names.count(argument.name) > 0) throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location); argument_names.insert(argument.name); - apply(*this, *argument.type); + apply(*argument.type); } - apply(*this, *function_definition.return_type); + apply(*function_definition.return_type); scopes.back().functions.insert(function_definition.name); auto & scope = scopes.emplace_back(); scope.is_function_scope = true; scope.variables = std::move(argument_names); - apply(*this, *function_definition.statements); + apply(*function_definition.statements); scopes.pop_back(); } - void operator()(return_statement const & return_statement) + void apply(return_statement const & return_statement) { if (return_statement.value) - apply(*this, *return_statement.value); + apply(*return_statement.value); } - void operator()(field_definition const & field_definition) + void apply(field_definition const & field_definition) { - apply(*this, *field_definition.type); + apply(*field_definition.type); } - void operator()(struct_definition const & struct_definition) + void apply(struct_definition const & struct_definition) { if (scopes.back().contains(struct_definition.name)) throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); for (auto const & field : struct_definition.fields) - apply(*this, field); + apply(field); scopes.back().structs.insert(struct_definition.name); } @@ -265,7 +272,7 @@ namespace pslang::ast { resolve_identifiers_visitor visitor; visitor.scopes.emplace_back().is_global_scope = true; - apply(visitor, *statements); + visitor.apply(*statements); } } diff --git a/libs/types/include/pslang/types/type_visitor.hpp b/libs/types/include/pslang/types/type_visitor.hpp index 8dfaa1d..b759ecf 100644 --- a/libs/types/include/pslang/types/type_visitor.hpp +++ b/libs/types/include/pslang/types/type_visitor.hpp @@ -13,14 +13,19 @@ namespace pslang::types return static_cast(*this); } + auto make_visitor() + { + return [this](auto const & type){ return derived().apply(type); }; + } + auto apply(types::primitive_type const & type) { - return std::visit([this](auto const & type){ return derived().apply(type); }, type); + return std::visit(make_visitor(), type); } auto apply(types::type const & type) { - return std::visit([this](auto const & type){ return derived().apply(type); }, type); + return std::visit(make_visitor(), type); } }; @@ -32,14 +37,19 @@ namespace pslang::types return static_cast(*this); } + auto make_visitor() + { + return [this](auto & type){ return derived().apply(type); }; + } + auto apply(types::primitive_type & type) { - return std::visit([this](auto & type){ return derived().apply(type); }, type); + return std::visit(make_visitor(), type); } auto apply(types::type & type) { - return std::visit([this](auto & type){ return derived().apply(type); }, type); + return std::visit(make_visitor(), type); } };