Add support for function values & calling non-identifier functions

This commit is contained in:
Nikita Lisitsa 2025-12-18 16:14:51 +03:00
parent 61f1a9c079
commit af815e215a
5 changed files with 127 additions and 99 deletions

View file

@ -19,13 +19,6 @@ namespace pslang::interpreter
interpreter::value value; interpreter::value value;
}; };
struct function_data
{
std::vector<ast::function_definition::argument> arguments;
type::type_ptr return_type;
ast::statement_list_ptr statements;
};
struct struct_data struct struct_data
{ {
std::vector<ast::field_definition> fields; std::vector<ast::field_definition> fields;
@ -34,12 +27,11 @@ namespace pslang::interpreter
struct scope struct scope
{ {
std::unordered_map<std::string, variable_data> variables; std::unordered_map<std::string, variable_data> variables;
std::unordered_map<std::string, function_data> functions;
std::unordered_map<std::string, struct_data> structs; std::unordered_map<std::string, struct_data> structs;
bool contains(std::string const & name) bool contains(std::string const & name)
{ {
return variables.count(name) > 0 || functions.count(name) > 0 || structs.count(name) > 0; return variables.count(name) > 0 || structs.count(name) > 0;
} }
bool is_function_scope = false; bool is_function_scope = false;

View file

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <pslang/type/type.hpp> #include <pslang/type/type.hpp>
#include <pslang/ast/function.hpp>
#include <pslang/interpreter/value_fwd.hpp> #include <pslang/interpreter/value_fwd.hpp>
#include <cstdint> #include <cstdint>
@ -69,11 +70,19 @@ namespace pslang::interpreter
std::unordered_map<std::string, value_ptr> fields; std::unordered_map<std::string, value_ptr> fields;
}; };
struct function_value
{
std::vector<ast::function_definition::argument> arguments;
type::type_ptr return_type;
ast::statement_list_ptr statements;
};
using value_impl = std::variant< using value_impl = std::variant<
unit_value, unit_value,
primitive_value, primitive_value,
array_value, array_value,
struct_value struct_value,
function_value
>; >;
struct value struct value

View file

@ -491,6 +491,11 @@ namespace pslang::interpreter
throw std::runtime_error("Cannot cast struct type to anything"); throw std::runtime_error("Cannot cast struct type to anything");
} }
value cast_impl(function_value const &, type::type const &)
{
throw std::runtime_error("Cannot cast function type to anything");
}
value eval_impl(context & context, ast::cast_operation const & cast_operation) value eval_impl(context & context, ast::cast_operation const & cast_operation)
{ {
auto arg = eval(context, cast_operation.expression); auto arg = eval(context, cast_operation.expression);
@ -500,103 +505,109 @@ namespace pslang::interpreter
value eval_impl(context & context, ast::function_call const & function_call) value eval_impl(context & context, ast::function_call const & function_call)
{ {
auto identifier = std::get_if<ast::identifier>(function_call.function.get()); auto identifier = std::get_if<ast::identifier>(function_call.function.get());
if (!identifier) if (identifier)
throw std::runtime_error("Calling non-identifier functions is not implemented");
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{ {
if (auto jt = it->functions.find(identifier->name); jt != it->functions.end()) for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{ {
if (jt->second.arguments.size() != function_call.arguments.size()) if (auto jt = it->structs.find(identifier->name); jt != it->structs.end())
{ {
std::ostringstream os; if (jt->second.fields.size() != function_call.arguments.size())
os << "Cannot call function \"" << identifier->name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!type::equal(actual_type, *jt->second.arguments[i].type))
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot call function \"" << identifier->name << "\": argument #" << (i + 1) << " expects type "; os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size();
type::print(os, *jt->second.arguments[i].type);
os << " but actual type is ";
type::print(os, actual_type);
throw std::runtime_error(os.str());
}
}
auto & function_scope = context.scope_stack.emplace_back();
function_scope.is_function_scope = true;
for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[jt->second.arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
auto expected_return_type = jt->second.return_type;
exec(context, jt->second.statements);
auto actual_return_type = type_of(context.scope_stack.back().return_value);
if (!type::equal(actual_return_type, *expected_return_type))
{
std::ostringstream os;
os << "Error returning from function \"" << identifier->name << "\": expected return type is ";
type::print(os, *expected_return_type);
os << " but actual type is ";
type::print(os, actual_return_type);
throw std::runtime_error(os.str());
}
auto result = std::move(context.scope_stack.back().return_value);
context.scope_stack.pop_back();
return result;
}
if (auto jt = it->structs.find(identifier->name); jt != it->structs.end())
{
if (jt->second.fields.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
std::unordered_map<std::string, value_ptr> fields;
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!type::equal(actual_type, *jt->second.fields[i].type))
{
std::ostringstream os;
os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type ";
type::print(os, *jt->second.fields[i].type);
os << " but actual type is ";
type::print(os, actual_type);
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
fields[jt->second.fields[i].name] = std::make_unique<value>(std::move(args[i])); std::vector<value> args;
} for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
return struct_value{ std::unordered_map<std::string, value_ptr> fields;
.struct_type = std::make_unique<type::type>(resolve_type(context, type::identifier{identifier->name})),
.fields = std::move(fields), for (std::size_t i = 0; i < args.size(); ++i)
}; {
auto actual_type = type_of(args[i]);
if (!type::equal(actual_type, *jt->second.fields[i].type))
{
std::ostringstream os;
os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type ";
type::print(os, *jt->second.fields[i].type);
os << " but actual type is ";
type::print(os, actual_type);
throw std::runtime_error(os.str());
}
fields[jt->second.fields[i].name] = std::make_unique<value>(std::move(args[i]));
}
return struct_value{
.struct_type = std::make_unique<type::type>(resolve_type(context, type::identifier{identifier->name})),
.fields = std::move(fields),
};
}
} }
} }
throw std::runtime_error("Function \"" + identifier->name + "\" is not defined"); auto lvalue = eval(context, function_call.function);
auto fvalue = std::get_if<function_value>(&lvalue);
if (fvalue)
{
if (fvalue->arguments.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!type::equal(actual_type, *fvalue->arguments[i].type))
{
std::ostringstream os;
os << "Cannot call function: argument #" << (i + 1) << " expects type ";
type::print(os, *fvalue->arguments[i].type);
os << " but actual type is ";
type::print(os, actual_type);
throw std::runtime_error(os.str());
}
}
auto & function_scope = context.scope_stack.emplace_back();
function_scope.is_function_scope = true;
for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
auto expected_return_type = fvalue->return_type;
exec(context, fvalue->statements);
auto actual_return_type = type_of(context.scope_stack.back().return_value);
if (!type::equal(actual_return_type, *expected_return_type))
{
std::ostringstream os;
os << "Error returning from function: expected return type is ";
type::print(os, *expected_return_type);
os << " but actual type is ";
type::print(os, actual_return_type);
throw std::runtime_error(os.str());
}
auto result = std::move(context.scope_stack.back().return_value);
context.scope_stack.pop_back();
return result;
}
std::ostringstream os;
os << "Cannot call ";
print(os, lvalue);
os << ": not a function";
throw std::runtime_error(os.str());
} }
value eval_impl(context & context, ast::array const & array) value eval_impl(context & context, ast::array const & array)

