diff --git a/libs/ast/include/pslang/ast/struct.hpp b/libs/ast/include/pslang/ast/struct.hpp index cdc130a..e1da312 100644 --- a/libs/ast/include/pslang/ast/struct.hpp +++ b/libs/ast/include/pslang/ast/struct.hpp @@ -22,6 +22,7 @@ namespace pslang::ast ast::type_ptr type; ast::location location; + types::type_ptr inferred_type = nullptr; field_layout layout = {}; }; diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index e54cfca..4e99efa 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -30,17 +30,9 @@ namespace pslang::ast struct struct_data { - struct field_data - { - std::string name; - types::type_ptr type; - }; - ast::struct_definition * node; bool layout_being_computed = false; bool layout_ready = false; - - std::vector fields; }; struct scope @@ -115,11 +107,13 @@ namespace pslang::ast if (data.layout_ready) return; + auto & struct_node = *data.node; + data.layout_being_computed = true; auto & layout = data.node->layout; - for (std::size_t i = 0; i < data.fields.size(); ++i) + for (std::size_t i = 0; i < struct_node.fields.size(); ++i) { - auto field_layout = field_layout_visitor{{}, scopes}.apply(*data.fields[i].type); + auto field_layout = field_layout_visitor{{}, scopes}.apply(*struct_node.fields[i].inferred_type); layout.alignment = std::max(layout.alignment, field_layout.alignment); layout.size = ((layout.size + field_layout.alignment - 1) / field_layout.alignment) * field_layout.alignment; data.node->fields[i].layout.offset = layout.size; @@ -253,20 +247,18 @@ namespace pslang::ast void apply(return_statement const &) {} - void apply(field_definition const & node) + void apply(field_definition & node) { resolve_types(scopes, *node.type); + node.inferred_type = get_type(*node.type); } void apply(struct_definition & node) { - for (auto const & field : node.fields) + for (auto & field : node.fields) apply(field); - auto & data = scopes.back().structs[node.name]; - data.node = &node; - for (auto const & field : node.fields) - data.fields.push_back({.name = field.name, .type = get_type(*field.type)}); + scopes.back().structs[node.name].node = &node; } }; @@ -574,25 +566,25 @@ namespace pslang::ast else if (auto named_type = std::get_if(type.get())) { auto const & scope = scopes.at(named_type->level); - auto const & data = scope.structs.at(named_type->name); + auto const & struct_node = *scope.structs.at(named_type->name).node; if (!node.arguments.empty()) { - if (node.arguments.size() != data.fields.size()) + if (node.arguments.size() != struct_node.fields.size()) { std::ostringstream os; - os << "Cannot create struct " << named_type->name << ": expected " << data.fields.size() << " arguments, but got " << node.arguments.size(); + os << "Cannot create struct " << named_type->name << ": expected " << struct_node.fields.size() << " arguments, but got " << node.arguments.size(); throw type_error(os.str(), node.location); } for (std::size_t i = 0; i < node.arguments.size(); ++i) { auto arg_type = get_type(*node.arguments[i]); - if (!types::equal(*arg_type, *data.fields[i].type)) + if (!types::equal(*arg_type, *struct_node.fields[i].inferred_type)) { std::ostringstream os; os << "Cannot create struct " << named_type->name << ": argument #" << i << " expected to have type "; - types::print(os, *data.fields[i].type); + types::print(os, *struct_node.fields[i].inferred_type); os << " but got type "; types::print(os, *arg_type); throw type_error(os.str(), node.location); @@ -685,13 +677,13 @@ namespace pslang::ast throw type_error(os.str(), get_location(*node.object)); } - auto const & struct_data = scopes.at(named_type->level).structs.at(named_type->name); + auto const & struct_node = *scopes.at(named_type->level).structs.at(named_type->name).node; - for (auto const & field : struct_data.fields) + for (auto const & field : struct_node.fields) { if (field.name == node.field_name) { - node.inferred_type = field.type; + node.inferred_type = field.inferred_type; return; } }