Refactor type checker: remove useless scoping, resolve type layouts all in one go

This commit is contained in:
Nikita Lisitsa 2026-03-22 23:49:29 +03:00
parent c4d1252462
commit ebc19fad20

View file

@ -22,16 +22,12 @@ namespace pslang::ast
bool layout_ready = false;
};
struct scope
struct local_context
{
std::unordered_map<ast::struct_definition *, struct_data> structs;
bool is_function_scope = false;
types::type_ptr expected_return_type = nullptr;
};
void compute_layout(ast::struct_definition * node, struct_data & data, std::vector<scope> & scopes);
void compute_layout(local_context & lcontext, ast::struct_definition * node, struct_data & data);
struct size_and_alignment
{
@ -42,10 +38,10 @@ namespace pslang::ast
struct field_layout_visitor
: types::const_visitor<field_layout_visitor>
{
std::vector<scope> & scopes;
using const_visitor::apply;
local_context & lcontext;
size_and_alignment apply(types::unit_type const &)
{
return {.size = 0, .alignment = 1};
@ -70,24 +66,23 @@ namespace pslang::ast
size_and_alignment apply(types::struct_type const & type)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
if (auto jt = it->structs.find(type.node); jt != it->structs.end())
{
// TODO: better error message (including the resursive inclusion path)
if (!jt->second.layout_ready && jt->second.layout_being_computed)
throw validation_error("Recursive structs are not allowed", jt->first->location);
if (auto jt = lcontext.structs.find(type.node); jt != lcontext.structs.end())
{
// TODO: better error message (including the resursive inclusion path)
if (!jt->second.layout_ready && jt->second.layout_being_computed)
throw validation_error("Recursive structs are not allowed", jt->first->location);
compute_layout(jt->first, jt->second, scopes);
compute_layout(lcontext, jt->first, jt->second);
auto & layout = jt->first->layout;
return {.size = layout.size, .alignment = layout.alignment};
}
auto & layout = jt->first->layout;
return {.size = layout.size, .alignment = layout.alignment};
}
throw std::runtime_error("Unknown type \"" + type.node->name + "\"");
}
};
void compute_layout(ast::struct_definition * node, struct_data & data, std::vector<scope> & scopes)
void compute_layout(local_context & lcontext, ast::struct_definition * node, struct_data & data)
{
if (data.layout_ready)
return;
@ -96,7 +91,7 @@ namespace pslang::ast
auto & layout = node->layout;
for (std::size_t i = 0; i < node->fields.size(); ++i)
{
auto field_layout = field_layout_visitor{{}, scopes}.apply(*node->fields[i].inferred_type);
auto field_layout = field_layout_visitor{{}, lcontext}.apply(*node->fields[i].inferred_type);
layout.alignment = std::max(layout.alignment, field_layout.alignment);
layout.size = ((layout.size + field_layout.alignment - 1) / field_layout.alignment) * field_layout.alignment;
node->fields[i].layout.offset = layout.size;
@ -109,8 +104,6 @@ namespace pslang::ast
struct resolve_types_visitor
: type_visitor<resolve_types_visitor>
{
std::vector<scope> & scopes;
using type_visitor::apply;
void apply(types::unit_type const &)
@ -149,16 +142,15 @@ namespace pslang::ast
}
};
void resolve_types(std::vector<scope> & scopes, ast::type & type)
void resolve_types(ast::type & type)
{
resolve_types_visitor visitor{{}, scopes};
visitor.apply(type);
resolve_types_visitor{}.apply(type);
}
struct populate_globals_visitor
: statement_visitor<populate_globals_visitor>
{
std::vector<scope> & scopes;
local_context & lcontext;
using statement_visitor::apply;
@ -180,33 +172,27 @@ namespace pslang::ast
void apply(function_definition & node)
{
types::function_type function_type;
resolve_types(scopes, *node.return_type);
resolve_types(*node.return_type);
node.inferred_result_type = get_type(*node.return_type);
function_type.result = node.inferred_result_type;
for (auto & argument : node.arguments)
{
resolve_types(scopes, *argument.type);
resolve_types(*argument.type);
argument.inferred_type = get_type(*argument.type);
function_type.arguments.push_back(argument.inferred_type);
}
node.inferred_function_type = std::make_unique<types::type>(std::move(function_type));
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = node.inferred_result_type;
apply(*node.statements);
scopes.pop_back();
}
void apply(foreign_function_declaration & node)
{
types::function_type function_type;
resolve_types(scopes, *node.return_type);
resolve_types(*node.return_type);
node.inferred_result_type = get_type(*node.return_type);
function_type.result = node.inferred_result_type;
for (auto & argument : node.arguments)
{
resolve_types(scopes, *argument.type);
resolve_types(*argument.type);
argument.inferred_type = get_type(*argument.type);
function_type.arguments.push_back(argument.inferred_type);
}
@ -220,27 +206,26 @@ namespace pslang::ast
{
for (auto & field : node.fields)
{
resolve_types(scopes, *field.type);
resolve_types(*field.type);
field.inferred_type = get_type(*field.type);
}
node.inferred_type = std::make_unique<types::type>(types::struct_type{&node});
scopes.back().structs[&node] = {};
lcontext.structs[&node] = {};
}
};
void populate_globals(std::vector<scope> & scopes, statement_list & statement_list)
void populate_globals(local_context & lcontext, statement_list & statement_list)
{
populate_globals_visitor visitor{{}, scopes};
visitor.apply(statement_list);
populate_globals_visitor{{}, lcontext}.apply(statement_list);
}
struct check_visitor
: expression_visitor<check_visitor>
, statement_visitor<check_visitor>
{
std::vector<scope> & scopes;
local_context & lcontext;
using expression_visitor::apply;
using statement_visitor::apply;
@ -432,7 +417,7 @@ namespace pslang::ast
void apply(cast_operation & node)
{
apply(*node.expression);
resolve_types(scopes, *node.type);
resolve_types(*node.type);
auto source_type = get_type(*node.expression);
auto target_type = get_type(*node.type);
@ -459,7 +444,7 @@ namespace pslang::ast
if (node.function)
apply(*node.function);
if (node.type)
resolve_types(scopes, *node.type);
resolve_types(*node.type);
for (auto const & argument : node.arguments)
apply(*argument);
@ -688,7 +673,7 @@ namespace pslang::ast
node.inferred_type = get_type(*node.initializer);
if (node.type)
{
resolve_types(scopes, *node.type);
resolve_types(*node.type);
auto expected_type = get_type(*node.type);
if (!types::equal(*expected_type, *node.inferred_type))
{
@ -719,9 +704,7 @@ namespace pslang::ast
}
}
scopes.emplace_back();
apply(*block.statements);
scopes.pop_back();
}
}
@ -737,26 +720,16 @@ namespace pslang::ast
throw type_error(os.str(), get_location(*node.condition));
}
scopes.emplace_back();
apply(*node.statements);
scopes.pop_back();
}
void apply(function_definition & node)
{
// Already added to scope by populate_globals_visitor
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = node.inferred_result_type;
apply(*node.statements);
scopes.pop_back();
}
void apply(foreign_function_declaration & node)
{
}
{}
void apply(return_statement const & node)
{
@ -785,18 +758,13 @@ namespace pslang::ast
}
void apply(struct_definition const & node)
{
// Already added to scope by populate_globals_visitor
}
{}
void apply(statement_list & node)
{
populate_globals(scopes, node);
populate_globals(lcontext, node);
for (auto & statement : node.statements)
apply(*statement);
for (auto & struct_data : scopes.back().structs)
compute_layout(struct_data.first, struct_data.second, scopes);
}
private:
@ -828,10 +796,12 @@ namespace pslang::ast
void check_and_infer_types(statement_list_ptr & statements)
{
std::vector<scope> scopes;
scopes.emplace_back();
check_visitor visitor{{}, {}, scopes};
local_context lcontext;
check_visitor visitor{{}, {}, lcontext};
visitor.apply(*statements);
for (auto & struct_data : lcontext.structs)
compute_layout(lcontext, struct_data.first, struct_data.second);
}
}