Functions can be any expressions and not just identifiers in function calls (AST only, not implemented in interpreter yet)

This commit is contained in:
Nikita Lisitsa 2025-12-18 14:37:06 +03:00
parent 4867c970d8
commit 53d4e12a09
4 changed files with 18 additions and 12 deletions

View file

@ -29,7 +29,7 @@ namespace pslang::ast
struct function_call
{
std::string name;
expression_ptr function;
std::vector<expression_ptr> arguments;
ast::location location;
};

View file

@ -125,7 +125,8 @@ namespace pslang::ast
void print_impl(std::ostream & out, function_call const & node, print_options const & options)
{
put_indent(out, options);
out << "call " << node.name << '\n';
out << "call\n";
print(out, node.function, child(options));
for (auto const & argument : node.arguments)
print(out, argument, child(options));
}

View file

@ -490,14 +490,18 @@ namespace pslang::interpreter
value eval_impl(context & context, ast::function_call const & function_call)
{
auto identifier = std::get_if<ast::identifier>(function_call.function.get());
if (!identifier)
throw std::runtime_error("Calling non-identifier functions is not implemented");
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{
if (auto jt = it->functions.find(function_call.name); jt != it->functions.end())
if (auto jt = it->functions.find(identifier->name); jt != it->functions.end())
{
if (jt->second.arguments.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot call function \"" << function_call.name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size();
os << "Cannot call function \"" << identifier->name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
@ -511,7 +515,7 @@ namespace pslang::interpreter
if (!type::equal(actual_type, *jt->second.arguments[i].type))
{
std::ostringstream os;
os << "Cannot call function \"" << function_call.name << "\": argument #" << (i + 1) << " expects type ";
os << "Cannot call function \"" << identifier->name << "\": argument #" << (i + 1) << " expects type ";
type::print(os, *jt->second.arguments[i].type);
os << " but actual type is ";
type::print(os, actual_type);
@ -533,7 +537,7 @@ namespace pslang::interpreter
if (!type::equal(actual_return_type, *expected_return_type))
{
std::ostringstream os;
os << "Error returning from function \"" << function_call.name << "\": expected return type is ";
os << "Error returning from function \"" << identifier->name << "\": expected return type is ";
type::print(os, *expected_return_type);
os << " but actual type is ";
type::print(os, actual_return_type);
@ -545,12 +549,12 @@ namespace pslang::interpreter
return result;
}
if (auto jt = it->structs.find(function_call.name); jt != it->structs.end())
if (auto jt = it->structs.find(identifier->name); jt != it->structs.end())
{
if (jt->second.fields.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot create struct \"" << function_call.name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size();
os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
@ -566,7 +570,7 @@ namespace pslang::interpreter
if (!type::equal(actual_type, *jt->second.fields[i].type))
{
std::ostringstream os;
os << "Cannot create struct \"" << function_call.name << "\": field " << jt->second.fields[i].name << " expects type ";
os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type ";
type::print(os, *jt->second.fields[i].type);
os << " but actual type is ";
type::print(os, actual_type);
@ -577,13 +581,13 @@ namespace pslang::interpreter
}
return struct_value{
.struct_type = std::make_unique<type::type>(resolve_type(context, type::identifier{function_call.name})),
.struct_type = std::make_unique<type::type>(resolve_type(context, type::identifier{identifier->name})),
.fields = std::move(fields),
};
}
}
throw std::runtime_error("Function \"" + function_call.name + "\" is not defined");
throw std::runtime_error("Function \"" + identifier->name + "\" is not defined");
}
value eval_impl(context & context, ast::array const & array)

View file

@ -260,13 +260,14 @@ postfix_expression
: base_expression
| postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$}; }
| postfix_expression dot name { $$ = ast::field_access{std::make_unique<ast::expression>($1), $3, @$}; }
//| postfix_expression dot name lparen comma_separated_expression_list rparen { auto args = $5; args.insert(args.begin(), std::make_unique<ast::expression>($1)); $$ = ast::function_call{$3, std::move(args), @$}; }
| postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>($1), $3, @$}; }
;
base_expression
: literal
| name { $$ = ast::identifier{$1, ctx.location}; }
| lparen expression rparen { $$ = $2; }
| name lparen comma_separated_expression_list rparen { $$ = ast::function_call{$1, $3, @$}; }
| lbracket nonempty_comma_separated_expression_list rbracket { $$ = ast::array{$2, @$}; }
;