Refactor return AST node: store direct function node reference instead of relative scope level

This commit is contained in:
Nikita Lisitsa 2026-03-22 20:18:39 +03:00
parent 5d00f1ddb7
commit c4d1252462
3 changed files with 17 additions and 19 deletions

View file

@ -43,7 +43,8 @@ namespace pslang::ast
// can be null, which means "return unit" // can be null, which means "return unit"
expression_ptr value; expression_ptr value;
ast::location location; ast::location location;
std::size_t level = 0;
function_definition * node = nullptr;
}; };
struct function_call struct function_call

View file

@ -771,17 +771,15 @@ namespace pslang::ast
actual_type = std::make_unique<types::type>(types::unit_type{}); actual_type = std::make_unique<types::type>(types::unit_type{});
} }
auto & return_scope = scopes.at(node.level); auto & return_node = *node.node;
if (!return_scope.expected_return_type)
throw invalid_ast_error("Unexpected return level", node.location);
if (!types::equal(*return_scope.expected_return_type, *actual_type)) if (!types::equal(*return_node.inferred_result_type, *actual_type))
{ {
std::ostringstream os; std::ostringstream os;
os << "Returning value of type "; os << "Returning value of type ";
print(os, *actual_type); print(os, *actual_type);
os << " from a function returning "; os << " from a function returning ";
print(os, *return_scope.expected_return_type); print(os, *return_node.inferred_result_type);
throw type_error(os.str(), node.location); throw type_error(os.str(), node.location);
} }
} }

View file

@ -102,7 +102,7 @@ namespace pslang::parser
std::vector<stack_entry> stack; std::vector<stack_entry> stack;
stack.push_back(result.get()); stack.push_back(result.get());
std::size_t current_indent = 0; std::size_t current_indent = 0;
std::size_t in_function_scope = 0; std::vector<ast::function_definition *> function_stack;
auto current_statement_list = [&](ast::location const & location) -> ast::statement_list * auto current_statement_list = [&](ast::location const & location) -> ast::statement_list *
{ {
@ -135,16 +135,17 @@ namespace pslang::parser
while (statement.indentation < current_indent) while (statement.indentation < current_indent)
{ {
if (stack.empty())
throw ast::invalid_ast_error("Unexpected empty indent stack", ast::get_location(*statement.statement));
if (!function_stack.empty() && std::holds_alternative<ast::statement_list *>(stack.back()) && function_stack.back()->statements.get() == std::get<ast::statement_list *>(stack.back()))
function_stack.pop_back();
stack.pop_back(); stack.pop_back();
--current_indent; --current_indent;
if (in_function_scope > 0)
--in_function_scope;
} }
// Now statement.indentation == current_indent // Now statement.indentation == current_indent
ast::statement_list * list = nullptr; ast::statement_list * list = nullptr;
bool is_function_definition = false;
if (auto if_block = std::get_if<ast::if_block>(statement.statement.get())) if (auto if_block = std::get_if<ast::if_block>(statement.statement.get()))
{ {
@ -185,9 +186,11 @@ namespace pslang::parser
else if (auto function_definition = std::get_if<ast::function_definition>(statement.statement.get())) else if (auto function_definition = std::get_if<ast::function_definition>(statement.statement.get()))
{ {
function_definition->statements = std::make_unique<ast::statement_list>(); function_definition->statements = std::make_unique<ast::statement_list>();
list = function_definition->statements.get(); auto statement = std::make_unique<ast::statement>(std::move(*function_definition));
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*function_definition))); auto function_definition_ptr = std::get_if<ast::function_definition>(statement.get());
is_function_definition = true; current_statement_list(location)->statements.push_back(std::move(statement));
list = function_definition_ptr->statements.get();
function_stack.push_back(function_definition_ptr);
} }
else if (auto field_definition = std::get_if<ast::field_definition>(statement.statement.get())) else if (auto field_definition = std::get_if<ast::field_definition>(statement.statement.get()))
{ {
@ -202,14 +205,12 @@ namespace pslang::parser
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*struct_definition))); current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*struct_definition)));
stack.push_back(std::get_if<ast::struct_definition>(current_statement_list(location)->statements.back().get())); stack.push_back(std::get_if<ast::struct_definition>(current_statement_list(location)->statements.back().get()));
++current_indent; ++current_indent;
if (in_function_scope > 0)
++in_function_scope;
} }
else if (auto return_statement = std::get_if<ast::return_statement>(statement.statement.get())) else if (auto return_statement = std::get_if<ast::return_statement>(statement.statement.get()))
{ {
if (in_function_scope == 0) if (function_stack.empty())
throw parse_error("Return statement outside of function scope", return_statement->location); throw parse_error("Return statement outside of function scope", return_statement->location);
return_statement->level = stack.size() - in_function_scope; return_statement->node = function_stack.back();
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*return_statement))); current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*return_statement)));
} }
else if (auto expression_ptr = std::get_if<ast::expression_ptr>(statement.statement.get())) else if (auto expression_ptr = std::get_if<ast::expression_ptr>(statement.statement.get()))
@ -237,8 +238,6 @@ namespace pslang::parser
{ {
stack.push_back(list); stack.push_back(list);
++current_indent; ++current_indent;
if (in_function_scope > 0 || is_function_definition)
++in_function_scope;
} }
} }