From dea5c18cfd9585dccd7591237a25327ea107388c Mon Sep 17 00:00:00 2001 From: lisyarus Date: Tue, 23 Dec 2025 13:43:54 +0300 Subject: [PATCH] Explicitly mark constructor AST nodes --- libs/ast/include/pslang/ast/function.hpp | 5 +- libs/ast/source/print.cpp | 13 +- libs/ast/source/resolve_identifiers.cpp | 28 ++- libs/ast/source/type_check.cpp | 124 ++++++------ libs/interpreter/source/eval.cpp | 193 +++++++++---------- libs/parser/rules/pslang.y | 26 +-- libs/types/include/pslang/types/type_fwd.hpp | 1 + libs/types/source/type.cpp | 10 + 8 files changed, 221 insertions(+), 179 deletions(-) diff --git a/libs/ast/include/pslang/ast/function.hpp b/libs/ast/include/pslang/ast/function.hpp index 6d2955c..3482c7d 100644 --- a/libs/ast/include/pslang/ast/function.hpp +++ b/libs/ast/include/pslang/ast/function.hpp @@ -38,7 +38,10 @@ namespace pslang::ast struct function_call { - expression_ptr function; + // Exactly one of these 2 is non-null + expression_ptr function; // function call + type_ptr type; // type constructor + std::vector arguments; ast::location location; types::type_ptr inferred_type = nullptr; diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index 486c004..e2ce63c 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -193,8 +193,17 @@ namespace pslang::ast void apply(function_call const & node) { put_indent(out, options); - out << "call\n"; - child(*node.function); + if (node.function) + { + out << "call\n"; + child(*node.function); + } + if (node.type) + { + out << "constructor { type = "; + print(out, *node.type); + out << " }\n"; + } for (auto const & argument : node.arguments) child(*argument); } diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index 1a30abc..ea93e9b 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -138,11 +139,34 @@ namespace pslang::ast apply(*cast_operation.type); } - void apply(function_call const & function_call) + void apply(function_call & function_call) { - apply(*function_call.function); + if (function_call.function) + apply(*function_call.function); + if (function_call.type) + apply(*function_call.type); for (auto const & argument : function_call.arguments) apply(*argument); + + if (auto id = std::get_if(function_call.function.get())) + { + if (auto type = types::builtin_type(id->name)) + { + if (auto unit_type = std::get_if(type.get())) + function_call.type = std::make_unique(*unit_type); + else if (auto primitive_type = std::get_if(type.get())) + function_call.type = std::make_unique(*primitive_type); + else + throw invalid_ast_error("Unknown built-in type \"" + id->name + "\"", get_location(*function_call.function)); + + function_call.function = nullptr; + } + else if (scopes.at(id->level).structs.contains(id->name)) + { + function_call.type = std::make_unique(type_identifier{.name = id->name, .location = id->location, .level = id->level}); + function_call.function = nullptr; + } + } } void apply(array const & array) diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 98e3167..9c8b1ea 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -300,19 +300,67 @@ namespace pslang::ast void apply(function_call & node) { - apply(*node.function); + if (node.function) + apply(*node.function); + if (node.type) + apply(*node.type); for (auto const & argument : node.arguments) apply(*argument); - std::string function_name; - - if (auto identifier = std::get_if(node.function.get())) + if (node.function) { - if (types::type_ptr type = types::builtin_type(identifier->name)) + std::string function_name; + + if (auto identifier = std::get_if(node.function.get())) + function_name = identifier->name + " "; + + auto function_type = get_type(*node.function); + + auto ftype = std::get_if(function_type.get()); + if (!ftype) + { + std::ostringstream os; + os << "Cannot call a value of a non-function type "; + types::print(os, *function_type); + throw type_error(os.str(), node.location); + } + + if (ftype->arguments.size() != node.arguments.size()) + { + std::ostringstream os; + os << "Cannot call function " << function_name << " of type "; + types::print(os, *function_type); + os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size(); + throw type_error(os.str(), node.location); + } + + for (std::size_t i = 0; i < node.arguments.size(); ++i) + { + auto arg_type = get_type(*node.arguments[i]); + if (!types::equal(*arg_type, *ftype->arguments[i])) + { + std::ostringstream os; + os << "Cannot call function " << function_name << " of type "; + types::print(os, *function_type); + os << ": argument #" << i << " expected to have type "; + types::print(os, *ftype->arguments[i]); + os << " but got type "; + types::print(os, *arg_type); + throw type_error(os.str(), node.location); + } + } + + node.inferred_type = ftype->result; + } + else if (node.type) + { + auto type = get_type(*node.type); + + if (types::is_builtin_type(*type)) { if (node.arguments.empty()) { - node.inferred_type = type; + node.inferred_type = get_type(*node.type); return; } @@ -322,27 +370,28 @@ namespace pslang::ast os << ": expected 0 arguments, but got " << node.arguments.size(); throw std::runtime_error(os.str()); } - - auto & scope = scopes.at(identifier->level); - if (auto it = scope.structs.find(identifier->name); it != scope.structs.end()) + else if (auto named_type = std::get_if(type.get())) { + auto const & scope = scopes.at(named_type->level); + auto const & data = scope.structs.at(named_type->name); + if (!node.arguments.empty()) { - if (node.arguments.size() != it->second.fields.size()) + if (node.arguments.size() != data.fields.size()) { std::ostringstream os; - os << "Cannot create struct " << identifier->name << ": expected " << it->second.fields.size() << " arguments, but got " << node.arguments.size(); + os << "Cannot create struct " << named_type->name << ": expected " << data.fields.size() << " arguments, but got " << node.arguments.size(); throw type_error(os.str(), node.location); } for (std::size_t i = 0; i < node.arguments.size(); ++i) { auto arg_type = get_type(*node.arguments[i]); - if (!types::equal(*arg_type, *it->second.fields[i].type)) + if (!types::equal(*arg_type, *data.fields[i].type)) { std::ostringstream os; - os << "Cannot create struct " << identifier->name << ": argument #" << i << " expected to have type "; - types::print(os, *it->second.fields[i].type); + os << "Cannot create struct " << named_type->name << ": argument #" << i << " expected to have type "; + types::print(os, *data.fields[i].type); os << " but got type "; types::print(os, *arg_type); throw type_error(os.str(), node.location); @@ -350,53 +399,12 @@ namespace pslang::ast } } - types::named_type type; - type.name = identifier->name; - type.level = identifier->level; - node.inferred_type = std::make_unique(std::move(type)); + node.inferred_type = type; return; } - - function_name = identifier->name + " "; } - - auto function_type = get_type(*node.function); - - auto ftype = std::get_if(function_type.get()); - if (!ftype) - { - std::ostringstream os; - os << "Cannot call a value of a non-function type "; - types::print(os, *function_type); - throw type_error(os.str(), node.location); - } - - if (ftype->arguments.size() != node.arguments.size()) - { - std::ostringstream os; - os << "Cannot call function " << function_name << " of type "; - types::print(os, *function_type); - os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size(); - throw type_error(os.str(), node.location); - } - - for (std::size_t i = 0; i < node.arguments.size(); ++i) - { - auto arg_type = get_type(*node.arguments[i]); - if (!types::equal(*arg_type, *ftype->arguments[i])) - { - std::ostringstream os; - os << "Cannot call function " << function_name << " of type "; - types::print(os, *function_type); - os << ": argument #" << i << " expected to have type "; - types::print(os, *ftype->arguments[i]); - os << " but got type "; - types::print(os, *arg_type); - throw type_error(os.str(), node.location); - } - } - - node.inferred_type = ftype->result; + else + throw invalid_ast_error("Function call node has neither function nor type", node.location); } void apply(array & node) diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index 55f2997..2de4731 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -507,10 +507,73 @@ namespace pslang::interpreter value eval_impl(context & context, ast::function_call const & function_call) { - auto identifier = std::get_if(function_call.function.get()); - if (identifier) + if (function_call.function) { - if (auto type = types::builtin_type(identifier->name)) + auto lvalue = eval(context, function_call.function); + auto fvalue = std::get_if(&lvalue); + if (fvalue) + { + if (fvalue->arguments.size() != function_call.arguments.size()) + { + std::ostringstream os; + os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size(); + throw std::runtime_error(os.str()); + } + + std::vector args; + for (auto const & expression : function_call.arguments) + args.push_back(eval(context, expression)); + + for (std::size_t i = 0; i < args.size(); ++i) + { + auto actual_type = type_of(args[i]); + if (!types::equal(actual_type, *fvalue->arguments[i].type)) + { + std::ostringstream os; + os << "Cannot call function: argument #" << (i + 1) << " expects type "; + types::print(os, *fvalue->arguments[i].type); + os << " but actual type is "; + types::print(os, actual_type); + throw std::runtime_error(os.str()); + } + } + + auto & function_scope = context.scope_stack.emplace_back(); + function_scope.is_function_scope = true; + + for (std::size_t i = 0; i < args.size(); ++i) + function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; + + auto expected_return_type = fvalue->return_type; + + exec(context, fvalue->statements); + + auto actual_return_type = type_of(context.scope_stack.back().return_value); + if (!types::equal(actual_return_type, *expected_return_type)) + { + std::ostringstream os; + os << "Error returning from function: expected return type is "; + types::print(os, *expected_return_type); + os << " but actual type is "; + types::print(os, actual_return_type); + throw std::runtime_error(os.str()); + } + + auto result = std::move(context.scope_stack.back().return_value); + context.scope_stack.pop_back(); + return result; + } + + std::ostringstream os; + os << "Cannot call "; + print(os, lvalue); + os << ": not a function"; + throw std::runtime_error(os.str()); + } + else if (function_call.type) + { + auto type = get_type(*function_call.type); + if (types::is_builtin_type(*type)) { if (function_call.arguments.empty()) return zero_value(context, *type); @@ -521,111 +584,35 @@ namespace pslang::interpreter os << ": expected 0 arguments, but got " << function_call.arguments.size(); throw std::runtime_error(os.str()); } - - for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) + else if (auto named_type = std::get_if(type.get())) { - if (auto jt = it->structs.find(identifier->name); jt != it->structs.end()) + auto const & scope = context.scope_stack.at(named_type->level); + auto const & data = scope.structs.at(named_type->name); + + if (function_call.arguments.empty()) + return zero_value(context, *type); + + std::vector args; + for (auto const & expression : function_call.arguments) + args.push_back(eval(context, expression)); + + std::unordered_map fields; + + for (std::size_t i = 0; i < args.size(); ++i) { - if (function_call.arguments.empty()) - return zero_value(context, types::named_type{.name = identifier->name, .level = identifier->level}); - - if (jt->second.fields.size() != function_call.arguments.size()) - { - std::ostringstream os; - os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); - throw std::runtime_error(os.str()); - } - - std::vector args; - for (auto const & expression : function_call.arguments) - args.push_back(eval(context, expression)); - - std::unordered_map fields; - - for (std::size_t i = 0; i < args.size(); ++i) - { - auto actual_type = type_of(args[i]); - if (!types::equal(actual_type, *jt->second.fields[i].type)) - { - std::ostringstream os; - os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type "; - types::print(os, *jt->second.fields[i].type); - os << " but actual type is "; - types::print(os, actual_type); - throw std::runtime_error(os.str()); - } - - fields[jt->second.fields[i].name] = std::make_unique(std::move(args[i])); - } - - return struct_value{ - .struct_type = std::make_unique(types::named_type{.name = identifier->name, .level = identifier->level}), - .fields = std::move(fields), - }; + fields[data.fields[i].name] = std::make_unique(std::move(args[i])); } + + return struct_value{ + .struct_type = type, + .fields = std::move(fields), + }; } + else + throw std::runtime_error("Unknown type in constructor"); } - - auto lvalue = eval(context, function_call.function); - auto fvalue = std::get_if(&lvalue); - if (fvalue) - { - if (fvalue->arguments.size() != function_call.arguments.size()) - { - std::ostringstream os; - os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size(); - throw std::runtime_error(os.str()); - } - - std::vector args; - for (auto const & expression : function_call.arguments) - args.push_back(eval(context, expression)); - - for (std::size_t i = 0; i < args.size(); ++i) - { - auto actual_type = type_of(args[i]); - if (!types::equal(actual_type, *fvalue->arguments[i].type)) - { - std::ostringstream os; - os << "Cannot call function: argument #" << (i + 1) << " expects type "; - types::print(os, *fvalue->arguments[i].type); - os << " but actual type is "; - types::print(os, actual_type); - throw std::runtime_error(os.str()); - } - } - - auto & function_scope = context.scope_stack.emplace_back(); - function_scope.is_function_scope = true; - - for (std::size_t i = 0; i < args.size(); ++i) - function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; - - auto expected_return_type = fvalue->return_type; - - exec(context, fvalue->statements); - - auto actual_return_type = type_of(context.scope_stack.back().return_value); - if (!types::equal(actual_return_type, *expected_return_type)) - { - std::ostringstream os; - os << "Error returning from function: expected return type is "; - types::print(os, *expected_return_type); - os << " but actual type is "; - types::print(os, actual_return_type); - throw std::runtime_error(os.str()); - } - - auto result = std::move(context.scope_stack.back().return_value); - context.scope_stack.pop_back(); - return result; - } - - std::ostringstream os; - os << "Cannot call "; - print(os, lvalue); - os << ": not a function"; - throw std::runtime_error(os.str()); + else + throw std::runtime_error("Function call node has neither function nor type"); } value eval_impl(context & context, ast::array const & array) diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index f1c7e79..856c725 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -289,19 +289,19 @@ postfix_expression : base_expression | postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique($1), std::make_unique($3), @$}; } | postfix_expression dot name { $$ = ast::field_access{std::make_unique($1), $3, @$}; } -| postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique($1), $3, @$}; } -| unit lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"unit"}), $3, @$}; } -| bool lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"bool"}), $3, @$}; } -| i8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"i8"}), $3, @$}; } -| u8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"u8"}), $3, @$}; } -| i16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"i16"}), $3, @$}; } -| u16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"u16"}), $3, @$}; } -| i32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"i32"}), $3, @$}; } -| u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"u32"}), $3, @$}; } -| i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"i64"}), $3, @$}; } -| u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"u64"}), $3, @$}; } -| f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"f32"}), $3, @$}; } -| f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique(ast::identifier{"f64"}), $3, @$}; } +| postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique($1), nullptr, $3, @$}; } +| unit lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::unit_type{}), $3, @$}; } +| bool lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::bool_type{}), $3, @$}; } +| i8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::i8_type{}), $3, @$}; } +| u8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::u8_type{}), $3, @$}; } +| i16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::i16_type{}), $3, @$}; } +| u16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::u16_type{}), $3, @$}; } +| i32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::i32_type{}), $3, @$}; } +| u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::u32_type{}), $3, @$}; } +| i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::i64_type{}), $3, @$}; } +| u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::u64_type{}), $3, @$}; } +| f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::f32_type{}), $3, @$}; } +| f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique(types::f64_type{}), $3, @$}; } ; base_expression diff --git a/libs/types/include/pslang/types/type_fwd.hpp b/libs/types/include/pslang/types/type_fwd.hpp index 6f74226..7197c70 100644 --- a/libs/types/include/pslang/types/type_fwd.hpp +++ b/libs/types/include/pslang/types/type_fwd.hpp @@ -21,5 +21,6 @@ namespace pslang::types bool is_unsigned_integer_type(type const & type); bool is_floating_point_type(type const & type); bool is_numeric_type(type const & type); + bool is_builtin_type(type const & type); } diff --git a/libs/types/source/type.cpp b/libs/types/source/type.cpp index e02755b..e78897d 100644 --- a/libs/types/source/type.cpp +++ b/libs/types/source/type.cpp @@ -113,4 +113,14 @@ namespace pslang::types return false; } + bool is_builtin_type(type const & type) + { + if (std::get_if(&type)) + return true; + if (std::get_if(&type)) + return true; + + return false; + } + }