diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 49de951..bbb5d40 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -22,16 +22,12 @@ namespace pslang::ast bool layout_ready = false; }; - struct scope + struct local_context { std::unordered_map structs; - - bool is_function_scope = false; - - types::type_ptr expected_return_type = nullptr; }; - void compute_layout(ast::struct_definition * node, struct_data & data, std::vector & scopes); + void compute_layout(local_context & lcontext, ast::struct_definition * node, struct_data & data); struct size_and_alignment { @@ -42,10 +38,10 @@ namespace pslang::ast struct field_layout_visitor : types::const_visitor { - std::vector & scopes; - using const_visitor::apply; + local_context & lcontext; + size_and_alignment apply(types::unit_type const &) { return {.size = 0, .alignment = 1}; @@ -70,24 +66,23 @@ namespace pslang::ast size_and_alignment apply(types::struct_type const & type) { - for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) - if (auto jt = it->structs.find(type.node); jt != it->structs.end()) - { - // TODO: better error message (including the resursive inclusion path) - if (!jt->second.layout_ready && jt->second.layout_being_computed) - throw validation_error("Recursive structs are not allowed", jt->first->location); + if (auto jt = lcontext.structs.find(type.node); jt != lcontext.structs.end()) + { + // TODO: better error message (including the resursive inclusion path) + if (!jt->second.layout_ready && jt->second.layout_being_computed) + throw validation_error("Recursive structs are not allowed", jt->first->location); - compute_layout(jt->first, jt->second, scopes); + compute_layout(lcontext, jt->first, jt->second); - auto & layout = jt->first->layout; - return {.size = layout.size, .alignment = layout.alignment}; - } + auto & layout = jt->first->layout; + return {.size = layout.size, .alignment = layout.alignment}; + } throw std::runtime_error("Unknown type \"" + type.node->name + "\""); } }; - void compute_layout(ast::struct_definition * node, struct_data & data, std::vector & scopes) + void compute_layout(local_context & lcontext, ast::struct_definition * node, struct_data & data) { if (data.layout_ready) return; @@ -96,7 +91,7 @@ namespace pslang::ast auto & layout = node->layout; for (std::size_t i = 0; i < node->fields.size(); ++i) { - auto field_layout = field_layout_visitor{{}, scopes}.apply(*node->fields[i].inferred_type); + auto field_layout = field_layout_visitor{{}, lcontext}.apply(*node->fields[i].inferred_type); layout.alignment = std::max(layout.alignment, field_layout.alignment); layout.size = ((layout.size + field_layout.alignment - 1) / field_layout.alignment) * field_layout.alignment; node->fields[i].layout.offset = layout.size; @@ -109,8 +104,6 @@ namespace pslang::ast struct resolve_types_visitor : type_visitor { - std::vector & scopes; - using type_visitor::apply; void apply(types::unit_type const &) @@ -149,16 +142,15 @@ namespace pslang::ast } }; - void resolve_types(std::vector & scopes, ast::type & type) + void resolve_types(ast::type & type) { - resolve_types_visitor visitor{{}, scopes}; - visitor.apply(type); + resolve_types_visitor{}.apply(type); } struct populate_globals_visitor : statement_visitor { - std::vector & scopes; + local_context & lcontext; using statement_visitor::apply; @@ -180,33 +172,27 @@ namespace pslang::ast void apply(function_definition & node) { types::function_type function_type; - resolve_types(scopes, *node.return_type); + resolve_types(*node.return_type); node.inferred_result_type = get_type(*node.return_type); function_type.result = node.inferred_result_type; for (auto & argument : node.arguments) { - resolve_types(scopes, *argument.type); + resolve_types(*argument.type); argument.inferred_type = get_type(*argument.type); function_type.arguments.push_back(argument.inferred_type); } node.inferred_function_type = std::make_unique(std::move(function_type)); - - scopes.emplace_back().is_function_scope = true; - scopes.back().expected_return_type = node.inferred_result_type; - - apply(*node.statements); - scopes.pop_back(); } void apply(foreign_function_declaration & node) { types::function_type function_type; - resolve_types(scopes, *node.return_type); + resolve_types(*node.return_type); node.inferred_result_type = get_type(*node.return_type); function_type.result = node.inferred_result_type; for (auto & argument : node.arguments) { - resolve_types(scopes, *argument.type); + resolve_types(*argument.type); argument.inferred_type = get_type(*argument.type); function_type.arguments.push_back(argument.inferred_type); } @@ -220,27 +206,26 @@ namespace pslang::ast { for (auto & field : node.fields) { - resolve_types(scopes, *field.type); + resolve_types(*field.type); field.inferred_type = get_type(*field.type); } node.inferred_type = std::make_unique(types::struct_type{&node}); - scopes.back().structs[&node] = {}; + lcontext.structs[&node] = {}; } }; - void populate_globals(std::vector & scopes, statement_list & statement_list) + void populate_globals(local_context & lcontext, statement_list & statement_list) { - populate_globals_visitor visitor{{}, scopes}; - visitor.apply(statement_list); + populate_globals_visitor{{}, lcontext}.apply(statement_list); } struct check_visitor : expression_visitor , statement_visitor { - std::vector & scopes; + local_context & lcontext; using expression_visitor::apply; using statement_visitor::apply; @@ -432,7 +417,7 @@ namespace pslang::ast void apply(cast_operation & node) { apply(*node.expression); - resolve_types(scopes, *node.type); + resolve_types(*node.type); auto source_type = get_type(*node.expression); auto target_type = get_type(*node.type); @@ -459,7 +444,7 @@ namespace pslang::ast if (node.function) apply(*node.function); if (node.type) - resolve_types(scopes, *node.type); + resolve_types(*node.type); for (auto const & argument : node.arguments) apply(*argument); @@ -688,7 +673,7 @@ namespace pslang::ast node.inferred_type = get_type(*node.initializer); if (node.type) { - resolve_types(scopes, *node.type); + resolve_types(*node.type); auto expected_type = get_type(*node.type); if (!types::equal(*expected_type, *node.inferred_type)) { @@ -719,9 +704,7 @@ namespace pslang::ast } } - scopes.emplace_back(); apply(*block.statements); - scopes.pop_back(); } } @@ -737,26 +720,16 @@ namespace pslang::ast throw type_error(os.str(), get_location(*node.condition)); } - scopes.emplace_back(); apply(*node.statements); - scopes.pop_back(); } void apply(function_definition & node) { - // Already added to scope by populate_globals_visitor - - scopes.emplace_back().is_function_scope = true; - scopes.back().expected_return_type = node.inferred_result_type; - apply(*node.statements); - scopes.pop_back(); } void apply(foreign_function_declaration & node) - { - - } + {} void apply(return_statement const & node) { @@ -785,18 +758,13 @@ namespace pslang::ast } void apply(struct_definition const & node) - { - // Already added to scope by populate_globals_visitor - } + {} void apply(statement_list & node) { - populate_globals(scopes, node); + populate_globals(lcontext, node); for (auto & statement : node.statements) apply(*statement); - - for (auto & struct_data : scopes.back().structs) - compute_layout(struct_data.first, struct_data.second, scopes); } private: @@ -828,10 +796,12 @@ namespace pslang::ast void check_and_infer_types(statement_list_ptr & statements) { - std::vector scopes; - scopes.emplace_back(); - check_visitor visitor{{}, {}, scopes}; + local_context lcontext; + check_visitor visitor{{}, {}, lcontext}; visitor.apply(*statements); + + for (auto & struct_data : lcontext.structs) + compute_layout(lcontext, struct_data.first, struct_data.second); } }