Support mutually recursive functions in identifier resolution & type checker
This commit is contained in:
parent
aab10621cc
commit
7a3c7cca5d
4 changed files with 194 additions and 43 deletions
|
|
@ -173,9 +173,9 @@ int main(int argc, char ** argv)
|
|||
|
||||
{
|
||||
// TODO: remove, testing-only code; should execute entry point instead
|
||||
auto offset = pcontext.symbols.at("sqr");
|
||||
auto fptr = (unsigned(*)(unsigned))(executable.data.get() + offset);
|
||||
auto x = fptr(30);
|
||||
auto offset = pcontext.symbols.at("foo");
|
||||
auto fptr = (uint32_t(*)(uint32_t))(executable.data.get() + offset);
|
||||
auto x = fptr(10u);
|
||||
std::cout << "Result: " << std::boolalpha << x << std::endl;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
func sqr(x : f32) -> f32:
|
||||
return x * x
|
||||
func foo(x : u32) -> u32:
|
||||
if x == 0u:
|
||||
return 42u
|
||||
return 1u + bar(x - 1u)
|
||||
|
||||
func smoothstep(x : f32) -> f32:
|
||||
return sqr(x) * (3.0 - 2.0 * x)
|
||||
func bar(x : u32) -> u32:
|
||||
return foo(x)
|
||||
|
|
|
|||
|
|
@ -56,12 +56,66 @@ namespace pslang::ast
|
|||
bool is_global_scope = false;
|
||||
};
|
||||
|
||||
struct populate_globals_visitor
|
||||
: statement_visitor<populate_globals_visitor>
|
||||
{
|
||||
std::vector<scope> & scopes;
|
||||
|
||||
using statement_visitor::apply;
|
||||
|
||||
void apply(expression_ptr const &)
|
||||
{}
|
||||
|
||||
void apply(assignment const &)
|
||||
{}
|
||||
|
||||
void apply(variable_declaration const &)
|
||||
{}
|
||||
|
||||
void apply(if_block const &)
|
||||
{}
|
||||
|
||||
void apply(else_if_block const &)
|
||||
{}
|
||||
|
||||
void apply(else_block const &)
|
||||
{}
|
||||
|
||||
void apply(if_chain const &)
|
||||
{}
|
||||
|
||||
void apply(while_block const &)
|
||||
{}
|
||||
|
||||
void apply(function_definition const & function_definition)
|
||||
{
|
||||
if (scopes.back().contains(function_definition.name))
|
||||
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
|
||||
|
||||
scopes.back().functions.insert(function_definition.name);
|
||||
}
|
||||
|
||||
void apply(return_statement const &)
|
||||
{}
|
||||
|
||||
void apply(field_definition const &)
|
||||
{}
|
||||
|
||||
void apply(struct_definition const & struct_definition)
|
||||
{
|
||||
if (scopes.back().contains(struct_definition.name))
|
||||
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
|
||||
|
||||
scopes.back().structs.insert(struct_definition.name);
|
||||
}
|
||||
};
|
||||
|
||||
struct resolve_identifiers_visitor
|
||||
: type_visitor<resolve_identifiers_visitor>
|
||||
, expression_visitor<resolve_identifiers_visitor>
|
||||
, statement_visitor<resolve_identifiers_visitor>
|
||||
{
|
||||
std::vector<scope> scopes;
|
||||
std::vector<scope> & scopes;
|
||||
|
||||
using type_visitor::apply;
|
||||
using expression_visitor::apply;
|
||||
|
|
@ -246,8 +300,7 @@ namespace pslang::ast
|
|||
|
||||
void apply(function_definition const & function_definition)
|
||||
{
|
||||
if (scopes.back().contains(function_definition.name))
|
||||
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
|
||||
// Already added to scope by populate_globals_visitor
|
||||
|
||||
std::unordered_set<std::string> argument_names;
|
||||
for (auto const & argument : function_definition.arguments)
|
||||
|
|
@ -282,8 +335,7 @@ namespace pslang::ast
|
|||
|
||||
void apply(struct_definition const & struct_definition)
|
||||
{
|
||||
if (scopes.back().contains(struct_definition.name))
|
||||
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
|
||||
// Already added to scope by populate_globals_visitor
|
||||
|
||||
for (auto const & field : struct_definition.fields)
|
||||
apply(field);
|
||||
|
|
@ -291,14 +343,23 @@ namespace pslang::ast
|
|||
scopes.back().structs.insert(struct_definition.name);
|
||||
}
|
||||
|
||||
void apply(statement_list & statement_list)
|
||||
{
|
||||
populate_globals_visitor populate_globals_visitor{{}, scopes};
|
||||
populate_globals_visitor.apply(statement_list);
|
||||
|
||||
for (auto const & statement : statement_list.statements)
|
||||
apply(*statement);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
void resolve_identifiers(statement_list_ptr & statements)
|
||||
{
|
||||
resolve_identifiers_visitor visitor;
|
||||
visitor.scopes.emplace_back().is_global_scope = true;
|
||||
std::vector<scope> scopes;
|
||||
scopes.emplace_back().is_global_scope = true;
|
||||
resolve_identifiers_visitor visitor{{}, {}, {}, scopes};
|
||||
visitor.apply(*statements);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,16 +50,12 @@ namespace pslang::ast
|
|||
types::type_ptr expected_return_type = nullptr;
|
||||
};
|
||||
|
||||
struct check_visitor
|
||||
: type_visitor<check_visitor>
|
||||
, expression_visitor<check_visitor>
|
||||
, const_statement_visitor<check_visitor>
|
||||
struct resolve_types_visitor
|
||||
: type_visitor<resolve_types_visitor>
|
||||
{
|
||||
std::vector<scope> scopes;
|
||||
std::vector<scope> & scopes;
|
||||
|
||||
using type_visitor::apply;
|
||||
using expression_visitor::apply;
|
||||
using const_statement_visitor::apply;
|
||||
|
||||
void apply(types::unit_type const &)
|
||||
{}
|
||||
|
|
@ -98,6 +94,104 @@ namespace pslang::ast
|
|||
type.level = node.level;
|
||||
node.inferred_type = std::make_unique<types::type>(std::move(type));
|
||||
}
|
||||
};
|
||||
|
||||
void resolve_types(std::vector<scope> & scopes, ast::type & type)
|
||||
{
|
||||
resolve_types_visitor visitor{{}, scopes};
|
||||
visitor.apply(type);
|
||||
}
|
||||
|
||||
struct populate_globals_visitor
|
||||
: statement_visitor<populate_globals_visitor>
|
||||
{
|
||||
std::vector<scope> & scopes;
|
||||
|
||||
using statement_visitor::apply;
|
||||
|
||||
void apply(expression_ptr const &)
|
||||
{}
|
||||
|
||||
void apply(assignment const &)
|
||||
{}
|
||||
|
||||
void apply(variable_declaration const &)
|
||||
{}
|
||||
|
||||
void apply(if_block const &)
|
||||
{}
|
||||
|
||||
void apply(else_if_block const &)
|
||||
{}
|
||||
|
||||
void apply(else_block const &)
|
||||
{}
|
||||
|
||||
void apply(if_chain const &)
|
||||
{}
|
||||
|
||||
void apply(while_block const &)
|
||||
{}
|
||||
|
||||
void apply(function_definition const & node)
|
||||
{
|
||||
for (auto const & argument : node.arguments)
|
||||
resolve_types(scopes, *argument.type);
|
||||
resolve_types(scopes, *node.return_type);
|
||||
|
||||
auto & data = scopes.back().functions[node.name];
|
||||
for (auto const & argument : node.arguments)
|
||||
data.arguments.push_back(get_type(*argument.type));
|
||||
data.result_type = get_type(*node.return_type);
|
||||
|
||||
scopes.emplace_back().is_function_scope = true;
|
||||
scopes.back().expected_return_type = get_type(*node.return_type);
|
||||
|
||||
for (auto const & argument : node.arguments)
|
||||
{
|
||||
scopes.back().variables[argument.name] = {
|
||||
.category = value_category::constant,
|
||||
.type = get_type(*argument.type),
|
||||
};
|
||||
}
|
||||
|
||||
apply(*node.statements);
|
||||
scopes.pop_back();
|
||||
}
|
||||
|
||||
void apply(return_statement const &)
|
||||
{}
|
||||
|
||||
void apply(field_definition const & node)
|
||||
{
|
||||
resolve_types(scopes, *node.type);
|
||||
}
|
||||
|
||||
void apply(struct_definition const & node)
|
||||
{
|
||||
for (auto const & field : node.fields)
|
||||
apply(field);
|
||||
|
||||
auto & data = scopes.back().structs[node.name];
|
||||
for (auto const & field : node.fields)
|
||||
data.fields.push_back({.name = field.name, .type = get_type(*field.type)});
|
||||
}
|
||||
};
|
||||
|
||||
void populate_globals(std::vector<scope> & scopes, statement_list & statement_list)
|
||||
{
|
||||
populate_globals_visitor visitor{{}, scopes};
|
||||
visitor.apply(statement_list);
|
||||
}
|
||||
|
||||
struct check_visitor
|
||||
: expression_visitor<check_visitor>
|
||||
, const_statement_visitor<check_visitor>
|
||||
{
|
||||
std::vector<scope> & scopes;
|
||||
|
||||
using expression_visitor::apply;
|
||||
using const_statement_visitor::apply;
|
||||
|
||||
void apply(literal &)
|
||||
{}
|
||||
|
|
@ -277,7 +371,7 @@ namespace pslang::ast
|
|||
void apply(cast_operation & node)
|
||||
{
|
||||
apply(*node.expression);
|
||||
apply(*node.type);
|
||||
resolve_types(scopes, *node.type);
|
||||
|
||||
auto source_type = get_type(*node.expression);
|
||||
auto target_type = get_type(*node.type);
|
||||
|
|
@ -304,7 +398,7 @@ namespace pslang::ast
|
|||
if (node.function)
|
||||
apply(*node.function);
|
||||
if (node.type)
|
||||
apply(*node.type);
|
||||
resolve_types(scopes, *node.type);
|
||||
for (auto const & argument : node.arguments)
|
||||
apply(*argument);
|
||||
|
||||
|
|
@ -530,7 +624,7 @@ namespace pslang::ast
|
|||
auto actual_type = get_type(*node.initializer);
|
||||
if (node.type)
|
||||
{
|
||||
apply(*node.type);
|
||||
resolve_types(scopes, *node.type);
|
||||
auto expected_type = get_type(*node.type);
|
||||
if (!types::equal(*expected_type, *actual_type))
|
||||
{
|
||||
|
|
@ -606,14 +700,7 @@ namespace pslang::ast
|
|||
|
||||
void apply(function_definition const & node)
|
||||
{
|
||||
for (auto const & argument : node.arguments)
|
||||
apply(*argument.type);
|
||||
apply(*node.return_type);
|
||||
|
||||
auto & data = scopes.back().functions[node.name];
|
||||
for (auto const & argument : node.arguments)
|
||||
data.arguments.push_back(get_type(*argument.type));
|
||||
data.result_type = get_type(*node.return_type);
|
||||
// Already added to scope by populate_globals_visitor
|
||||
|
||||
scopes.emplace_back().is_function_scope = true;
|
||||
scopes.back().expected_return_type = get_type(*node.return_type);
|
||||
|
|
@ -658,19 +745,19 @@ namespace pslang::ast
|
|||
}
|
||||
}
|
||||
|
||||
void apply(field_definition const & node)
|
||||
{
|
||||
apply(*node.type);
|
||||
}
|
||||
void apply(field_definition const &)
|
||||
{}
|
||||
|
||||
void apply(struct_definition const & node)
|
||||
{
|
||||
for (auto const & field : node.fields)
|
||||
apply(field);
|
||||
// Already added to scope by populate_globals_visitor
|
||||
}
|
||||
|
||||
auto & data = scopes.back().structs[node.name];
|
||||
for (auto const & field : node.fields)
|
||||
data.fields.push_back({.name = field.name, .type = get_type(*field.type)});
|
||||
void apply(statement_list & node)
|
||||
{
|
||||
populate_globals(scopes, node);
|
||||
for (auto const & statement : node.statements)
|
||||
apply(*statement);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -678,8 +765,9 @@ namespace pslang::ast
|
|||
|
||||
void check_and_infer_types(statement_list_ptr & statements)
|
||||
{
|
||||
check_visitor visitor;
|
||||
visitor.scopes.emplace_back().is_global_scope = true;
|
||||
std::vector<scope> scopes;
|
||||
scopes.emplace_back().is_global_scope = true;
|
||||
check_visitor visitor{{}, {}, scopes};
|
||||
visitor.apply(*statements);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue