Explicitly mark constructor AST nodes

This commit is contained in:
Nikita Lisitsa 2025-12-23 13:43:54 +03:00
parent 46a1031a08
commit dea5c18cfd
8 changed files with 221 additions and 179 deletions

View file

@ -38,7 +38,10 @@ namespace pslang::ast
struct function_call struct function_call
{ {
expression_ptr function; // Exactly one of these 2 is non-null
expression_ptr function; // function call
type_ptr type; // type constructor
std::vector<expression_ptr> arguments; std::vector<expression_ptr> arguments;
ast::location location; ast::location location;
types::type_ptr inferred_type = nullptr; types::type_ptr inferred_type = nullptr;

View file

@ -193,8 +193,17 @@ namespace pslang::ast
void apply(function_call const & node) void apply(function_call const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "call\n"; if (node.function)
child(*node.function); {
out << "call\n";
child(*node.function);
}
if (node.type)
{
out << "constructor { type = ";
print(out, *node.type);
out << " }\n";
}
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
child(*argument); child(*argument);
} }

View file

@ -4,6 +4,7 @@
#include <pslang/ast/expression_visitor.hpp> #include <pslang/ast/expression_visitor.hpp>
#include <pslang/ast/statement_visitor.hpp> #include <pslang/ast/statement_visitor.hpp>
#include <pslang/ast/error.hpp> #include <pslang/ast/error.hpp>
#include <pslang/types/type.hpp>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -138,11 +139,34 @@ namespace pslang::ast
apply(*cast_operation.type); apply(*cast_operation.type);
} }
void apply(function_call const & function_call) void apply(function_call & function_call)
{ {
apply(*function_call.function); if (function_call.function)
apply(*function_call.function);
if (function_call.type)
apply(*function_call.type);
for (auto const & argument : function_call.arguments) for (auto const & argument : function_call.arguments)
apply(*argument); apply(*argument);
if (auto id = std::get_if<identifier>(function_call.function.get()))
{
if (auto type = types::builtin_type(id->name))
{
if (auto unit_type = std::get_if<types::unit_type>(type.get()))
function_call.type = std::make_unique<ast::type>(*unit_type);
else if (auto primitive_type = std::get_if<types::primitive_type>(type.get()))
function_call.type = std::make_unique<ast::type>(*primitive_type);
else
throw invalid_ast_error("Unknown built-in type \"" + id->name + "\"", get_location(*function_call.function));
function_call.function = nullptr;
}
else if (scopes.at(id->level).structs.contains(id->name))
{
function_call.type = std::make_unique<ast::type>(type_identifier{.name = id->name, .location = id->location, .level = id->level});
function_call.function = nullptr;
}
}
} }
void apply(array const & array) void apply(array const & array)

View file

@ -300,19 +300,67 @@ namespace pslang::ast
void apply(function_call & node) void apply(function_call & node)
{ {
apply(*node.function); if (node.function)
apply(*node.function);
if (node.type)
apply(*node.type);
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
apply(*argument); apply(*argument);
std::string function_name; if (node.function)
if (auto identifier = std::get_if<ast::identifier>(node.function.get()))
{ {
if (types::type_ptr type = types::builtin_type(identifier->name)) std::string function_name;
if (auto identifier = std::get_if<ast::identifier>(node.function.get()))
function_name = identifier->name + " ";
auto function_type = get_type(*node.function);
auto ftype = std::get_if<types::function_type>(function_type.get());
if (!ftype)
{
std::ostringstream os;
os << "Cannot call a value of a non-function type ";
types::print(os, *function_type);
throw type_error(os.str(), node.location);
}
if (ftype->arguments.size() != node.arguments.size())
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size();
throw type_error(os.str(), node.location);
}
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
auto arg_type = get_type(*node.arguments[i]);
if (!types::equal(*arg_type, *ftype->arguments[i]))
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
os << ": argument #" << i << " expected to have type ";
types::print(os, *ftype->arguments[i]);
os << " but got type ";
types::print(os, *arg_type);
throw type_error(os.str(), node.location);
}
}
node.inferred_type = ftype->result;
}
else if (node.type)
{
auto type = get_type(*node.type);
if (types::is_builtin_type(*type))
{ {
if (node.arguments.empty()) if (node.arguments.empty())
{ {
node.inferred_type = type; node.inferred_type = get_type(*node.type);
return; return;
} }
@ -322,27 +370,28 @@ namespace pslang::ast
os << ": expected 0 arguments, but got " << node.arguments.size(); os << ": expected 0 arguments, but got " << node.arguments.size();
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
else if (auto named_type = std::get_if<types::named_type>(type.get()))
auto & scope = scopes.at(identifier->level);
if (auto it = scope.structs.find(identifier->name); it != scope.structs.end())
{ {
auto const & scope = scopes.at(named_type->level);
auto const & data = scope.structs.at(named_type->name);
if (!node.arguments.empty()) if (!node.arguments.empty())
{ {
if (node.arguments.size() != it->second.fields.size()) if (node.arguments.size() != data.fields.size())
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot create struct " << identifier->name << ": expected " << it->second.fields.size() << " arguments, but got " << node.arguments.size(); os << "Cannot create struct " << named_type->name << ": expected " << data.fields.size() << " arguments, but got " << node.arguments.size();
throw type_error(os.str(), node.location); throw type_error(os.str(), node.location);
} }
for (std::size_t i = 0; i < node.arguments.size(); ++i) for (std::size_t i = 0; i < node.arguments.size(); ++i)
{ {
auto arg_type = get_type(*node.arguments[i]); auto arg_type = get_type(*node.arguments[i]);
if (!types::equal(*arg_type, *it->second.fields[i].type)) if (!types::equal(*arg_type, *data.fields[i].type))
{ {
std::ostringstream os; std::ostringstream os;
os << "Cannot create struct " << identifier->name << ": argument #" << i << " expected to have type "; os << "Cannot create struct " << named_type->name << ": argument #" << i << " expected to have type ";
types::print(os, *it->second.fields[i].type); types::print(os, *data.fields[i].type);
os << " but got type "; os << " but got type ";
types::print(os, *arg_type); types::print(os, *arg_type);
throw type_error(os.str(), node.location); throw type_error(os.str(), node.location);
@ -350,53 +399,12 @@ namespace pslang::ast
} }
} }
types::named_type type; node.inferred_type = type;
type.name = identifier->name;
type.level = identifier->level;
node.inferred_type = std::make_unique<types::type>(std::move(type));
return; return;
} }
function_name = identifier->name + " ";
} }
else
auto function_type = get_type(*node.function); throw invalid_ast_error("Function call node has neither function nor type", node.location);
auto ftype = std::get_if<types::function_type>(function_type.get());
if (!ftype)
{
std::ostringstream os;
os << "Cannot call a value of a non-function type ";
types::print(os, *function_type);
throw type_error(os.str(), node.location);
}
if (ftype->arguments.size() != node.arguments.size())
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
os << ": expected " << ftype->arguments.size() << " arguments, but got " << node.arguments.size();
throw type_error(os.str(), node.location);
}
for (std::size_t i = 0; i < node.arguments.size(); ++i)
{
auto arg_type = get_type(*node.arguments[i]);
if (!types::equal(*arg_type, *ftype->arguments[i]))
{
std::ostringstream os;
os << "Cannot call function " << function_name << " of type ";
types::print(os, *function_type);
os << ": argument #" << i << " expected to have type ";
types::print(os, *ftype->arguments[i]);
os << " but got type ";
types::print(os, *arg_type);
throw type_error(os.str(), node.location);
}
}
node.inferred_type = ftype->result;
} }
void apply(array & node) void apply(array & node)

View file