View file

@ -159,13 +159,14 @@ namespace pslang::interpreter
if (scope.contains(function_definition.name)) if (scope.contains(function_definition.name))
throw std::runtime_error("Identifier \"" + function_definition.name + "\" is already defined in this scope"); throw std::runtime_error("Identifier \"" + function_definition.name + "\" is already defined in this scope");
auto & function = scope.functions[function_definition.name]; function_value value;
for (auto const & argument : function_definition.arguments) for (auto const & argument : function_definition.arguments)
function.arguments.push_back({.name = argument.name, .type = std::make_unique<type::type>(resolve_type(context, *argument.type))}); value.arguments.push_back({.name = argument.name, .type = std::make_unique<type::type>(resolve_type(context, *argument.type))});
value.return_type = std::make_unique<type::type>(resolve_type(context, *function_definition.return_type));
value.statements = function_definition.statements;
function.return_type = std::make_unique<type::type>(resolve_type(context, *function_definition.return_type)); scope.variables[function_definition.name] = {.category = ast::value_category::constant, .value = std::move(value)};
function.statements = function_definition.statements;
} }
void exec_impl(context & context, ast::return_statement const & return_statement) void exec_impl(context & context, ast::return_statement const & return_statement)

View file

@ -1,4 +1,5 @@
#include <pslang/interpreter/value.hpp> #include <pslang/interpreter/value.hpp>
#include <pslang/type/print.hpp>
#include <iomanip> #include <iomanip>
@ -34,6 +35,15 @@ namespace pslang::interpreter
return *value.struct_type; return *value.struct_type;
} }
type::type type_of_impl(function_value const & value)
{
type::function_type result;
for (auto const & argument : value.arguments)
result.arguments.push_back(argument.type);
result.result = value.return_type;
return result;
}
type::type type_of_impl(value const & value) type::type type_of_impl(value const & value)
{ {
return std::visit([](auto const & value){ return type_of_impl(value); }, value); return std::visit([](auto const & value){ return type_of_impl(value); }, value);
@ -105,6 +115,11 @@ namespace pslang::interpreter
out << "}"; out << "}";
} }
void print_impl(std::ostream & out, function_value const & value)
{
out << "func";
}
void print_impl(std::ostream & out, value const & value) void print_impl(std::ostream & out, value const & value)
{ {
std::visit([&](auto const & value){ return print_impl(out, value); }, value); std::visit([&](auto const & value){ return print_impl(out, value); }, value);