Support mutually recursive functions in identifier resolution & type checker
This commit is contained in:
parent
aab10621cc
commit
7a3c7cca5d
4 changed files with 194 additions and 43 deletions
|
|
@ -173,9 +173,9 @@ int main(int argc, char ** argv)
|
||||||
|
|
||||||
{
|
{
|
||||||
// TODO: remove, testing-only code; should execute entry point instead
|
// TODO: remove, testing-only code; should execute entry point instead
|
||||||
auto offset = pcontext.symbols.at("sqr");
|
auto offset = pcontext.symbols.at("foo");
|
||||||
auto fptr = (unsigned(*)(unsigned))(executable.data.get() + offset);
|
auto fptr = (uint32_t(*)(uint32_t))(executable.data.get() + offset);
|
||||||
auto x = fptr(30);
|
auto x = fptr(10u);
|
||||||
std::cout << "Result: " << std::boolalpha << x << std::endl;
|
std::cout << "Result: " << std::boolalpha << x << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
func sqr(x : f32) -> f32:
|
func foo(x : u32) -> u32:
|
||||||
return x * x
|
if x == 0u:
|
||||||
|
return 42u
|
||||||
|
return 1u + bar(x - 1u)
|
||||||
|
|
||||||
func smoothstep(x : f32) -> f32:
|
func bar(x : u32) -> u32:
|
||||||
return sqr(x) * (3.0 - 2.0 * x)
|
return foo(x)
|
||||||
|
|
|
||||||
|
|
@ -56,12 +56,66 @@ namespace pslang::ast
|
||||||
bool is_global_scope = false;
|
bool is_global_scope = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct populate_globals_visitor
|
||||||
|
: statement_visitor<populate_globals_visitor>
|
||||||
|
{
|
||||||
|
std::vector<scope> & scopes;
|
||||||
|
|
||||||
|
using statement_visitor::apply;
|
||||||
|
|
||||||
|
void apply(expression_ptr const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(assignment const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(variable_declaration const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(if_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(else_if_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(else_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(if_chain const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(while_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(function_definition const & function_definition)
|
||||||
|
{
|
||||||
|
if (scopes.back().contains(function_definition.name))
|
||||||
|
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
|
||||||
|
|
||||||
|
scopes.back().functions.insert(function_definition.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void apply(return_statement const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(field_definition const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(struct_definition const & struct_definition)
|
||||||
|
{
|
||||||
|
if (scopes.back().contains(struct_definition.name))
|
||||||
|
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
|
||||||
|
|
||||||
|
scopes.back().structs.insert(struct_definition.name);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct resolve_identifiers_visitor
|
struct resolve_identifiers_visitor
|
||||||
: type_visitor<resolve_identifiers_visitor>
|
: type_visitor<resolve_identifiers_visitor>
|
||||||
, expression_visitor<resolve_identifiers_visitor>
|
, expression_visitor<resolve_identifiers_visitor>
|
||||||
, statement_visitor<resolve_identifiers_visitor>
|
, statement_visitor<resolve_identifiers_visitor>
|
||||||
{
|
{
|
||||||
std::vector<scope> scopes;
|
std::vector<scope> & scopes;
|
||||||
|
|
||||||
using type_visitor::apply;
|
using type_visitor::apply;
|
||||||
using expression_visitor::apply;
|
using expression_visitor::apply;
|
||||||
|
|
@ -246,8 +300,7 @@ namespace pslang::ast
|
||||||
|
|
||||||
void apply(function_definition const & function_definition)
|
void apply(function_definition const & function_definition)
|
||||||
{
|
{
|
||||||
if (scopes.back().contains(function_definition.name))
|
// Already added to scope by populate_globals_visitor
|
||||||
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
|
|
||||||
|
|
||||||
std::unordered_set<std::string> argument_names;
|
std::unordered_set<std::string> argument_names;
|
||||||
for (auto const & argument : function_definition.arguments)
|
for (auto const & argument : function_definition.arguments)
|
||||||
|
|
@ -282,8 +335,7 @@ namespace pslang::ast
|
||||||
|
|
||||||
void apply(struct_definition const & struct_definition)
|
void apply(struct_definition const & struct_definition)
|
||||||
{
|
{
|
||||||
if (scopes.back().contains(struct_definition.name))
|
// Already added to scope by populate_globals_visitor
|
||||||
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
|
|
||||||
|
|
||||||
for (auto const & field : struct_definition.fields)
|
for (auto const & field : struct_definition.fields)
|
||||||
apply(field);
|
apply(field);
|
||||||
|
|
@ -291,14 +343,23 @@ namespace pslang::ast
|
||||||
scopes.back().structs.insert(struct_definition.name);
|
scopes.back().structs.insert(struct_definition.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(statement_list & statement_list)
|
||||||
|
{
|
||||||
|
populate_globals_visitor populate_globals_visitor{{}, scopes};
|
||||||
|
populate_globals_visitor.apply(statement_list);
|
||||||
|
|
||||||
|
for (auto const & statement : statement_list.statements)
|
||||||
|
apply(*statement);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void resolve_identifiers(statement_list_ptr & statements)
|
void resolve_identifiers(statement_list_ptr & statements)
|
||||||
{
|
{
|
||||||
resolve_identifiers_visitor visitor;
|
std::vector<scope> scopes;
|
||||||
visitor.scopes.emplace_back().is_global_scope = true;
|
scopes.emplace_back().is_global_scope = true;
|
||||||
|
resolve_identifiers_visitor visitor{{}, {}, {}, scopes};
|
||||||
visitor.apply(*statements);
|
visitor.apply(*statements);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,16 +50,12 @@ namespace pslang::ast
|
||||||
types::type_ptr expected_return_type = nullptr;
|
types::type_ptr expected_return_type = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct check_visitor
|
struct resolve_types_visitor
|
||||||
: type_visitor<check_visitor>
|
: type_visitor<resolve_types_visitor>
|
||||||
, expression_visitor<check_visitor>
|
|
||||||
, const_statement_visitor<check_visitor>
|
|
||||||
{
|
{
|
||||||
std::vector<scope> scopes;
|
std::vector<scope> & scopes;
|
||||||
|
|
||||||
using type_visitor::apply;
|
using type_visitor::apply;
|
||||||
using expression_visitor::apply;
|
|
||||||
using const_statement_visitor::apply;
|
|
||||||
|
|
||||||
void apply(types::unit_type const &)
|
void apply(types::unit_type const &)
|
||||||
{}
|
{}
|
||||||
|
|
@ -98,6 +94,104 @@ namespace pslang::ast
|
||||||
type.level = node.level;
|
type.level = node.level;
|
||||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void resolve_types(std::vector<scope> & scopes, ast::type & type)
|
||||||
|
{
|
||||||
|
resolve_types_visitor visitor{{}, scopes};
|
||||||
|
visitor.apply(type);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct populate_globals_visitor
|
||||||
|
: statement_visitor<populate_globals_visitor>
|
||||||
|
{
|
||||||
|
std::vector<scope> & scopes;
|
||||||
|
|
||||||
|
using statement_visitor::apply;
|
||||||
|
|
||||||
|
void apply(expression_ptr const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(assignment const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(variable_declaration const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(if_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(else_if_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(else_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(if_chain const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(while_block const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(function_definition const & node)
|
||||||
|
{
|
||||||
|
for (auto const & argument : node.arguments)
|
||||||
|
resolve_types(scopes, *argument.type);
|
||||||
|
resolve_types(scopes, *node.return_type);
|
||||||
|
|
||||||
|
auto & data = scopes.back().functions[node.name];
|
||||||
|
for (auto const & argument : node.arguments)
|
||||||
|
data.arguments.push_back(get_type(*argument.type));
|
||||||
|
data.result_type = get_type(*node.return_type);
|
||||||
|
|
||||||
|
scopes.emplace_back().is_function_scope = true;
|
||||||
|
scopes.back().expected_return_type = get_type(*node.return_type);
|
||||||
|
|
||||||
|
for (auto const & argument : node.arguments)
|
||||||
|
{
|
||||||
|
scopes.back().variables[argument.name] = {
|
||||||
|
.category = value_category::constant,
|
||||||
|
.type = get_type(*argument.type),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
apply(*node.statements);
|
||||||
|
scopes.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void apply(return_statement const &)
|
||||||
|
{}
|
||||||
|
|
||||||
|
void apply(field_definition const & node)
|
||||||
|
{
|
||||||
|
resolve_types(scopes, *node.type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void apply(struct_definition const & node)
|
||||||
|
{
|
||||||
|
for (auto const & field : node.fields)
|
||||||
|
apply(field);
|
||||||
|
|
||||||
|
auto & data = scopes.back().structs[node.name];
|
||||||
|
for (auto const & field : node.fields)
|
||||||
|
data.fields.push_back({.name = field.name, .type = get_type(*field.type)});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populate_globals(std::vector<scope> & scopes, statement_list & statement_list)
|
||||||
|
{
|
||||||
|
populate_globals_visitor visitor{{}, scopes};
|
||||||
|
visitor.apply(statement_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct check_visitor
|
||||||
|
: expression_visitor<check_visitor>
|
||||||
|
, const_statement_visitor<check_visitor>
|
||||||
|
{
|
||||||
|
std::vector<scope> & scopes;
|
||||||
|
|
||||||
|
using expression_visitor::apply;
|
||||||
|
using const_statement_visitor::apply;
|
||||||
|
|
||||||
void apply(literal &)
|
void apply(literal &)
|
||||||
{}
|
{}
|
||||||
|
|
@ -277,7 +371,7 @@ namespace pslang::ast
|
||||||
void apply(cast_operation & node)
|
void apply(cast_operation & node)
|
||||||
{
|
{
|
||||||
apply(*node.expression);
|
apply(*node.expression);
|
||||||
apply(*node.type);
|
resolve_types(scopes, *node.type);
|
||||||
|
|
||||||
auto source_type = get_type(*node.expression);
|
auto source_type = get_type(*node.expression);
|
||||||
auto target_type = get_type(*node.type);
|
auto target_type = get_type(*node.type);
|
||||||
|
|
@ -304,7 +398,7 @@ namespace pslang::ast
|
||||||
if (node.function)
|
if (node.function)
|
||||||
apply(*node.function);
|
apply(*node.function);
|
||||||
if (node.type)
|
if (node.type)
|
||||||
apply(*node.type);
|
resolve_types(scopes, *node.type);
|
||||||
for (auto const & argument : node.arguments)
|
for (auto const & argument : node.arguments)
|
||||||
apply(*argument);
|
apply(*argument);
|
||||||
|
|
||||||
|
|
@ -530,7 +624,7 @@ namespace pslang::ast
|
||||||
auto actual_type = get_type(*node.initializer);
|
auto actual_type = get_type(*node.initializer);
|
||||||
if (node.type)
|
if (node.type)
|
||||||
{
|
{
|
||||||
apply(*node.type);
|
resolve_types(scopes, *node.type);
|
||||||
auto expected_type = get_type(*node.type);
|
auto expected_type = get_type(*node.type);
|
||||||
if (!types::equal(*expected_type, *actual_type))
|
if (!types::equal(*expected_type, *actual_type))
|
||||||
{
|
{
|
||||||
|
|
@ -606,14 +700,7 @@ namespace pslang::ast
|
||||||
|
|
||||||
void apply(function_definition const & node)
|
void apply(function_definition const & node)
|
||||||
{
|
{
|
||||||
for (auto const & argument : node.arguments)
|
// Already added to scope by populate_globals_visitor
|
||||||
apply(*argument.type);
|
|
||||||
apply(*node.return_type);
|
|
||||||
|
|
||||||
auto & data = scopes.back().functions[node.name];
|
|
||||||
for (auto const & argument : node.arguments)
|
|
||||||
data.arguments.push_back(get_type(*argument.type));
|
|
||||||
data.result_type = get_type(*node.return_type);
|
|
||||||
|
|
||||||
scopes.emplace_back().is_function_scope = true;
|
scopes.emplace_back().is_function_scope = true;
|
||||||
scopes.back().expected_return_type = get_type(*node.return_type);
|
scopes.back().expected_return_type = get_type(*node.return_type);
|
||||||
|
|
@ -658,19 +745,19 @@ namespace pslang::ast
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(field_definition const & node)
|
void apply(field_definition const &)
|
||||||
{
|
{}
|
||||||
apply(*node.type);
|
|
||||||
}
|
|
||||||
|
|
||||||
void apply(struct_definition const & node)
|
void apply(struct_definition const & node)
|
||||||
{
|
{
|
||||||
for (auto const & field : node.fields)
|
// Already added to scope by populate_globals_visitor
|
||||||
apply(field);
|
}
|
||||||
|
|
||||||
auto & data = scopes.back().structs[node.name];
|
void apply(statement_list & node)
|
||||||
for (auto const & field : node.fields)
|
{
|
||||||
data.fields.push_back({.name = field.name, .type = get_type(*field.type)});
|
populate_globals(scopes, node);
|
||||||
|
for (auto const & statement : node.statements)
|
||||||
|
apply(*statement);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -678,8 +765,9 @@ namespace pslang::ast
|
||||||
|
|
||||||
void check_and_infer_types(statement_list_ptr & statements)
|
void check_and_infer_types(statement_list_ptr & statements)
|
||||||
{
|
{
|
||||||
check_visitor visitor;
|
std::vector<scope> scopes;
|
||||||
visitor.scopes.emplace_back().is_global_scope = true;
|
scopes.emplace_back().is_global_scope = true;
|
||||||
|
check_visitor visitor{{}, {}, scopes};
|
||||||
visitor.apply(*statements);
|
visitor.apply(*statements);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue