Validate functions with no return & automatically add return in the end of a function returning unit

This commit is contained in:
Nikita Lisitsa 2026-03-13 13:40:09 +03:00
parent 11656fb296
commit 51c78169b3
6 changed files with 119 additions and 5 deletions

View file

@ -125,6 +125,7 @@ int main(int argc, char ** argv)
auto ast = parser::parse(filenames.back());
ast::resolve_identifiers(ast);
ast::check_and_infer_types(ast);
ast::validate(ast);
parsed.push_back(std::move(ast));
}
catch (ast::parse_error const & error)
@ -145,6 +146,12 @@ int main(int argc, char ** argv)
print_error_context(argv[arg], error.location());
return EXIT_FAILURE;
}
catch (ast::validation_error const & error)
{
std::cerr << "Validation error at " << error.location() << ":\n " << error.what() << std::endl;
print_error_context(argv[arg], error.location());
return EXIT_FAILURE;
}
catch (interpreter::internal_error const & error)
{
std::cerr << "Internal error at " << error.location() << ":\n " << error.what() << std::endl;

View file

@ -10,5 +10,29 @@ func add_or_sub(add : bool) -> (u32 -> u32):
else:
return sub1
func test(x: i32) -> i32:
return 10 / 0
foreign func putchar(c: i32) -> i32
func print(c: u8):
putchar(c as i32)
func test():
print('H')
print('e')
print('l')
print('l')
print('o')
print(',')
print(' ')
print('w')
print('o')
print('r')
print('l')
print('d')
print('!')
print('\n')
//func test1():
// let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n']
// mut i = 0
// while i < 14:
// print(str[i])

View file

@ -53,5 +53,10 @@ namespace pslang::ast
using error::error;
};
struct validation_error
: error
{
using error::error;
};
}

View file

@ -7,5 +7,6 @@ namespace pslang::ast
void resolve_identifiers(statement_list_ptr & statements);
void check_and_infer_types(statement_list_ptr & statements);
void validate(statement_list_ptr & statements);
}

View file

@ -0,0 +1,69 @@
#include <pslang/ast/preprocess.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <pslang/ast/error.hpp>
#include <pslang/types/type.hpp>
#include <iostream>
namespace pslang::ast
{
namespace
{
struct validate_visitor
: const_statement_visitor<validate_visitor>
{
using const_statement_visitor::apply;
void apply(expression_ptr const &) {}
void apply(assignment const &) {}
void apply(variable_declaration const &) {}
void apply(if_block const &) {}
void apply(else_block const &) {}
void apply(else_if_block const &) {}
void apply(if_chain const & node)
{
for (auto const & block : node.blocks)
apply(*block.statements);
}
void apply(while_block const & node)
{
apply(*node.statements);
}
void apply(function_definition const & node)
{
if (!types::equal(*get_type(*node.return_type), types::unit_type{}))
if (node.statements->statements.empty() || (true
&& !std::get_if<return_statement>(node.statements->statements.back().get())
&& !std::get_if<if_chain>(node.statements->statements.back().get())
&& !std::get_if<while_block>(node.statements->statements.back().get())
))
throw validation_error("Function returning non-unit is missing a return statement in the end", node.location);
}
void apply(foreign_function_declaration const &) {}
void apply(return_statement const &) {}
void apply(field_definition const &) {}
void apply(struct_definition const &) {}
};
}
void validate(statement_list_ptr & statements)
{
validate_visitor{}.apply(*statements);
}
}

View file

@ -825,9 +825,7 @@ namespace pslang::jit::aarch64
{
if (node.value)
apply(*node.value);
if (stack_offset > 0)
builder.add_imm(31, 31, stack_offset);
builder.ret();
do_return();
}
void apply(ast::function_definition const &)
@ -882,6 +880,13 @@ namespace pslang::jit::aarch64
scopes.pop_back();
}
void do_return()
{
if (stack_offset > 0)
builder.add_imm(31, 31, stack_offset);
builder.ret();
}
private:
void push(std::uint8_t reg)
{
@ -954,6 +959,9 @@ namespace pslang::jit::aarch64
pcontext.symbols[node.name] = pcontext.code.size();
compile_function_visitor visitor{{}, {}, pcontext, lcontext};
visitor.do_apply(node);
if (node.statements->statements.empty() || !std::get_if<ast::return_statement>(node.statements->statements.back().get()))
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
visitor.do_return();
}
void apply(ast::foreign_function_declaration const &)