@ -507,10 +507,73 @@ namespace pslang::interpreter
value eval_impl(context & context, ast::function_call const & function_call) value eval_impl(context & context, ast::function_call const & function_call)
{ {
auto identifier = std::get_if<ast::identifier>(function_call.function.get()); if (function_call.function)
if (identifier)
{ {
if (auto type = types::builtin_type(identifier->name)) auto lvalue = eval(context, function_call.function);
auto fvalue = std::get_if<function_value>(&lvalue);
if (fvalue)
{
if (fvalue->arguments.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!types::equal(actual_type, *fvalue->arguments[i].type))
{
std::ostringstream os;
os << "Cannot call function: argument #" << (i + 1) << " expects type ";
types::print(os, *fvalue->arguments[i].type);
os << " but actual type is ";
types::print(os, actual_type);
throw std::runtime_error(os.str());
}
}
auto & function_scope = context.scope_stack.emplace_back();
function_scope.is_function_scope = true;
for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
auto expected_return_type = fvalue->return_type;
exec(context, fvalue->statements);
auto actual_return_type = type_of(context.scope_stack.back().return_value);
if (!types::equal(actual_return_type, *expected_return_type))
{
std::ostringstream os;
os << "Error returning from function: expected return type is ";
types::print(os, *expected_return_type);
os << " but actual type is ";
types::print(os, actual_return_type);
throw std::runtime_error(os.str());
}
auto result = std::move(context.scope_stack.back().return_value);
context.scope_stack.pop_back();
return result;
}
std::ostringstream os;
os << "Cannot call ";
print(os, lvalue);
os << ": not a function";
throw std::runtime_error(os.str());
}
else if (function_call.type)
{
auto type = get_type(*function_call.type);
if (types::is_builtin_type(*type))
{ {
if (function_call.arguments.empty()) if (function_call.arguments.empty())
return zero_value(context, *type); return zero_value(context, *type);
@ -521,111 +584,35 @@ namespace pslang::interpreter
os << ": expected 0 arguments, but got " << function_call.arguments.size(); os << ": expected 0 arguments, but got " << function_call.arguments.size();
throw std::runtime_error(os.str()); throw std::runtime_error(os.str());
} }
else if (auto named_type = std::get_if<types::named_type>(type.get()))
for (auto it = context.scope_stack.rbegin(); it != context.scope_stack.rend(); ++it)
{ {
if (auto jt = it->structs.find(identifier->name); jt != it->structs.end()) auto const & scope = context.scope_stack.at(named_type->level);
auto const & data = scope.structs.at(named_type->name);
if (function_call.arguments.empty())
return zero_value(context, *type);
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
std::unordered_map<std::string, value_ptr> fields;
for (std::size_t i = 0; i < args.size(); ++i)
{ {
if (function_call.arguments.empty()) fields[data.fields[i].name] = std::make_unique<value>(std::move(args[i]));
return zero_value(context, types::named_type{.name = identifier->name, .level = identifier->level});
if (jt->second.fields.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot create struct \"" << identifier->name << "\": expected " << jt->second.fields.size() << " fields, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
std::unordered_map<std::string, value_ptr> fields;
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!types::equal(actual_type, *jt->second.fields[i].type))
{
std::ostringstream os;
os << "Cannot create struct \"" << identifier->name << "\": field " << jt->second.fields[i].name << " expects type ";
types::print(os, *jt->second.fields[i].type);
os << " but actual type is ";
types::print(os, actual_type);
throw std::runtime_error(os.str());
}
fields[jt->second.fields[i].name] = std::make_unique<value>(std::move(args[i]));
}
return struct_value{
.struct_type = std::make_unique<types::type>(types::named_type{.name = identifier->name, .level = identifier->level}),
.fields = std::move(fields),
};
} }
return struct_value{
.struct_type = type,
.fields = std::move(fields),
};
} }
else
throw std::runtime_error("Unknown type in constructor");
} }
else
auto lvalue = eval(context, function_call.function); throw std::runtime_error("Function call node has neither function nor type");
auto fvalue = std::get_if<function_value>(&lvalue);
if (fvalue)
{
if (fvalue->arguments.size() != function_call.arguments.size())
{
std::ostringstream os;
os << "Cannot call function: expected " << fvalue->arguments.size() << " arguments, got " << function_call.arguments.size();
throw std::runtime_error(os.str());
}
std::vector<value> args;
for (auto const & expression : function_call.arguments)
args.push_back(eval(context, expression));
for (std::size_t i = 0; i < args.size(); ++i)
{
auto actual_type = type_of(args[i]);
if (!types::equal(actual_type, *fvalue->arguments[i].type))
{
std::ostringstream os;
os << "Cannot call function: argument #" << (i + 1) << " expects type ";
types::print(os, *fvalue->arguments[i].type);
os << " but actual type is ";
types::print(os, actual_type);
throw std::runtime_error(os.str());
}
}
auto & function_scope = context.scope_stack.emplace_back();
function_scope.is_function_scope = true;
for (std::size_t i = 0; i < args.size(); ++i)
function_scope.variables[fvalue->arguments[i].name] = {.category = ast::value_category::constant, .value = std::move(args[i])};
auto expected_return_type = fvalue->return_type;
exec(context, fvalue->statements);
auto actual_return_type = type_of(context.scope_stack.back().return_value);
if (!types::equal(actual_return_type, *expected_return_type))
{
std::ostringstream os;
os << "Error returning from function: expected return type is ";
types::print(os, *expected_return_type);
os << " but actual type is ";
types::print(os, actual_return_type);
throw std::runtime_error(os.str());
}
auto result = std::move(context.scope_stack.back().return_value);
context.scope_stack.pop_back();
return result;
}
std::ostringstream os;
os << "Cannot call ";
print(os, lvalue);
os << ": not a function";
throw std::runtime_error(os.str());
} }
value eval_impl(context & context, ast::array const & array) value eval_impl(context & context, ast::array const & array)

