diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 845d776..e2bdbe3 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -111,6 +111,7 @@ int main(int argc, char ** argv) filenames.push_back(argv[arg]); auto ast = parser::parse(filenames.back()); ast::resolve_identifiers(ast); + ast::check_and_infer_types(ast); parsed.push_back(std::move(ast)); } catch (pslang::ast::parse_error const & error) diff --git a/libs/ast/include/pslang/ast/array.hpp b/libs/ast/include/pslang/ast/array.hpp index d307a0a..d57f91a 100644 --- a/libs/ast/include/pslang/ast/array.hpp +++ b/libs/ast/include/pslang/ast/array.hpp @@ -2,6 +2,7 @@ #include #include +#include #include @@ -12,6 +13,7 @@ namespace pslang::ast { std::vector elements; ast::location location; + types::type_ptr inferred_type = nullptr; }; struct array_access @@ -19,6 +21,7 @@ namespace pslang::ast expression_ptr array; expression_ptr index; ast::location location; + types::type_ptr inferred_type = nullptr; }; } diff --git a/libs/ast/include/pslang/ast/cast.hpp b/libs/ast/include/pslang/ast/cast.hpp index 14823cd..4089b26 100644 --- a/libs/ast/include/pslang/ast/cast.hpp +++ b/libs/ast/include/pslang/ast/cast.hpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace pslang::ast { @@ -12,6 +13,7 @@ namespace pslang::ast expression_ptr expression; type_ptr type; ast::location location; + types::type_ptr inferred_type = nullptr; }; } diff --git a/libs/ast/include/pslang/ast/expression.hpp b/libs/ast/include/pslang/ast/expression.hpp index 8500c3f..07af6fc 100644 --- a/libs/ast/include/pslang/ast/expression.hpp +++ b/libs/ast/include/pslang/ast/expression.hpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace pslang::ast { @@ -18,6 +19,7 @@ namespace pslang::ast unary_operation_type type; expression_ptr arg1; ast::location location; + types::type_ptr inferred_type = nullptr; }; struct binary_operation @@ -26,6 +28,7 @@ namespace pslang::ast expression_ptr arg1; expression_ptr arg2; ast::location location; + types::type_ptr inferred_type = nullptr; }; using expression_impl = std::variant< @@ -47,5 +50,6 @@ namespace pslang::ast }; location get_location(expression const & expression); + types::type_ptr get_type(expression const & expression); } diff --git a/libs/ast/include/pslang/ast/function.hpp b/libs/ast/include/pslang/ast/function.hpp index 592ffd1..5b81544 100644 --- a/libs/ast/include/pslang/ast/function.hpp +++ b/libs/ast/include/pslang/ast/function.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -32,6 +33,7 @@ namespace pslang::ast expression_ptr function; std::vector arguments; ast::location location; + types::type_ptr inferred_type = nullptr; }; } diff --git a/libs/ast/include/pslang/ast/identifier.hpp b/libs/ast/include/pslang/ast/identifier.hpp index d0b9961..210c6f3 100644 --- a/libs/ast/include/pslang/ast/identifier.hpp +++ b/libs/ast/include/pslang/ast/identifier.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -12,6 +13,7 @@ namespace pslang::ast std::string name; ast::location location; std::size_t level = 0; + types::type_ptr inferred_type = nullptr; }; } diff --git a/libs/ast/include/pslang/ast/preprocess.hpp b/libs/ast/include/pslang/ast/preprocess.hpp index 2f29fb8..6aaeb24 100644 --- a/libs/ast/include/pslang/ast/preprocess.hpp +++ b/libs/ast/include/pslang/ast/preprocess.hpp @@ -6,5 +6,6 @@ namespace pslang::ast { void resolve_identifiers(statement_list_ptr & statements); + void check_and_infer_types(statement_list_ptr & statements); } diff --git a/libs/ast/include/pslang/ast/struct.hpp b/libs/ast/include/pslang/ast/struct.hpp index 40bfd89..e4c9558 100644 --- a/libs/ast/include/pslang/ast/struct.hpp +++ b/libs/ast/include/pslang/ast/struct.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -29,6 +30,7 @@ namespace pslang::ast expression_ptr object; std::string field_name; ast::location location; + types::type_ptr inferred_type = nullptr; }; } diff --git a/libs/ast/include/pslang/ast/type.hpp b/libs/ast/include/pslang/ast/type.hpp index 0eee52d..4a01968 100644 --- a/libs/ast/include/pslang/ast/type.hpp +++ b/libs/ast/include/pslang/ast/type.hpp @@ -17,12 +17,14 @@ namespace pslang::ast { type_ptr element_type; std::uint64_t size; + types::type_ptr inferred_type = nullptr; }; struct function_type { std::vector arguments; type_ptr result; + types::type_ptr inferred_type = nullptr; }; struct type_identifier @@ -30,6 +32,7 @@ namespace pslang::ast std::string name; ast::location location; std::size_t level = 0; + types::type_ptr inferred_type = nullptr; }; using type_impl = std::variant< diff --git a/libs/ast/include/pslang/ast/type_fwd.hpp b/libs/ast/include/pslang/ast/type_fwd.hpp index 7dfebe5..4ab080e 100644 --- a/libs/ast/include/pslang/ast/type_fwd.hpp +++ b/libs/ast/include/pslang/ast/type_fwd.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include namespace pslang::ast @@ -9,4 +11,6 @@ namespace pslang::ast using type_ptr = std::shared_ptr; + types::type_ptr get_type(type const & type); + } diff --git a/libs/ast/source/expression.cpp b/libs/ast/source/expression.cpp index b9bb5e5..990d868 100644 --- a/libs/ast/source/expression.cpp +++ b/libs/ast/source/expression.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace pslang::ast { @@ -59,6 +60,58 @@ namespace pslang::ast } }; + struct get_type_visitor + : const_expression_visitor + { + using const_expression_visitor::apply; + + template + types::type_ptr apply(primitive_literal_base const & node) + { + return std::make_unique(types::primitive_type{types::primitive_type_base{}}); + } + + types::type_ptr apply(identifier const & node) + { + return node.inferred_type; + } + + types::type_ptr apply(unary_operation const & node) + { + return node.inferred_type; + } + + types::type_ptr apply(binary_operation const & node) + { + return node.inferred_type; + } + + types::type_ptr apply(cast_operation const & node) + { + return node.inferred_type; + } + + types::type_ptr apply(function_call const & node) + { + return node.inferred_type; + } + + types::type_ptr apply(array const & node) + { + return node.inferred_type; + } + + types::type_ptr apply(array_access const & node) + { + return node.inferred_type; + } + + types::type_ptr apply(field_access const & node) + { + return node.inferred_type; + } + }; + } location get_location(expression const & expression) @@ -66,4 +119,9 @@ namespace pslang::ast return get_location_visitor{}.apply(expression); } + types::type_ptr get_type(expression const & expression) + { + return get_type_visitor{}.apply(expression); + } + } diff --git a/libs/ast/source/type.cpp b/libs/ast/source/type.cpp new file mode 100644 index 0000000..2107fca --- /dev/null +++ b/libs/ast/source/type.cpp @@ -0,0 +1,49 @@ +#include +#include +#include + +namespace pslang::ast +{ + + namespace + { + + struct get_type_visitor + : const_type_visitor + { + using const_type_visitor::apply; + + types::type_ptr apply(types::unit_type const & type) + { + return std::make_unique(type); + } + + types::type_ptr apply(types::primitive_type const & type) + { + return std::make_unique(type); + } + + types::type_ptr apply(ast::array_type const & type) + { + return type.inferred_type; + } + + types::type_ptr apply(ast::function_type const & type) + { + return type.inferred_type; + } + + types::type_ptr apply(ast::type_identifier const & type) + { + return type.inferred_type; + } + }; + + } + + types::type_ptr get_type(type const & type) + { + return get_type_visitor{}.apply(type); + } + +} diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp new file mode 100644 index 0000000..98e3167 --- /dev/null +++ b/libs/ast/source/type_check.cpp @@ -0,0 +1,677 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace pslang::ast +{ + + namespace + { + + struct variable_data + { + value_category category; + types::type_ptr type; + }; + + struct function_data + { + std::vector arguments; + types::type_ptr result_type; + }; + + struct struct_data + { + struct field_data + { + std::string name; + types::type_ptr type; + }; + + std::vector fields; + }; + + struct scope + { + std::unordered_map variables; + std::unordered_map functions; + std::unordered_map structs; + + bool is_function_scope = false; + bool is_global_scope = false; + + types::type_ptr expected_return_type = nullptr; + }; + + struct check_visitor + : type_visitor + , expression_visitor + , const_statement_visitor + { + std::vector scopes; + + using type_visitor::apply; + using expression_visitor::apply; + using const_statement_visitor::apply; + + void apply(types::unit_type const &) + {} + + void apply(types::primitive_type const &) + {} + + void apply(ast::array_type & node) + { + apply(*node.element_type); + + auto type = types::array_type{}; + type.element_type = get_type(*node.element_type); + type.size = node.size; + node.inferred_type = std::make_unique(std::move(type)); + } + + void apply(ast::function_type & node) + { + for (auto const & argument : node.arguments) + apply(*argument); + apply(*node.result); + + auto type = types::function_type{}; + for (auto const & argument : node.arguments) + type.arguments.push_back(get_type(*argument)); + type.result = get_type(*node.result); + + node.inferred_type = std::make_unique(std::move(type)); + } + + void apply(ast::type_identifier & node) + { + auto type = types::named_type{}; + type.name = node.name; + type.level = node.level; + node.inferred_type = std::make_unique(std::move(type)); + } + + void apply(literal &) + {} + + void apply(identifier & node) + { + if (auto type = types::builtin_type(node.name)) + { + node.inferred_type = type; + return; + } + + auto & scope = scopes.at(node.level); + if (auto it = scope.variables.find(node.name); it != scope.variables.end()) + { + node.inferred_type = it->second.type; + } + else if (auto it = scope.functions.find(node.name); it != scope.functions.end()) + { + auto type = types::function_type{}; + for (auto const & argument : it->second.arguments) + type.arguments.push_back(argument); + type.result = it->second.result_type; + node.inferred_type = std::make_unique(std::move(type)); + } + } + + void apply(unary_operation & node) + { + apply(*node.arg1); + auto arg1_type = get_type(*node.arg1); + + bool good = false; + + switch (node.type) + { + case unary_operation_type::negation: + if (types::is_integer_type(*arg1_type) || types::is_floating_point_type(*arg1_type)) + { + node.inferred_type = arg1_type; + return; + } + break; + case unary_operation_type::logical_not: + if (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)) + { + node.inferred_type = arg1_type; + return; + } + break; + } + + std::ostringstream os; + os << "Cannot apply " << node.type << " to a value of type "; + types::print(os, *arg1_type); + throw type_error(os.str(), node.location); + } + + void apply(binary_operation & node) + { + apply(*node.arg1); + apply(*node.arg2); + + auto arg1_type = get_type(*node.arg1); + auto arg2_type = get_type(*node.arg2); + + bool equal = types::equal(*arg1_type, *arg2_type); + + switch (node.type) + { + case binary_operation_type::addition: + if (equal && types::is_numeric_type(*arg1_type)) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::subtraction: + if (equal && types::is_numeric_type(*arg1_type)) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::multiplication: + if (equal && types::is_numeric_type(*arg1_type)) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::division: + if (equal && types::is_numeric_type(*arg1_type)) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::remainder: + if (equal && types::is_integer_type(*arg1_type)) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::logical_and: + if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::logical_or: + if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::logical_xor: + if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) + { + node.inferred_type = arg1_type; + return; + } + break; + case binary_operation_type::equals: + if (equal) + { + node.inferred_type = std::make_unique(types::bool_type{}); + return; + } + break; + case binary_operation_type::not_equals: + if (equal) + return; + break; + case binary_operation_type::less: + if (equal) + { + node.inferred_type = std::make_unique(types::bool_type{}); + return; + } + break; + case binary_operation_type::greater: + if (equal) + { + node.inferred_type = std::make_unique(types::bool_type{}); + return; + } + break; + case binary_operation_type::less_equals: + if (equal) + { + node.inferred_type = std::make_unique(types::bool_type{}); + return; + } + break; + case binary_operation_type::greater_equals: + if (equal) + { + node.inferred_type = std::make_unique(types::bool_type{}); + return; + } + break; + } + + std::ostringstream os; + os << "Cannot apply " << node.type << " to values of types "; + types::print(os, *arg1_type); + os << " and "; + types::print(os, *arg2_type); + throw type_error(os.str(), node.location); + } + + void apply(cast_operation & node) + { + apply(*node.expression); + apply(*node.type); + + auto source_type = get_type(*node.expression); + auto target_type = get_type(*node.type); + + node.inferred_type = target_type; + + if (types::equal(*source_type, *target_type)) + return; + + if (std::get_if(target_type.get())) + if (!types::is_bool_type(*source_type) && !types::is_bool_type(*target_type)) + return; + + std::ostringstream os; + os << "Cannot cast a value of type "; + types::print(os, *source_type); + os << " to type "; + types::print(os, *target_type); + throw type_error(os.str(), node.location); + } + + void apply(function_call & node) + { + apply(*node.function); + for (auto const & argument : node.arguments) + apply(*argument); + + std::string function_name; + + if (auto identifier = std::get_if(node.function.get())) + { + if (types::type_ptr type = types::builtin_type(identifier->name)) + { + if (node.arguments.empty()) + { + node.inferred_type = type; + return; + } + + std::ostringstream os; + os << "Cannot create built-in type "; + types::print(os, *type); + os << ": expected 0 arguments, but got " << node.arguments.size(); + throw std::runtime_error(os.str()); + } + + auto & scope = scopes.at(identifier->level); + if (auto it = scope.structs.find(identifier->name); it != scope.structs.end()) + { + if (!node.arguments.empty()) + { + if (node.arguments.size() != it->second.fields.size()) + { + std::ostringstream os; + os << "Cannot create struct " << identifier->name << ": expected " << it->second.fields.size() << " arguments, but got " << node.arguments.size(); + throw type_error(os.str(), node.location); + } + + for (std::size_t i = 0; i < node.arguments.size(); ++i) + { + auto arg_type = get_type(*node.arguments[i]); + if (!types::equal(*arg_type, *it->second.fields[i].type)) + { + std::ostringstream os; + os << "Cannot create struct " << identifier->name << ": argument #" << i << " expected to have type "; + types::print(os, *it->second.fields[i].type); + os << " but got type "; + types::print(os, *arg_type); + throw type_error(os.str(), node.location); + } + } + } + + types::named_type type; + type.name = identifier->name; + type.level = identifier->level; + node.inferred_type = std::make_unique(std::move(type)); + return; + } + + function_name = identifier->name + " "; + } + + auto function_type = get_type(*node.function); + + auto ftype = std::get_if(function_type.get()); + if (!ftype) + { + std::ostringstream os; + os << "Cannot call a value of a non-function type "; + types::print(os, *function_type); + throw type_error(os.str(), node.location); + } + + if (ftype->arguments.size() != node.arguments.size()) + { + std::ostringstream os; + os << "Cannot call function " << function_name << " of type "; + types::print(os, *function_type); + os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size(); + throw type_error(os.str(), node.location); + } + + for (std::size_t i = 0; i < node.arguments.size(); ++i) + { + auto arg_type = get_type(*node.arguments[i]); + if (!types::equal(*arg_type, *ftype->arguments[i])) + { + std::ostringstream os; + os << "Cannot call function " << function_name << " of type "; + types::print(os, *function_type); + os << ": argument #" << i << " expected to have type "; + types::print(os, *ftype->arguments[i]); + os << " but got type "; + types::print(os, *arg_type); + throw type_error(os.str(), node.location); + } + } + + node.inferred_type = ftype->result; + } + + void apply(array & node) + { + if (node.elements.empty()) + throw invalid_ast_error("Empty array", node.location); + + for (auto const & element : node.elements) + apply(*element); + + types::type_ptr element_type = nullptr; + for (std::size_t i = 0; i < node.elements.size(); ++i) + { + auto current_type = get_type(*node.elements[0]); + + if (i == 0) + { + element_type = current_type; + } + else if (!types::equal(*element_type, *current_type)) + { + std::ostringstream os; + os << "Failed to infer array type: element #0 has type "; + types::print(os, *element_type); + os << " but element #" << i << " has type "; + types::print(os, *current_type); + } + } + + types::array_type type; + type.element_type = element_type; + type.size = node.elements.size(); + node.inferred_type = std::make_unique(std::move(type)); + } + + void apply(array_access & node) + { + apply(*node.array); + apply(*node.index); + + auto array_type = get_type(*node.array); + auto index_type = get_type(*node.index); + + if (!types::is_integer_type(*index_type)) + { + std::ostringstream os; + os << "Expected an integer type as index, but got "; + types::print(os, *index_type); + throw type_error(os.str(), get_location(*node.index)); + } + + auto atype = std::get_if(array_type.get()); + + if (!atype) + { + std::ostringstream os; + os << "Expected an array to index, but got "; + types::print(os, *array_type); + throw type_error(os.str(), get_location(*node.array)); + } + + node.inferred_type = atype->element_type; + } + + void apply(field_access & node) + { + apply(*node.object); + + auto object_type = get_type(*node.object); + auto named_type = std::get_if(object_type.get()); + + if (!named_type) + { + std::ostringstream os; + os << "Expected a struct, but got "; + types::print(os, *object_type); + throw type_error(os.str(), get_location(*node.object)); + } + + auto const & struct_data = scopes.at(named_type->level).structs.at(named_type->name); + + for (auto const & field : struct_data.fields) + { + if (field.name == node.field_name) + { + node.inferred_type = field.type; + return; + } + } + + std::ostringstream os; + os << "Struct \"" << named_type->name << "\" has no field named \"" << node.field_name << "\""; + throw type_error(os.str(), node.location); + } + + void apply(expression_ptr const & node) + { + apply(*node); + } + + void apply(assignment const & node) + { + apply(node.lhs); + apply(node.rhs); + auto ltype = get_type(*node.lhs); + auto rtype = get_type(*node.rhs); + // TODO: check lvalue + if (!types::equal(*ltype, *rtype)) + { + std::ostringstream os; + os << "Cannot assign a value of type "; + types::print(os, *rtype); + os << " to an expression of type "; + types::print(os, *ltype); + throw type_error(os.str(), node.location); + }; + } + + void apply(variable_declaration const & node) + { + apply(node.initializer); + auto actual_type = get_type(*node.initializer); + if (node.type) + { + apply(*node.type); + auto expected_type = get_type(*node.type); + if (!types::equal(*expected_type, *actual_type)) + { + std::ostringstream os; + os << "Cannot initialize a variable of type "; + types::print(os, *expected_type); + os << " with an expression of type "; + types::print(os, *actual_type); + throw type_error(os.str(), node.location); + } + } + + scopes.back().variables[node.name] = { + .category = node.category, + .type = actual_type, + }; + } + + void apply(if_block const & node) + { + throw invalid_ast_error("if blocks cannot be present in the final AST", node.location); + } + + void apply(else_if_block const & node) + { + throw invalid_ast_error("else if blocks cannot be present in the final AST", node.location); + } + + void apply(else_block const & node) + { + throw invalid_ast_error("else blocks cannot be present in the final AST", node.location); + } + + void apply(if_chain const & node) + { + for (auto const & block : node.blocks) + { + if (block.condition) + { + apply(block.condition); + auto actual_type = get_type(*block.condition); + if (!types::is_bool_type(*actual_type)) + { + std::ostringstream os; + os << "if condition expects a bool type, but got "; + types::print(os, *actual_type); + throw type_error(os.str(), get_location(*block.condition)); + } + } + + scopes.emplace_back(); + apply(*block.statements); + scopes.pop_back(); + } + } + + void apply(while_block const & node) + { + apply(node.condition); + auto actual_type = get_type(*node.condition); + if (!types::is_bool_type(*actual_type)) + { + std::ostringstream os; + os << "while condition expects a bool type, but got "; + types::print(os, *actual_type); + throw type_error(os.str(), get_location(*node.condition)); + } + + scopes.emplace_back(); + apply(*node.statements); + scopes.pop_back(); + } + + void apply(function_definition const & node) + { + for (auto const & argument : node.arguments) + apply(*argument.type); + apply(*node.return_type); + + auto & data = scopes.back().functions[node.name]; + for (auto const & argument : node.arguments) + data.arguments.push_back(get_type(*argument.type)); + data.result_type = get_type(*node.return_type); + + scopes.emplace_back().is_function_scope = true; + scopes.back().expected_return_type = get_type(*node.return_type); + + for (auto const & argument : node.arguments) + { + scopes.back().variables[argument.name] = { + .category = value_category::constant, + .type = get_type(*argument.type), + }; + } + + apply(*node.statements); + scopes.pop_back(); + } + + void apply(return_statement const & node) + { + types::type_ptr actual_type; + if (node.value) + { + apply(node.value); + actual_type = get_type(*node.value); + } + else + { + actual_type = std::make_unique(types::unit_type{}); + } + + auto & return_scope = scopes.at(node.level); + if (!return_scope.expected_return_type) + throw invalid_ast_error("Unexpected return level", node.location); + + if (!types::equal(*return_scope.expected_return_type, *actual_type)) + { + std::ostringstream os; + os << "Returning value of type "; + types::print(os, *actual_type); + os << " from a function returning "; + types::print(os, *return_scope.expected_return_type); + throw type_error(os.str(), node.location); + } + } + + void apply(field_definition const & node) + { + apply(*node.type); + } + + void apply(struct_definition const & node) + { + for (auto const & field : node.fields) + apply(field); + + auto & data = scopes.back().structs[node.name]; + for (auto const & field : node.fields) + data.fields.push_back({.name = field.name, .type = get_type(*field.type)}); + } + }; + + } + + void check_and_infer_types(statement_list_ptr & statements) + { + check_visitor visitor; + visitor.scopes.emplace_back().is_global_scope = true; + visitor.apply(*statements); + } + +} diff --git a/libs/types/include/pslang/types/type_fwd.hpp b/libs/types/include/pslang/types/type_fwd.hpp index b4dacdc..6f74226 100644 --- a/libs/types/include/pslang/types/type_fwd.hpp +++ b/libs/types/include/pslang/types/type_fwd.hpp @@ -14,4 +14,12 @@ namespace pslang::types type_ptr builtin_type(std::string const & name); + bool is_unit_type(type const & type); + bool is_bool_type(type const & type); + bool is_integer_type(type const & type); + bool is_signed_integer_type(type const & type); + bool is_unsigned_integer_type(type const & type); + bool is_floating_point_type(type const & type); + bool is_numeric_type(type const & type); + } diff --git a/libs/types/source/type.cpp b/libs/types/source/type.cpp index dbdb3f6..e02755b 100644 --- a/libs/types/source/type.cpp +++ b/libs/types/source/type.cpp @@ -38,4 +38,79 @@ namespace pslang::types return nullptr; } + bool is_unit_type(type const & type) + { + return equal(type, unit_type{}); + } + + bool is_bool_type(type const & type) + { + return equal(type, bool_type{}); + } + + bool is_integer_type(type const & type) + { + if (auto ptype = std::get_if(&type)) + { + return std::visit([](primitive_type_base const &) + { + return std::is_integral_v && !std::is_same_v; + }, *ptype); + } + + return false; + } + + bool is_signed_integer_type(type const & type) + { + if (auto ptype = std::get_if(&type)) + { + return std::visit([](primitive_type_base const &) + { + return std::is_integral_v && std::is_signed_v && !std::is_same_v; + }, *ptype); + } + + return false; + } + + bool is_unsigned_integer_type(type const & type) + { + if (auto ptype = std::get_if(&type)) + { + return std::visit([](primitive_type_base const &) + { + return std::is_integral_v && std::is_unsigned_v && !std::is_same_v; + }, *ptype); + } + + return false; + } + + bool is_floating_point_type(type const & type) + { + if (auto ptype = std::get_if(&type)) + { + return std::visit([](primitive_type_base const &) + { + return std::is_floating_point_v; + }, *ptype); + } + + return false; + } + + bool is_numeric_type(type const & type) + { + if (auto ptype = std::get_if(&type)) + { + return std::visit([](primitive_type_base const &) + { + return !std::is_same_v; + }, *ptype); + } + + return false; + } + }