From 51c78169b3e10bdb34c2e839ba54da765b7de2ec Mon Sep 17 00:00:00 2001 From: lisyarus Date: Fri, 13 Mar 2026 13:40:09 +0300 Subject: [PATCH] Validate functions with no return & automatically add return in the end of a function returning unit --- apps/interpreter/source/main.cpp | 7 +++ examples/jit_test.psl | 28 ++++++++- libs/ast/include/pslang/ast/error.hpp | 5 ++ libs/ast/include/pslang/ast/preprocess.hpp | 1 + libs/ast/source/validate.cpp | 69 ++++++++++++++++++++++ libs/jit/source/arch/aarch64/compiler.cpp | 14 ++++- 6 files changed, 119 insertions(+), 5 deletions(-) create mode 100644 libs/ast/source/validate.cpp diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 5f3baec..6531c80 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -125,6 +125,7 @@ int main(int argc, char ** argv) auto ast = parser::parse(filenames.back()); ast::resolve_identifiers(ast); ast::check_and_infer_types(ast); + ast::validate(ast); parsed.push_back(std::move(ast)); } catch (ast::parse_error const & error) @@ -145,6 +146,12 @@ int main(int argc, char ** argv) print_error_context(argv[arg], error.location()); return EXIT_FAILURE; } + catch (ast::validation_error const & error) + { + std::cerr << "Validation error at " << error.location() << ":\n " << error.what() << std::endl; + print_error_context(argv[arg], error.location()); + return EXIT_FAILURE; + } catch (interpreter::internal_error const & error) { std::cerr << "Internal error at " << error.location() << ":\n " << error.what() << std::endl; diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 59f768c..199f236 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -10,5 +10,29 @@ func add_or_sub(add : bool) -> (u32 -> u32): else: return sub1 -func test(x: i32) -> i32: - return 10 / 0 +foreign func putchar(c: i32) -> i32 + +func print(c: u8): + putchar(c as i32) + +func test(): + print('H') + print('e') + print('l') + print('l') + print('o') + print(',') + print(' ') + print('w') + print('o') + print('r') + print('l') + print('d') + print('!') + print('\n') + +//func test1(): +// let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n'] +// mut i = 0 +// while i < 14: +// print(str[i]) diff --git a/libs/ast/include/pslang/ast/error.hpp b/libs/ast/include/pslang/ast/error.hpp index e0caeec..f0d3b49 100644 --- a/libs/ast/include/pslang/ast/error.hpp +++ b/libs/ast/include/pslang/ast/error.hpp @@ -53,5 +53,10 @@ namespace pslang::ast using error::error; }; + struct validation_error + : error + { + using error::error; + }; } diff --git a/libs/ast/include/pslang/ast/preprocess.hpp b/libs/ast/include/pslang/ast/preprocess.hpp index 6aaeb24..8899676 100644 --- a/libs/ast/include/pslang/ast/preprocess.hpp +++ b/libs/ast/include/pslang/ast/preprocess.hpp @@ -7,5 +7,6 @@ namespace pslang::ast void resolve_identifiers(statement_list_ptr & statements); void check_and_infer_types(statement_list_ptr & statements); + void validate(statement_list_ptr & statements); } diff --git a/libs/ast/source/validate.cpp b/libs/ast/source/validate.cpp new file mode 100644 index 0000000..3fca560 --- /dev/null +++ b/libs/ast/source/validate.cpp @@ -0,0 +1,69 @@ +#include +#include +#include +#include + +#include + +namespace pslang::ast +{ + + namespace + { + + struct validate_visitor + : const_statement_visitor + { + using const_statement_visitor::apply; + + void apply(expression_ptr const &) {} + + void apply(assignment const &) {} + + void apply(variable_declaration const &) {} + + void apply(if_block const &) {} + + void apply(else_block const &) {} + + void apply(else_if_block const &) {} + + void apply(if_chain const & node) + { + for (auto const & block : node.blocks) + apply(*block.statements); + } + + void apply(while_block const & node) + { + apply(*node.statements); + } + + void apply(function_definition const & node) + { + if (!types::equal(*get_type(*node.return_type), types::unit_type{})) + if (node.statements->statements.empty() || (true + && !std::get_if(node.statements->statements.back().get()) + && !std::get_if(node.statements->statements.back().get()) + && !std::get_if(node.statements->statements.back().get()) + )) + throw validation_error("Function returning non-unit is missing a return statement in the end", node.location); + } + + void apply(foreign_function_declaration const &) {} + + void apply(return_statement const &) {} + + void apply(field_definition const &) {} + + void apply(struct_definition const &) {} + }; + + } + + void validate(statement_list_ptr & statements) + { + validate_visitor{}.apply(*statements); + } + +} \ No newline at end of file diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 7b43f25..e3d6206 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -825,9 +825,7 @@ namespace pslang::jit::aarch64 { if (node.value) apply(*node.value); - if (stack_offset > 0) - builder.add_imm(31, 31, stack_offset); - builder.ret(); + do_return(); } void apply(ast::function_definition const &) @@ -882,6 +880,13 @@ namespace pslang::jit::aarch64 scopes.pop_back(); } + void do_return() + { + if (stack_offset > 0) + builder.add_imm(31, 31, stack_offset); + builder.ret(); + } + private: void push(std::uint8_t reg) { @@ -954,6 +959,9 @@ namespace pslang::jit::aarch64 pcontext.symbols[node.name] = pcontext.code.size(); compile_function_visitor visitor{{}, {}, pcontext, lcontext}; visitor.do_apply(node); + if (node.statements->statements.empty() || !std::get_if(node.statements->statements.back().get())) + if (types::equal(*ast::get_type(*node.return_type), types::unit_type{})) + visitor.do_return(); } void apply(ast::foreign_function_declaration const &)