View file

@ -289,19 +289,19 @@ postfix_expression
: base_expression : base_expression
| postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$}; } | postfix_expression lbracket expression rbracket { $$ = ast::array_access{std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$}; }
| postfix_expression dot name { $$ = ast::field_access{std::make_unique<ast::expression>($1), $3, @$}; } | postfix_expression dot name { $$ = ast::field_access{std::make_unique<ast::expression>($1), $3, @$}; }
| postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>($1), $3, @$}; } | postfix_expression lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>($1), nullptr, $3, @$}; }
| unit lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"unit"}), $3, @$}; } | unit lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::unit_type{}), $3, @$}; }
| bool lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"bool"}), $3, @$}; } | bool lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::bool_type{}), $3, @$}; }
| i8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"i8"}), $3, @$}; } | i8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i8_type{}), $3, @$}; }
| u8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"u8"}), $3, @$}; } | u8 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u8_type{}), $3, @$}; }
| i16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"i16"}), $3, @$}; } | i16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i16_type{}), $3, @$}; }
| u16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"u16"}), $3, @$}; } | u16 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u16_type{}), $3, @$}; }
| i32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"i32"}), $3, @$}; } | i32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i32_type{}), $3, @$}; }
| u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"u32"}), $3, @$}; } | u32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u32_type{}), $3, @$}; }
| i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"i64"}), $3, @$}; } | i64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::i64_type{}), $3, @$}; }
| u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"u64"}), $3, @$}; } | u64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::u64_type{}), $3, @$}; }
| f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"f32"}), $3, @$}; } | f32 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f32_type{}), $3, @$}; }
| f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{std::make_unique<ast::expression>(ast::identifier{"f64"}), $3, @$}; } | f64 lparen comma_separated_expression_list rparen { $$ = ast::function_call{nullptr, std::make_unique<ast::type>(types::f64_type{}), $3, @$}; }
; ;
base_expression base_expression

View file

@ -21,5 +21,6 @@ namespace pslang::types
bool is_unsigned_integer_type(type const & type); bool is_unsigned_integer_type(type const & type);
bool is_floating_point_type(type const & type); bool is_floating_point_type(type const & type);
bool is_numeric_type(type const & type); bool is_numeric_type(type const & type);
bool is_builtin_type(type const & type);
} }

View file

@ -113,4 +113,14 @@ namespace pslang::types
return false; return false;
} }
bool is_builtin_type(type const & type)
{
if (std::get_if<unit_type>(&type))
return true;
if (std::get_if<primitive_type>(&type))
return true;
return false;
}
} }