Functions can be any expressions and not just identifiers in function calls (AST only, not implemented in interpreter yet)

This commit is contained in:
Nikita Lisitsa 2025-12-18 14:37:06 +03:00
parent 4867c970d8
commit 53d4e12a09
4 changed files with 18 additions and 12 deletions

View file

@ -29,7 +29,7 @@ namespace pslang::ast
struct function_call struct function_call
{ {
std::string name; expression_ptr function;
std::vector<expression_ptr> arguments; std::vector<expression_ptr> arguments;
ast::location location; ast::location location;
}; };

View file

@ -125,7 +125,8 @@ namespace pslang::ast
void print_impl(std::ostream & out, function_call const & node, print_options const & options) void print_impl(std::ostream & out, function_call const & node, print_options const & options)
{ {
put_indent(out, options); put_indent(out, options);
out << "call " << node.name << '\n'; out << "call\n";
print(out, node.function, child(options));
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
print(out, argument, child(options)); print(out, argument, child(options));
} }

View file

@ -490,14 +490,18 @@ namespace pslang::interpreter
value eval_impl(context & context, ast::function_call const & function_call) value eval_impl(context & context, ast::function_call const & function_call)
{ {
auto identifier = std::get_if<ast::identifier>(function_call.function.get());
if (!identifier)
throw std::runtime_error("Calling non-identifier functions is not implemented");
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{ {
if (auto jt = it->functions.find(function_call.name); jt != it->functions.end()) if (auto jt = it->functions.find(identifier->name); jt != it->functions.end())
{ {
if (jt->second.arguments.size() != function_call.arguments.size()) if (jt->second.arguments.size() != function_call.arguments.size())
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot call function \"" << function_call.name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size(); os << "Cannot call function \"" << identifier->name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size();
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
@ -511,7 +515,7 @@ namespace pslang::interpreter
if (!type::equal(actual_type, *jt->second.arguments[i].type)) if (!type::equal(actual_type, *jt->second.arguments[i].type))
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot call function \"" << function_call.name << "\": argument #" << (i + 1) << " expects type "; os << "Cannot call function \"" << identifier->name << "\": argument #" << (i + 1) << " expects type ";
type::print(os, *jt->second.arguments[i].type); type::print(os, *jt->second.arguments[i].type);
os << " but actual type is "; os << " but actual type is ";
type::print(os, actual_type); type::print(os, actual_type);
@ -533,7 +537,7 @@ namespace pslang::interpreter
if (!type::equal(actual_return_type, *expected_return_type)) if (!type::equal(actual_return_type, *expected_return_type))
{ {
std::ostringstream os; std::ostringstream os;
os << "Error returning from function \"" << function_call.name << "\": expected return type is "; os << "Error returning from function \"" << identifier->name << "\": expected return type is ";
type::print(os, *expected_return_type); type::print(os, *expected_return_type);
os << " but actual type is "; os << " but actual type is ";
type::print(os, actual_return_type); type::print(os, actual_return_type);
@ -545,12 +549,12 @@ namespace pslang::interpreter
return result; return result;
} }
if (auto jt = it->structs.find(function_call.name); jt != it->structs.end()) if (auto jt = it->structs.find(identifier->name); jt != it->structs.end())
{ {
if (jt->second.fields.size() != function_call.arguments.size()) if (jt->second.fields.size() != function_call.arguments.size())
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot create struct \"" << function_call.name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size();
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
@ -566,7 +570,7 @@ namespace pslang::interpreter
if (!type::equal(actual_type, *jt->second.fields[i].type)) if (!type::equal(actual_type, *jt->second.fields[i].type))
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot create struct \"" << function_call.name << "\": field " << jt->second.fields[i].name << " expects type "; os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type ";
type::print(os, *jt->second.fields[i].type); type::print(os, *jt->second.fields[i].type);
os << " but actual type is "; os << " but actual type is ";
type::print(os, actual_type); type::print(os, actual_type);
@ -577,13 +581,13 @@ namespace pslang::interpreter
} }
return struct_value{ return struct_value{
.struct_type = std::make_unique<type::type>(resolve_type(context, type::identifier{function_call.name})), .struct_type = std::make_unique<type::type>(resolve_type(context, type::identifier{identifier->name})),
.fields = std::move(fields), .fields = std::move(fields),
}; };
} }
} }
throw std::runtime_error("Function \"" + function_call.name + "\" is not defined"); throw std::runtime_error("Function \"" + identifier->name + "\" is not defined");
} }
value eval_impl(context & context, ast::array const & array) value eval_impl(context & context, ast::array const & array)

View file

@ -260,13 +260,14 @@ postfix_expression
: base_expression : base_expression
| postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$}; } | postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$}; }
| postfix_expression dot name { $$ = ast::field_access{std::make_unique<ast::expression>($1), $3, @$}; } | postfix_expression dot name { $$ = ast::field_access{std::make_unique<ast::expression>($1), $3, @$}; }
//| postfix_expression dot name lparen comma_separated_expression_list rparen { auto args = $5; args.insert(args.begin(), std::make_unique<ast::expression>($1)); $$ = ast::function_call{$3, std::move(args), @$}; }
| postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>($1), $3, @$}; }
; ;
base_expression base_expression
: literal : literal
| name { $$ = ast::identifier{$1, ctx.location}; } | name { $$ = ast::identifier{$1, ctx.location}; }
| lparen expression rparen { $$ = $2; } | lparen expression rparen { $$ = $2; }
| name lparen comma_separated_expression_list rparen { $$ = ast::function_call{$1, $3, @$}; }
| lbracket nonempty_comma_separated_expression_list rbracket { $$ = ast::array{$2, @$}; } | lbracket nonempty_comma_separated_expression_list rbracket { $$ = ast::array{$2, @$}; }
; ;