Support mutually recursive functions in identifier resolution & type checker

This commit is contained in:
Nikita Lisitsa 2026-01-16 15:23:48 +03:00
parent aab10621cc
commit 7a3c7cca5d
4 changed files with 194 additions and 43 deletions

View file

@ -173,9 +173,9 @@ int main(int argc, char ** argv)
{ {
// TODO: remove, testing-only code; should execute entry point instead // TODO: remove, testing-only code; should execute entry point instead
auto offset = pcontext.symbols.at("sqr"); auto offset = pcontext.symbols.at("foo");
auto fptr = (unsigned(*)(unsigned))(executable.data.get() + offset); auto fptr = (uint32_t(*)(uint32_t))(executable.data.get() + offset);
auto x = fptr(30); auto x = fptr(10u);
std::cout << "Result: " << std::boolalpha << x << std::endl; std::cout << "Result: " << std::boolalpha << x << std::endl;
} }
} }

View file

@ -1,5 +1,7 @@
func sqr(x : f32) -> f32: func foo(x : u32) -> u32:
return x * x if x == 0u:
return 42u
return 1u + bar(x - 1u)
func smoothstep(x : f32) -> f32: func bar(x : u32) -> u32:
return sqr(x) * (3.0 - 2.0 * x) return foo(x)

View file

@ -56,12 +56,66 @@ namespace pslang::ast
bool is_global_scope = false; bool is_global_scope = false;
}; };
struct populate_globals_visitor
: statement_visitor<populate_globals_visitor>
{
std::vector<scope> & scopes;
using statement_visitor::apply;
void apply(expression_ptr const &)
{}
void apply(assignment const &)
{}
void apply(variable_declaration const &)
{}
void apply(if_block const &)
{}
void apply(else_if_block const &)
{}
void apply(else_block const &)
{}
void apply(if_chain const &)
{}
void apply(while_block const &)
{}
void apply(function_definition const & function_definition)
{
if (scopes.back().contains(function_definition.name))
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
scopes.back().functions.insert(function_definition.name);
}
void apply(return_statement const &)
{}
void apply(field_definition const &)
{}
void apply(struct_definition const & struct_definition)
{
if (scopes.back().contains(struct_definition.name))
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
scopes.back().structs.insert(struct_definition.name);
}
};
struct resolve_identifiers_visitor struct resolve_identifiers_visitor
: type_visitor<resolve_identifiers_visitor> : type_visitor<resolve_identifiers_visitor>
, expression_visitor<resolve_identifiers_visitor> , expression_visitor<resolve_identifiers_visitor>
, statement_visitor<resolve_identifiers_visitor> , statement_visitor<resolve_identifiers_visitor>
{ {
std::vector<scope> scopes; std::vector<scope> & scopes;
using type_visitor::apply; using type_visitor::apply;
using expression_visitor::apply; using expression_visitor::apply;
@ -246,8 +300,7 @@ namespace pslang::ast
void apply(function_definition const & function_definition) void apply(function_definition const & function_definition)
{ {
if (scopes.back().contains(function_definition.name)) // Already added to scope by populate_globals_visitor
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
std::unordered_set<std::string> argument_names; std::unordered_set<std::string> argument_names;
for (auto const & argument : function_definition.arguments) for (auto const & argument : function_definition.arguments)
@ -282,8 +335,7 @@ namespace pslang::ast
void apply(struct_definition const & struct_definition) void apply(struct_definition const & struct_definition)
{ {
if (scopes.back().contains(struct_definition.name)) // Already added to scope by populate_globals_visitor
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
for (auto const & field : struct_definition.fields) for (auto const & field : struct_definition.fields)
apply(field); apply(field);
@ -291,14 +343,23 @@ namespace pslang::ast
scopes.back().structs.insert(struct_definition.name); scopes.back().structs.insert(struct_definition.name);
} }
void apply(statement_list & statement_list)
{
populate_globals_visitor populate_globals_visitor{{}, scopes};
populate_globals_visitor.apply(statement_list);
for (auto const & statement : statement_list.statements)
apply(*statement);
}
}; };
} }
void resolve_identifiers(statement_list_ptr & statements) void resolve_identifiers(statement_list_ptr & statements)
{ {
resolve_identifiers_visitor visitor; std::vector<scope> scopes;
visitor.scopes.emplace_back().is_global_scope = true; scopes.emplace_back().is_global_scope = true;
resolve_identifiers_visitor visitor{{}, {}, {}, scopes};
visitor.apply(*statements); visitor.apply(*statements);
} }

View file

@ -50,16 +50,12 @@ namespace pslang::ast
types::type_ptr expected_return_type = nullptr; types::type_ptr expected_return_type = nullptr;
}; };
struct check_visitor struct resolve_types_visitor
: type_visitor<check_visitor> : type_visitor<resolve_types_visitor>
, expression_visitor<check_visitor>
, const_statement_visitor<check_visitor>
{ {
std::vector<scope> scopes; std::vector<scope> & scopes;
using type_visitor::apply; using type_visitor::apply;
using expression_visitor::apply;
using const_statement_visitor::apply;
void apply(types::unit_type const &) void apply(types::unit_type const &)
{} {}
@ -98,6 +94,104 @@ namespace pslang::ast
type.level = node.level; type.level = node.level;
node.inferred_type = std::make_unique<types::type>(std::move(type)); node.inferred_type = std::make_unique<types::type>(std::move(type));
} }
};
void resolve_types(std::vector<scope> & scopes, ast::type & type)
{
resolve_types_visitor visitor{{}, scopes};
visitor.apply(type);
}
struct populate_globals_visitor
: statement_visitor<populate_globals_visitor>
{
std::vector<scope> & scopes;
using statement_visitor::apply;
void apply(expression_ptr const &)
{}
void apply(assignment const &)
{}
void apply(variable_declaration const &)
{}
void apply(if_block const &)
{}
void apply(else_if_block const &)
{}
void apply(else_block const &)
{}
void apply(if_chain const &)
{}
void apply(while_block const &)
{}
void apply(function_definition const & node)
{
for (auto const & argument : node.arguments)
resolve_types(scopes, *argument.type);
resolve_types(scopes, *node.return_type);
auto & data = scopes.back().functions[node.name];
for (auto const & argument : node.arguments)
data.arguments.push_back(get_type(*argument.type));
data.result_type = get_type(*node.return_type);
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = get_type(*node.return_type);
for (auto const & argument : node.arguments)
{
scopes.back().variables[argument.name] = {
.category = value_category::constant,
.type = get_type(*argument.type),
};
}
apply(*node.statements);
scopes.pop_back();
}
void apply(return_statement const &)
{}
void apply(field_definition const & node)
{
resolve_types(scopes, *node.type);
}
void apply(struct_definition const & node)
{
for (auto const & field : node.fields)
apply(field);
auto & data = scopes.back().structs[node.name];
for (auto const & field : node.fields)
data.fields.push_back({.name = field.name, .type = get_type(*field.type)});
}
};
void populate_globals(std::vector<scope> & scopes, statement_list & statement_list)
{
populate_globals_visitor visitor{{}, scopes};
visitor.apply(statement_list);
}
struct check_visitor
: expression_visitor<check_visitor>
, const_statement_visitor<check_visitor>
{
std::vector<scope> & scopes;
using expression_visitor::apply;
using const_statement_visitor::apply;
void apply(literal &) void apply(literal &)
{} {}
@ -277,7 +371,7 @@ namespace pslang::ast
void apply(cast_operation & node) void apply(cast_operation & node)
{ {
apply(*node.expression); apply(*node.expression);
apply(*node.type); resolve_types(scopes, *node.type);
auto source_type = get_type(*node.expression); auto source_type = get_type(*node.expression);
auto target_type = get_type(*node.type); auto target_type = get_type(*node.type);
@ -304,7 +398,7 @@ namespace pslang::ast
if (node.function) if (node.function)
apply(*node.function); apply(*node.function);
if (node.type) if (node.type)
apply(*node.type); resolve_types(scopes, *node.type);
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
apply(*argument); apply(*argument);
@ -530,7 +624,7 @@ namespace pslang::ast
auto actual_type = get_type(*node.initializer); auto actual_type = get_type(*node.initializer);
if (node.type) if (node.type)
{ {
apply(*node.type); resolve_types(scopes, *node.type);
auto expected_type = get_type(*node.type); auto expected_type = get_type(*node.type);
if (!types::equal(*expected_type, *actual_type)) if (!types::equal(*expected_type, *actual_type))
{ {
@ -606,14 +700,7 @@ namespace pslang::ast
void apply(function_definition const & node) void apply(function_definition const & node)
{ {
for (auto const & argument : node.arguments) // Already added to scope by populate_globals_visitor
apply(*argument.type);
apply(*node.return_type);
auto & data = scopes.back().functions[node.name];
for (auto const & argument : node.arguments)
data.arguments.push_back(get_type(*argument.type));
data.result_type = get_type(*node.return_type);
scopes.emplace_back().is_function_scope = true; scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = get_type(*node.return_type); scopes.back().expected_return_type = get_type(*node.return_type);
@ -658,19 +745,19 @@ namespace pslang::ast
} }
} }
void apply(field_definition const & node) void apply(field_definition const &)
{ {}
apply(*node.type);
}
void apply(struct_definition const & node) void apply(struct_definition const & node)
{ {
for (auto const & field : node.fields) // Already added to scope by populate_globals_visitor
apply(field); }
auto & data = scopes.back().structs[node.name]; void apply(statement_list & node)
for (auto const & field : node.fields) {
data.fields.push_back({.name = field.name, .type = get_type(*field.type)}); populate_globals(scopes, node);
for (auto const & statement : node.statements)
apply(*statement);
} }
}; };
@ -678,8 +765,9 @@ namespace pslang::ast
void check_and_infer_types(statement_list_ptr & statements) void check_and_infer_types(statement_list_ptr & statements)
{ {
check_visitor visitor; std::vector<scope> scopes;
visitor.scopes.emplace_back().is_global_scope = true; scopes.emplace_back().is_global_scope = true;
check_visitor visitor{{}, {}, scopes};
visitor.apply(*statements); visitor.apply(*statements);
} }