From af815e215a80b157eb57efebb19126f8e0e8aa5b Mon Sep 17 00:00:00 2001 From: lisyarus Date: Thu, 18 Dec 2025 16:14:51 +0300 Subject: [PATCH] Add support for function values & calling non-identifier functions --- .../include/pslang/interpreter/context.hpp | 10 +- .../include/pslang/interpreter/value.hpp | 11 +- libs/interpreter/source/eval.cpp | 181 ++++++++++-------- libs/interpreter/source/exec.cpp | 9 +- libs/interpreter/source/value.cpp | 15 ++ 5 files changed, 127 insertions(+), 99 deletions(-) diff --git a/libs/interpreter/include/pslang/interpreter/context.hpp b/libs/interpreter/include/pslang/interpreter/context.hpp index 42d515e..ed82313 100644 --- a/libs/interpreter/include/pslang/interpreter/context.hpp +++ b/libs/interpreter/include/pslang/interpreter/context.hpp @@ -19,13 +19,6 @@ namespace pslang::interpreter interpreter::value value; }; - struct function_data - { - std::vector arguments; - type::type_ptr return_type; - ast::statement_list_ptr statements; - }; - struct struct_data { std::vector fields; @@ -34,12 +27,11 @@ namespace pslang::interpreter struct scope { std::unordered_map variables; - std::unordered_map functions; std::unordered_map structs; bool contains(std::string const & name) { - return variables.count(name) > 0 || functions.count(name) > 0 || structs.count(name) > 0; + return variables.count(name) > 0 || structs.count(name) > 0; } bool is_function_scope = false; diff --git a/libs/interpreter/include/pslang/interpreter/value.hpp b/libs/interpreter/include/pslang/interpreter/value.hpp index f737f2c..2de216c 100644 --- a/libs/interpreter/include/pslang/interpreter/value.hpp +++ b/libs/interpreter/include/pslang/interpreter/value.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -69,11 +70,19 @@ namespace pslang::interpreter std::unordered_map fields; }; + struct function_value + { + std::vector arguments; + type::type_ptr return_type; + ast::statement_list_ptr statements; + }; + using value_impl = std::variant< unit_value, primitive_value, array_value, - struct_value + struct_value, + function_value >; struct value diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index 3a7224b..f8fe084 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -491,6 +491,11 @@ namespace pslang::interpreter throw std::runtime_error("Cannot cast struct type to anything"); } + value cast_impl(function_value const &, type::type const &) + { + throw std::runtime_error("Cannot cast function type to anything"); + } + value eval_impl(context & context, ast::cast_operation const & cast_operation) { auto arg = eval(context, cast_operation.expression); @@ -500,103 +505,109 @@ namespace pslang::interpreter value eval_impl(context & context, ast::function_call const & function_call) { auto identifier = std::get_if(function_call.function.get()); - if (!identifier) - throw std::runtime_error("Calling non-identifier functions is not implemented"); - - for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) + if (identifier) { - if (auto jt = it->functions.find(identifier->name); jt != it->functions.end()) + for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) { - if (jt->second.arguments.size() != function_call.arguments.size()) + if (auto jt = it->structs.find(identifier->name); jt != it->structs.end()) { - std::ostringstream os; - os << "Cannot call function \"" << identifier->name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size(); - throw std::runtime_error(os.str()); - } - - std::vector args; - for (auto const & expression : function_call.arguments) - args.push_back(eval(context, expression)); - - for (std::size_t i = 0; i < args.size(); ++i) - { - auto actual_type = type_of(args[i]); - if (!type::equal(actual_type, *jt->second.arguments[i].type)) + if (jt->second.fields.size() != function_call.arguments.size()) { std::ostringstream os; - os << "Cannot call function \"" << identifier->name << "\": argument #" << (i + 1) << " expects type "; - type::print(os, *jt->second.arguments[i].type); - os << " but actual type is "; - type::print(os, actual_type); - throw std::runtime_error(os.str()); - } - } - - auto & function_scope = context.scope_stack.emplace_back(); - function_scope.is_function_scope = true; - - for (std::size_t i = 0; i < args.size(); ++i) - function_scope.variables[jt->second.arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; - - auto expected_return_type = jt->second.return_type; - - exec(context, jt->second.statements); - - auto actual_return_type = type_of(context.scope_stack.back().return_value); - if (!type::equal(actual_return_type, *expected_return_type)) - { - std::ostringstream os; - os << "Error returning from function \"" << identifier->name << "\": expected return type is "; - type::print(os, *expected_return_type); - os << " but actual type is "; - type::print(os, actual_return_type); - throw std::runtime_error(os.str()); - } - - auto result = std::move(context.scope_stack.back().return_value); - context.scope_stack.pop_back(); - return result; - } - - if (auto jt = it->structs.find(identifier->name); jt != it->structs.end()) - { - if (jt->second.fields.size() != function_call.arguments.size()) - { - std::ostringstream os; - os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); - throw std::runtime_error(os.str()); - } - - std::vector args; - for (auto const & expression : function_call.arguments) - args.push_back(eval(context, expression)); - - std::unordered_map fields; - - for (std::size_t i = 0; i < args.size(); ++i) - { - auto actual_type = type_of(args[i]); - if (!type::equal(actual_type, *jt->second.fields[i].type)) - { - std::ostringstream os; - os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type "; - type::print(os, *jt->second.fields[i].type); - os << " but actual type is "; - type::print(os, actual_type); + os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); throw std::runtime_error(os.str()); } - fields[jt->second.fields[i].name] = std::make_unique(std::move(args[i])); - } + std::vector args; + for (auto const & expression : function_call.arguments) + args.push_back(eval(context, expression)); - return struct_value{ - .struct_type = std::make_unique(resolve_type(context, type::identifier{identifier->name})), - .fields = std::move(fields), - }; + std::unordered_map fields; + + for (std::size_t i = 0; i < args.size(); ++i) + { + auto actual_type = type_of(args[i]); + if (!type::equal(actual_type, *jt->second.fields[i].type)) + { + std::ostringstream os; + os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type "; + type::print(os, *jt->second.fields[i].type); + os << " but actual type is "; + type::print(os, actual_type); + throw std::runtime_error(os.str()); + } + + fields[jt->second.fields[i].name] = std::make_unique(std::move(args[i])); + } + + return struct_value{ + .struct_type = std::make_unique(resolve_type(context, type::identifier{identifier->name})), + .fields = std::move(fields), + }; + } } } - throw std::runtime_error("Function \"" + identifier->name + "\" is not defined"); + auto lvalue = eval(context, function_call.function); + auto fvalue = std::get_if(&lvalue); + if (fvalue) + { + if (fvalue->arguments.size() != function_call.arguments.size()) + { + std::ostringstream os; + os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size(); + throw std::runtime_error(os.str()); + } + + std::vector args; + for (auto const & expression : function_call.arguments) + args.push_back(eval(context, expression)); + + for (std::size_t i = 0; i < args.size(); ++i) + { + auto actual_type = type_of(args[i]); + if (!type::equal(actual_type, *fvalue->arguments[i].type)) + { + std::ostringstream os; + os << "Cannot call function: argument #" << (i + 1) << " expects type "; + type::print(os, *fvalue->arguments[i].type); + os << " but actual type is "; + type::print(os, actual_type); + throw std::runtime_error(os.str()); + } + } + + auto & function_scope = context.scope_stack.emplace_back(); + function_scope.is_function_scope = true; + + for (std::size_t i = 0; i < args.size(); ++i) + function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])}; + + auto expected_return_type = fvalue->return_type; + + exec(context, fvalue->statements); + + auto actual_return_type = type_of(context.scope_stack.back().return_value); + if (!type::equal(actual_return_type, *expected_return_type)) + { + std::ostringstream os; + os << "Error returning from function: expected return type is "; + type::print(os, *expected_return_type); + os << " but actual type is "; + type::print(os, actual_return_type); + throw std::runtime_error(os.str()); + } + + auto result = std::move(context.scope_stack.back().return_value); + context.scope_stack.pop_back(); + return result; + } + + std::ostringstream os; + os << "Cannot call "; + print(os, lvalue); + os << ": not a function"; + throw std::runtime_error(os.str()); } value eval_impl(context & context, ast::array const & array) diff --git a/libs/interpreter/source/exec.cpp b/libs/interpreter/source/exec.cpp index 38b28a6..b82ce8a 100644 --- a/libs/interpreter/source/exec.cpp +++ b/libs/interpreter/source/exec.cpp @@ -159,13 +159,14 @@ namespace pslang::interpreter if (scope.contains(function_definition.name)) throw std::runtime_error("Identifier \"" + function_definition.name + "\" is already defined in this scope"); - auto & function = scope.functions[function_definition.name]; + function_value value; for (auto const & argument : function_definition.arguments) - function.arguments.push_back({.name = argument.name, .type = std::make_unique(resolve_type(context, *argument.type))}); + value.arguments.push_back({.name = argument.name, .type = std::make_unique(resolve_type(context, *argument.type))}); + value.return_type = std::make_unique(resolve_type(context, *function_definition.return_type)); + value.statements = function_definition.statements; - function.return_type = std::make_unique(resolve_type(context, *function_definition.return_type)); - function.statements = function_definition.statements; + scope.variables[function_definition.name] = {.category = ast::value_category::constant, .value = std::move(value)}; } void exec_impl(context & context, ast::return_statement const & return_statement) diff --git a/libs/interpreter/source/value.cpp b/libs/interpreter/source/value.cpp index fb8f65c..fcae499 100644 --- a/libs/interpreter/source/value.cpp +++ b/libs/interpreter/source/value.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -34,6 +35,15 @@ namespace pslang::interpreter return *value.struct_type; } + type::type type_of_impl(function_value const & value) + { + type::function_type result; + for (auto const & argument : value.arguments) + result.arguments.push_back(argument.type); + result.result = value.return_type; + return result; + } + type::type type_of_impl(value const & value) { return std::visit([](auto const & value){ return type_of_impl(value); }, value); @@ -105,6 +115,11 @@ namespace pslang::interpreter out << "}"; } + void print_impl(std::ostream & out, function_value const & value) + { + out << "func"; + } + void print_impl(std::ostream & out, value const & value) { std::visit([&](auto const & value){ return print_impl(out, value); }, value);