diff --git a/libs/ast/include/pslang/ast/function.hpp b/libs/ast/include/pslang/ast/function.hpp index 043cae4..7171a27 100644 --- a/libs/ast/include/pslang/ast/function.hpp +++ b/libs/ast/include/pslang/ast/function.hpp @@ -43,7 +43,8 @@ namespace pslang::ast // can be null, which means "return unit" expression_ptr value; ast::location location; - std::size_t level = 0; + + function_definition * node = nullptr; }; struct function_call diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index cf40ab1..49de951 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -771,17 +771,15 @@ namespace pslang::ast actual_type = std::make_unique(types::unit_type{}); } - auto & return_scope = scopes.at(node.level); - if (!return_scope.expected_return_type) - throw invalid_ast_error("Unexpected return level", node.location); + auto & return_node = *node.node; - if (!types::equal(*return_scope.expected_return_type, *actual_type)) + if (!types::equal(*return_node.inferred_result_type, *actual_type)) { std::ostringstream os; os << "Returning value of type "; print(os, *actual_type); os << " from a function returning "; - print(os, *return_scope.expected_return_type); + print(os, *return_node.inferred_result_type); throw type_error(os.str(), node.location); } } diff --git a/libs/parser/source/finalize.cpp b/libs/parser/source/finalize.cpp index af1f442..c7e4a53 100644 --- a/libs/parser/source/finalize.cpp +++ b/libs/parser/source/finalize.cpp @@ -102,7 +102,7 @@ namespace pslang::parser std::vector stack; stack.push_back(result.get()); std::size_t current_indent = 0; - std::size_t in_function_scope = 0; + std::vector function_stack; auto current_statement_list = [&](ast::location const & location) -> ast::statement_list * { @@ -135,16 +135,17 @@ namespace pslang::parser while (statement.indentation < current_indent) { + if (stack.empty()) + throw ast::invalid_ast_error("Unexpected empty indent stack", ast::get_location(*statement.statement)); + if (!function_stack.empty() && std::holds_alternative(stack.back()) && function_stack.back()->statements.get() == std::get(stack.back())) + function_stack.pop_back(); stack.pop_back(); --current_indent; - if (in_function_scope > 0) - --in_function_scope; } // Now statement.indentation == current_indent ast::statement_list * list = nullptr; - bool is_function_definition = false; if (auto if_block = std::get_if(statement.statement.get())) { @@ -185,9 +186,11 @@ namespace pslang::parser else if (auto function_definition = std::get_if(statement.statement.get())) { function_definition->statements = std::make_unique(); - list = function_definition->statements.get(); - current_statement_list(location)->statements.push_back(std::make_unique(std::move(*function_definition))); - is_function_definition = true; + auto statement = std::make_unique(std::move(*function_definition)); + auto function_definition_ptr = std::get_if(statement.get()); + current_statement_list(location)->statements.push_back(std::move(statement)); + list = function_definition_ptr->statements.get(); + function_stack.push_back(function_definition_ptr); } else if (auto field_definition = std::get_if(statement.statement.get())) { @@ -202,14 +205,12 @@ namespace pslang::parser current_statement_list(location)->statements.push_back(std::make_unique(std::move(*struct_definition))); stack.push_back(std::get_if(current_statement_list(location)->statements.back().get())); ++current_indent; - if (in_function_scope > 0) - ++in_function_scope; } else if (auto return_statement = std::get_if(statement.statement.get())) { - if (in_function_scope == 0) + if (function_stack.empty()) throw parse_error("Return statement outside of function scope", return_statement->location); - return_statement->level = stack.size() - in_function_scope; + return_statement->node = function_stack.back(); current_statement_list(location)->statements.push_back(std::make_unique(std::move(*return_statement))); } else if (auto expression_ptr = std::get_if(statement.statement.get())) @@ -237,8 +238,6 @@ namespace pslang::parser { stack.push_back(list); ++current_indent; - if (in_function_scope > 0 || is_function_definition) - ++in_function_scope; } }