Refactor type checker: remove useless scoping, resolve type layouts all in one go

This commit is contained in:
Nikita Lisitsa 2026-03-22 23:49:29 +03:00
parent c4d1252462
commit ebc19fad20

View file

@ -22,16 +22,12 @@ namespace pslang::ast
bool layout_ready = false; bool layout_ready = false;
}; };
struct scope struct local_context
{ {
std::unordered_map<ast::struct_definition *, struct_data> structs; std::unordered_map<ast::struct_definition *, struct_data> structs;
bool is_function_scope = false;
types::type_ptr expected_return_type = nullptr;
}; };
void compute_layout(ast::struct_definition * node, struct_data & data, std::vector<scope> & scopes); void compute_layout(local_context & lcontext, ast::struct_definition * node, struct_data & data);
struct size_and_alignment struct size_and_alignment
{ {
@ -42,10 +38,10 @@ namespace pslang::ast
struct field_layout_visitor struct field_layout_visitor
: types::const_visitor<field_layout_visitor> : types::const_visitor<field_layout_visitor>
{ {
std::vector<scope> & scopes;
using const_visitor::apply; using const_visitor::apply;
local_context & lcontext;
size_and_alignment apply(types::unit_type const &) size_and_alignment apply(types::unit_type const &)
{ {
return {.size = 0, .alignment = 1}; return {.size = 0, .alignment = 1};
@ -70,24 +66,23 @@ namespace pslang::ast
size_and_alignment apply(types::struct_type const & type) size_and_alignment apply(types::struct_type const & type)
{ {
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) if (auto jt = lcontext.structs.find(type.node); jt != lcontext.structs.end())
if (auto jt = it->structs.find(type.node); jt != it->structs.end()) {
{ // TODO: better error message (including the resursive inclusion path)
// TODO: better error message (including the resursive inclusion path) if (!jt->second.layout_ready && jt->second.layout_being_computed)
if (!jt->second.layout_ready && jt->second.layout_being_computed) throw validation_error("Recursive structs are not allowed", jt->first->location);
throw validation_error("Recursive structs are not allowed", jt->first->location);
compute_layout(jt->first, jt->second, scopes); compute_layout(lcontext, jt->first, jt->second);
auto & layout = jt->first->layout; auto & layout = jt->first->layout;
return {.size = layout.size, .alignment = layout.alignment}; return {.size = layout.size, .alignment = layout.alignment};
} }
throw std::runtime_error("Unknown type \"" + type.node->name + "\""); throw std::runtime_error("Unknown type \"" + type.node->name + "\"");
} }
}; };
void compute_layout(ast::struct_definition * node, struct_data & data, std::vector<scope> & scopes) void compute_layout(local_context & lcontext, ast::struct_definition * node, struct_data & data)
{ {
if (data.layout_ready) if (data.layout_ready)
return; return;
@ -96,7 +91,7 @@ namespace pslang::ast
auto & layout = node->layout; auto & layout = node->layout;
for (std::size_t i = 0; i < node->fields.size(); ++i) for (std::size_t i = 0; i < node->fields.size(); ++i)
{ {
auto field_layout = field_layout_visitor{{}, scopes}.apply(*node->fields[i].inferred_type); auto field_layout = field_layout_visitor{{}, lcontext}.apply(*node->fields[i].inferred_type);
layout.alignment = std::max(layout.alignment, field_layout.alignment); layout.alignment = std::max(layout.alignment, field_layout.alignment);
layout.size = ((layout.size + field_layout.alignment - 1) / field_layout.alignment) * field_layout.alignment; layout.size = ((layout.size + field_layout.alignment - 1) / field_layout.alignment) * field_layout.alignment;
node->fields[i].layout.offset = layout.size; node->fields[i].layout.offset = layout.size;
@ -109,8 +104,6 @@ namespace pslang::ast
struct resolve_types_visitor struct resolve_types_visitor
: type_visitor<resolve_types_visitor> : type_visitor<resolve_types_visitor>
{ {
std::vector<scope> & scopes;
using type_visitor::apply; using type_visitor::apply;
void apply(types::unit_type const &) void apply(types::unit_type const &)
@ -149,16 +142,15 @@ namespace pslang::ast
} }
}; };
void resolve_types(std::vector<scope> & scopes, ast::type & type) void resolve_types(ast::type & type)
{ {
resolve_types_visitor visitor{{}, scopes}; resolve_types_visitor{}.apply(type);
visitor.apply(type);
} }
struct populate_globals_visitor struct populate_globals_visitor
: statement_visitor<populate_globals_visitor> : statement_visitor<populate_globals_visitor>
{ {
std::vector<scope> & scopes; local_context & lcontext;
using statement_visitor::apply; using statement_visitor::apply;
@ -180,33 +172,27 @@ namespace pslang::ast
void apply(function_definition & node) void apply(function_definition & node)
{ {
types::function_type function_type; types::function_type function_type;
resolve_types(scopes, *node.return_type); resolve_types(*node.return_type);
node.inferred_result_type = get_type(*node.return_type); node.inferred_result_type = get_type(*node.return_type);
function_type.result = node.inferred_result_type; function_type.result = node.inferred_result_type;
for (auto & argument : node.arguments) for (auto & argument : node.arguments)
{ {
resolve_types(scopes, *argument.type); resolve_types(*argument.type);
argument.inferred_type = get_type(*argument.type); argument.inferred_type = get_type(*argument.type);
function_type.arguments.push_back(argument.inferred_type); function_type.arguments.push_back(argument.inferred_type);
} }
node.inferred_function_type = std::make_unique<types::type>(std::move(function_type)); node.inferred_function_type = std::make_unique<types::type>(std::move(function_type));
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = node.inferred_result_type;
apply(*node.statements);
scopes.pop_back();
} }
void apply(foreign_function_declaration & node) void apply(foreign_function_declaration & node)
{ {
types::function_type function_type; types::function_type function_type;
resolve_types(scopes, *node.return_type); resolve_types(*node.return_type);
node.inferred_result_type = get_type(*node.return_type); node.inferred_result_type = get_type(*node.return_type);
function_type.result = node.inferred_result_type; function_type.result = node.inferred_result_type;
for (auto & argument : node.arguments) for (auto & argument : node.arguments)
{ {
resolve_types(scopes, *argument.type); resolve_types(*argument.type);
argument.inferred_type = get_type(*argument.type); argument.inferred_type = get_type(*argument.type);
function_type.arguments.push_back(argument.inferred_type); function_type.arguments.push_back(argument.inferred_type);
} }
@ -220,27 +206,26 @@ namespace pslang::ast
{ {
for (auto & field : node.fields) for (auto & field : node.fields)
{ {
resolve_types(scopes, *field.type); resolve_types(*field.type);
field.inferred_type = get_type(*field.type); field.inferred_type = get_type(*field.type);
} }
node.inferred_type = std::make_unique<types::type>(types::struct_type{&node}); node.inferred_type = std::make_unique<types::type>(types::struct_type{&node});
scopes.back().structs[&node] = {}; lcontext.structs[&node] = {};
} }
}; };
void populate_globals(std::vector<scope> & scopes, statement_list & statement_list) void populate_globals(local_context & lcontext, statement_list & statement_list)
{ {
populate_globals_visitor visitor{{}, scopes}; populate_globals_visitor{{}, lcontext}.apply(statement_list);
visitor.apply(statement_list);
} }
struct check_visitor struct check_visitor
: expression_visitor<check_visitor> : expression_visitor<check_visitor>
, statement_visitor<check_visitor> , statement_visitor<check_visitor>
{ {
std::vector<scope> & scopes; local_context & lcontext;
using expression_visitor::apply; using expression_visitor::apply;
using statement_visitor::apply; using statement_visitor::apply;
@ -432,7 +417,7 @@ namespace pslang::ast
void apply(cast_operation & node) void apply(cast_operation & node)
{ {
apply(*node.expression); apply(*node.expression);
resolve_types(scopes, *node.type); resolve_types(*node.type);
auto source_type = get_type(*node.expression); auto source_type = get_type(*node.expression);
auto target_type = get_type(*node.type); auto target_type = get_type(*node.type);
@ -459,7 +444,7 @@ namespace pslang::ast
if (node.function) if (node.function)
apply(*node.function); apply(*node.function);
if (node.type) if (node.type)
resolve_types(scopes, *node.type); resolve_types(*node.type);
for (auto const & argument : node.arguments) for (auto const & argument : node.arguments)
apply(*argument); apply(*argument);
@ -688,7 +673,7 @@ namespace pslang::ast
node.inferred_type = get_type(*node.initializer); node.inferred_type = get_type(*node.initializer);
if (node.type) if (node.type)
{ {
resolve_types(scopes, *node.type); resolve_types(*node.type);
auto expected_type = get_type(*node.type); auto expected_type = get_type(*node.type);
if (!types::equal(*expected_type, *node.inferred_type)) if (!types::equal(*expected_type, *node.inferred_type))
{ {
@ -719,9 +704,7 @@ namespace pslang::ast
} }
} }
scopes.emplace_back();
apply(*block.statements); apply(*block.statements);
scopes.pop_back();
} }
} }
@ -737,26 +720,16 @@ namespace pslang::ast
throw type_error(os.str(), get_location(*node.condition)); throw type_error(os.str(), get_location(*node.condition));
} }
scopes.emplace_back();
apply(*node.statements); apply(*node.statements);
scopes.pop_back();
} }
void apply(function_definition & node) void apply(function_definition & node)
{ {
// Already added to scope by populate_globals_visitor
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = node.inferred_result_type;
apply(*node.statements); apply(*node.statements);
scopes.pop_back();
} }
void apply(foreign_function_declaration & node) void apply(foreign_function_declaration & node)
{ {}
}
void apply(return_statement const & node) void apply(return_statement const & node)
{ {
@ -785,18 +758,13 @@ namespace pslang::ast
} }
void apply(struct_definition const & node) void apply(struct_definition const & node)
{ {}
// Already added to scope by populate_globals_visitor
}
void apply(statement_list & node) void apply(statement_list & node)
{ {
populate_globals(scopes, node); populate_globals(lcontext, node);
for (auto & statement : node.statements) for (auto & statement : node.statements)
apply(*statement); apply(*statement);
for (auto & struct_data : scopes.back().structs)
compute_layout(struct_data.first, struct_data.second, scopes);
} }
private: private:
@ -828,10 +796,12 @@ namespace pslang::ast
void check_and_infer_types(statement_list_ptr & statements) void check_and_infer_types(statement_list_ptr & statements)
{ {
std::vector<scope> scopes; local_context lcontext;
scopes.emplace_back(); check_visitor visitor{{}, {}, lcontext};
check_visitor visitor{{}, {}, scopes};
visitor.apply(*statements); visitor.apply(*statements);
for (auto & struct_data : lcontext.structs)
compute_layout(lcontext, struct_data.first, struct_data.second);
} }
} }