diff --git a/libs/ast/include/pslang/ast/function.hpp b/libs/ast/include/pslang/ast/function.hpp index b540dfe..c832577 100644 --- a/libs/ast/include/pslang/ast/function.hpp +++ b/libs/ast/include/pslang/ast/function.hpp @@ -29,7 +29,7 @@ namespace pslang::ast struct function_call { - std::string name; + expression_ptr function; std::vector arguments; ast::location location; }; diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index d9eb338..c5f5033 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -125,7 +125,8 @@ namespace pslang::ast void print_impl(std::ostream & out, function_call const & node, print_options const & options) { put_indent(out, options); - out << "call " << node.name << '\n'; + out << "call\n"; + print(out, node.function, child(options)); for (auto const & argument : node.arguments) print(out, argument, child(options)); } diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index fa68140..8e895ba 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -490,14 +490,18 @@ namespace pslang::interpreter value eval_impl(context & context, ast::function_call const & function_call) { + auto identifier = std::get_if(function_call.function.get()); + if (!identifier) + throw std::runtime_error("Calling non-identifier functions is not implemented"); + for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it) { - if (auto jt = it->functions.find(function_call.name); jt != it->functions.end()) + if (auto jt = it->functions.find(identifier->name); jt != it->functions.end()) { if (jt->second.arguments.size() != function_call.arguments.size()) { std::ostringstream os; - os << "Cannot call function \"" << function_call.name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size(); + os << "Cannot call function \"" << identifier->name << "\": expected " << jt->second.arguments.size() << " arguments, got " << function_call.arguments.size(); throw std::runtime_error(os.str()); } @@ -511,7 +515,7 @@ namespace pslang::interpreter if (!type::equal(actual_type, *jt->second.arguments[i].type)) { std::ostringstream os; - os << "Cannot call function \"" << function_call.name << "\": argument #" << (i + 1) << " expects type "; + os << "Cannot call function \"" << identifier->name << "\": argument #" << (i + 1) << " expects type "; type::print(os, *jt->second.arguments[i].type); os << " but actual type is "; type::print(os, actual_type); @@ -533,7 +537,7 @@ namespace pslang::interpreter if (!type::equal(actual_return_type, *expected_return_type)) { std::ostringstream os; - os << "Error returning from function \"" << function_call.name << "\": expected return type is "; + os << "Error returning from function \"" << identifier->name << "\": expected return type is "; type::print(os, *expected_return_type); os << " but actual type is "; type::print(os, actual_return_type); @@ -545,12 +549,12 @@ namespace pslang::interpreter return result; } - if (auto jt = it->structs.find(function_call.name); jt != it->structs.end()) + if (auto jt = it->structs.find(identifier->name); jt != it->structs.end()) { if (jt->second.fields.size() != function_call.arguments.size()) { std::ostringstream os; - os << "Cannot create struct \"" << function_call.name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); + os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size(); throw std::runtime_error(os.str()); } @@ -566,7 +570,7 @@ namespace pslang::interpreter if (!type::equal(actual_type, *jt->second.fields[i].type)) { std::ostringstream os; - os << "Cannot create struct \"" << function_call.name << "\": field " << jt->second.fields[i].name << " expects type "; + os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type "; type::print(os, *jt->second.fields[i].type); os << " but actual type is "; type::print(os, actual_type); @@ -577,13 +581,13 @@ namespace pslang::interpreter } return struct_value{ - .struct_type = std::make_unique(resolve_type(context, type::identifier{function_call.name})), + .struct_type = std::make_unique(resolve_type(context, type::identifier{identifier->name})), .fields = std::move(fields), }; } } - throw std::runtime_error("Function \"" + function_call.name + "\" is not defined"); + throw std::runtime_error("Function \"" + identifier->name + "\" is not defined"); } value eval_impl(context & context, ast::array const & array) diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index 4729a1f..1e9deb4 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -260,13 +260,14 @@ postfix_expression : base_expression | postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique($1), std::make_unique($3), @$}; } | postfix_expression dot name { $$ = ast::field_access{std::make_unique($1), $3, @$}; } +//| postfix_expression dot name lparen comma_separated_expression_list rparen { auto args = $5; args.insert(args.begin(), std::make_unique($1)); $$ = ast::function_call{$3, std::move(args), @$}; } +| postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique($1), $3, @$}; } ; base_expression : literal | name { $$ = ast::identifier{$1, ctx.location}; } | lparen expression rparen { $$ = $2; } -| name lparen comma_separated_expression_list rparen { $$ = ast::function_call{$1, $3, @$}; } | lbracket nonempty_comma_separated_expression_list rbracket { $$ = ast::array{$2, @$}; } ;