Refactor AST: store direct references to AST nodes in indentifiers

This commit is contained in:
Nikita Lisitsa 2026-03-22 20:02:40 +03:00
parent 292d6eeabf
commit 5d00f1ddb7
13 changed files with 191 additions and 182 deletions

View file

@ -4,6 +4,7 @@
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/ast/location.hpp>
#include <pslang/ast/type_fwd.hpp>
#include <pslang/ast/variable.hpp>
#include <pslang/types/type_fwd.hpp>
#include <string>
@ -14,17 +15,14 @@ namespace pslang::ast
struct function_declaration
{
struct argument
{
std::string name;
type_ptr type;
ast::location location;
};
using argument = variable_base;
std::string name;
std::vector<argument> arguments;
type_ptr return_type;
ast::location prelude_location;
types::type_ptr inferred_result_type = nullptr;
types::type_ptr inferred_function_type = nullptr;
};
struct function_definition

View file

@ -8,11 +8,17 @@
namespace pslang::ast
{
struct variable_base;
struct function_definition;
struct foreign_function_declaration;
struct identifier
{
std::string name;
ast::location location;
std::size_t level = 0;
variable_base * variable_node = nullptr;
function_definition * function_node = nullptr;
foreign_function_declaration * foreign_function_node = nullptr;
types::type_ptr inferred_type = nullptr;
};

View file

@ -2,27 +2,17 @@
#include <pslang/ast/location.hpp>
#include <pslang/ast/expression.hpp>
#include <pslang/ast/value_category.hpp>
#include <pslang/ast/variable.hpp>
#include <pslang/ast/control.hpp>
#include <pslang/ast/function.hpp>
#include <pslang/ast/struct.hpp>
#include <pslang/ast/statement_fwd.hpp>
#include <pslang/ast/type.hpp>
#include <variant>
namespace pslang::ast
{
struct variable_declaration
{
value_category category;
std::string name;
type_ptr type;
expression_ptr initializer;
ast::location location;
};
struct assignment
{
expression_ptr lhs;

View file

@ -39,6 +39,7 @@ namespace pslang::ast
ast::location prelude_location;
ast::location location;
types::type_ptr inferred_type = nullptr;
struct_layout layout = {};
};

View file

@ -13,6 +13,8 @@
namespace pslang::ast
{
struct struct_definition;
struct array_type
{
type_ptr element_type;
@ -31,7 +33,7 @@ namespace pslang::ast
{
std::string name;
ast::location location;
std::size_t level = 0;
struct_definition * node = nullptr;
types::type_ptr inferred_type = nullptr;
};

View file

@ -0,0 +1,27 @@
#pragma once
#include <pslang/ast/value_category.hpp>
#include <pslang/ast/type.hpp>
#include <pslang/ast/expression_fwd.hpp>
#include <string>
namespace pslang::ast
{
struct variable_base
{
value_category category;
std::string name;
type_ptr type;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
struct variable_declaration
: variable_base
{
expression_ptr initializer;
};
}

View file

@ -7,6 +7,7 @@
#include <pslang/types/type.hpp>
#include <unordered_set>
#include <unordered_map>
#include <vector>
namespace pslang::ast
@ -17,14 +18,16 @@ namespace pslang::ast
struct scope
{
std::unordered_set<std::string> functions;
std::unordered_set<std::string> structs;
std::unordered_set<std::string> variables;
std::unordered_map<std::string, function_definition *> functions;
std::unordered_map<std::string, foreign_function_declaration *> foreign_functions;
std::unordered_map<std::string, struct_definition *> structs;
std::unordered_map<std::string, variable_base *> variables;
bool contains_transitive(std::string const & name) const
{
return false
|| (functions.count(name) > 0)
|| (foreign_functions.count(name) > 0)
|| (structs.count(name) > 0)
;
}
@ -40,6 +43,7 @@ namespace pslang::ast
{
return false
|| (functions.count(name) > 0)
|| (foreign_functions.count(name) > 0)
|| (structs.count(name) > 0)
|| (variables.count(name) > 0)
;
@ -77,31 +81,31 @@ namespace pslang::ast
void apply(while_block const &)
{}
void apply(function_definition const & function_definition)
void apply(function_definition & function_definition)
{
if (scopes.back().contains(function_definition.name))
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
scopes.back().functions.insert(function_definition.name);
scopes.back().functions[function_definition.name] = &function_definition;
}
void apply(foreign_function_declaration const foreign_function_declaration)
void apply(foreign_function_declaration & foreign_function_declaration)
{
if (scopes.back().contains(foreign_function_declaration.name))
throw parse_error("Identifier \"" + foreign_function_declaration.name + "\" is already defined at this scope", foreign_function_declaration.location);
scopes.back().functions.insert(foreign_function_declaration.name);
scopes.back().foreign_functions[foreign_function_declaration.name] = &foreign_function_declaration;
}
void apply(return_statement const &)
{}
void apply(struct_definition const & struct_definition)
void apply(struct_definition & struct_definition)
{
if (scopes.back().contains(struct_definition.name))
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
scopes.back().structs.insert(struct_definition.name);
scopes.back().structs[struct_definition.name] = &struct_definition;
}
};
@ -138,9 +142,9 @@ namespace pslang::ast
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->contains_type(identifier.name))
if (auto jt = it->structs.find(identifier.name); jt != it->structs.end())
{
identifier.level = it.base() - scopes.begin() - 1;
identifier.node = jt->second;
return;
}
}
@ -153,18 +157,33 @@ namespace pslang::ast
void apply(identifier & identifier)
{
if (types::builtin_type(identifier.name))
return;
// NB: cannot be a type
// The case of type constructors is resolved earlier in function_call node
bool crossed_function_scope = false;
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->contains(identifier.name, crossed_function_scope))
if (auto jt = it->functions.find(identifier.name); jt != it->functions.end())
{
identifier.level = it.base() - scopes.begin() - 1;
identifier.function_node = jt->second;
return;
}
if (auto jt = it->foreign_functions.find(identifier.name); jt != it->foreign_functions.end())
{
identifier.foreign_function_node = jt->second;
return;
}
if (!crossed_function_scope)
{
if (auto jt = it->variables.find(identifier.name); jt != it->variables.end())
{
identifier.variable_node = jt->second;
return;
}
}
crossed_function_scope |= it->is_function_scope;
}
@ -190,13 +209,6 @@ namespace pslang::ast
void apply(function_call & function_call)
{
if (function_call.function)
apply(*function_call.function);
if (function_call.type)
apply(*function_call.type);
for (auto const & argument : function_call.arguments)
apply(*argument);
if (auto id = std::get_if<identifier>(function_call.function.get()))
{
if (auto type = types::builtin_type(id->name))
@ -210,12 +222,26 @@ namespace pslang::ast
function_call.function = nullptr;
}
else if (scopes.at(id->level).structs.contains(id->name))
else
{
function_call.type = std::make_unique<ast::type>(type_identifier{.name = id->name, .location = id->location, .level = id->level});
function_call.function = nullptr;
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->structs.find(id->name); jt != it->structs.end())
{
function_call.type = std::make_unique<ast::type>(type_identifier{.name = id->name, .location = id->location, .node = jt->second});
function_call.function = nullptr;
break;
}
}
}
}
if (function_call.function)
apply(*function_call.function);
if (function_call.type)
apply(*function_call.type);
for (auto const & argument : function_call.arguments)
apply(*argument);
}
void apply(array const & array)
@ -246,7 +272,7 @@ namespace pslang::ast
apply(assignment.rhs);
}
void apply(variable_declaration const & variable_declaration)
void apply(variable_declaration & variable_declaration)
{
if (scopes.back().contains(variable_declaration.name))
throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location);
@ -255,7 +281,7 @@ namespace pslang::ast
apply(*variable_declaration.type);
apply(*variable_declaration.initializer);
scopes.back().variables.insert(variable_declaration.name);
scopes.back().variables[variable_declaration.name] = &variable_declaration;
}
void apply(if_chain const & if_chain)
@ -278,26 +304,24 @@ namespace pslang::ast
scopes.pop_back();
}
void apply(function_definition const & function_definition)
void apply(function_definition & function_definition)
{
// Already added to scope by populate_globals_visitor
std::unordered_set<std::string> argument_names;
for (auto const & argument : function_definition.arguments)
std::unordered_map<std::string, variable_base *> arguments;
for (auto & argument : function_definition.arguments)
{
if (argument_names.count(argument.name) > 0)
if (arguments.count(argument.name) > 0)
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location);
argument_names.insert(argument.name);
arguments[argument.name] = &argument;
apply(*argument.type);
}
apply(*function_definition.return_type);
scopes.back().functions.insert(function_definition.name);
auto & scope = scopes.emplace_back();
scope.is_function_scope = true;
scope.variables = std::move(argument_names);
scope.variables = std::move(arguments);
apply(*function_definition.statements);
scopes.pop_back();
}
@ -324,14 +348,14 @@ namespace pslang::ast
apply(*return_statement.value);
}
void apply(struct_definition const & struct_definition)
void apply(struct_definition & struct_definition)
{
// Already added to scope by populate_globals_visitor
for (auto const & field : struct_definition.fields)
apply(*field.type);
scopes.back().structs.insert(struct_definition.name);
scopes.back().structs[struct_definition.name] = &struct_definition;
}
void apply(statement_list & statement_list)

View file

@ -16,37 +16,22 @@ namespace pslang::ast
namespace
{
struct variable_data
{
value_category category;
types::type_ptr type;
};
struct function_data
{
std::vector<types::type_ptr> arguments;
types::type_ptr result_type;
};
struct struct_data
{
ast::struct_definition * node;
bool layout_being_computed = false;
bool layout_ready = false;
};
struct scope
{
std::unordered_map<std::string, variable_data> variables;
std::unordered_map<std::string, function_data> functions;
std::unordered_map<std::string, struct_data> structs;
std::unordered_map<ast::struct_definition *, struct_data> structs;
bool is_function_scope = false;
types::type_ptr expected_return_type = nullptr;
};
void compute_layout(struct_data & data, std::vector<scope> & scopes);
void compute_layout(ast::struct_definition * node, struct_data & data, std::vector<scope> & scopes);
struct size_and_alignment
{
@ -86,15 +71,15 @@ namespace pslang::ast
size_and_alignment apply(types::struct_type const & type)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
if (auto jt = it->structs.find(type.node->name); jt != it->structs.end())
if (auto jt = it->structs.find(type.node); jt != it->structs.end())
{
// TODO: better error message (including the resursive inclusion path)
if (!jt->second.layout_ready && jt->second.layout_being_computed)
throw validation_error("Recursive structs are not allowed", jt->second.node->location);
throw validation_error("Recursive structs are not allowed", jt->first->location);
compute_layout(jt->second, scopes);
compute_layout(jt->first, jt->second, scopes);
auto & layout = jt->second.node->layout;
auto & layout = jt->first->layout;
return {.size = layout.size, .alignment = layout.alignment};
}
@ -102,21 +87,19 @@ namespace pslang::ast
}
};
void compute_layout(struct_data & data, std::vector<scope> & scopes)
void compute_layout(ast::struct_definition * node, struct_data & data, std::vector<scope> & scopes)
{
if (data.layout_ready)
return;
auto & struct_node = *data.node;
data.layout_being_computed = true;
auto & layout = data.node->layout;
for (std::size_t i = 0; i < struct_node.fields.size(); ++i)
auto & layout = node->layout;
for (std::size_t i = 0; i < node->fields.size(); ++i)
{
auto field_layout = field_layout_visitor{{}, scopes}.apply(*struct_node.fields[i].inferred_type);
auto field_layout = field_layout_visitor{{}, scopes}.apply(*node->fields[i].inferred_type);
layout.alignment = std::max(layout.alignment, field_layout.alignment);
layout.size = ((layout.size + field_layout.alignment - 1) / field_layout.alignment) * field_layout.alignment;
data.node->fields[i].layout.offset = layout.size;
node->fields[i].layout.offset = layout.size;
layout.size += field_layout.size;
}
layout.size = ((layout.size + layout.alignment - 1) / layout.alignment) * layout.alignment;
@ -162,16 +145,7 @@ namespace pslang::ast
void apply(ast::type_identifier & node)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->structs.find(node.name); jt != it->structs.end())
{
node.inferred_type = std::make_unique<types::type>(types::struct_type{jt->second.node});
return;
}
}
throw std::runtime_error(std::format("Unknown type \"{}\"", node.name));
node.inferred_type = std::make_unique<types::type>(types::struct_type{node.node});
}
};
@ -203,42 +177,40 @@ namespace pslang::ast
void apply(while_block const &)
{}
void apply(function_definition const & node)
void apply(function_definition & node)
{
for (auto const & argument : node.arguments)
resolve_types(scopes, *argument.type);
types::function_type function_type;
resolve_types(scopes, *node.return_type);
auto & data = scopes.back().functions[node.name];
for (auto const & argument : node.arguments)
data.arguments.push_back(get_type(*argument.type));
data.result_type = get_type(*node.return_type);
node.inferred_result_type = get_type(*node.return_type);
function_type.result = node.inferred_result_type;
for (auto & argument : node.arguments)
{
resolve_types(scopes, *argument.type);
argument.inferred_type = get_type(*argument.type);
function_type.arguments.push_back(argument.inferred_type);
}
node.inferred_function_type = std::make_unique<types::type>(std::move(function_type));
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = get_type(*node.return_type);
for (auto const & argument : node.arguments)
{
scopes.back().variables[argument.name] = {
.category = value_category::constant,
.type = get_type(*argument.type),
};
}
scopes.back().expected_return_type = node.inferred_result_type;
apply(*node.statements);
scopes.pop_back();
}
void apply(foreign_function_declaration const & node)
void apply(foreign_function_declaration & node)
{
for (auto const & argument : node.arguments)
resolve_types(scopes, *argument.type);
types::function_type function_type;
resolve_types(scopes, *node.return_type);
auto & data = scopes.back().functions[node.name];
for (auto const & argument : node.arguments)
data.arguments.push_back(get_type(*argument.type));
data.result_type = get_type(*node.return_type);
node.inferred_result_type = get_type(*node.return_type);
function_type.result = node.inferred_result_type;
for (auto & argument : node.arguments)
{
resolve_types(scopes, *argument.type);
argument.inferred_type = get_type(*argument.type);
function_type.arguments.push_back(argument.inferred_type);
}
node.inferred_function_type = std::make_unique<types::type>(std::move(function_type));
}
void apply(return_statement const &)
@ -252,7 +224,9 @@ namespace pslang::ast
field.inferred_type = get_type(*field.type);
}
scopes.back().structs[node.name].node = &node;
node.inferred_type = std::make_unique<types::type>(types::struct_type{&node});
scopes.back().structs[&node] = {};
}
};
@ -264,37 +238,32 @@ namespace pslang::ast
struct check_visitor
: expression_visitor<check_visitor>
, const_statement_visitor<check_visitor>
, statement_visitor<check_visitor>
{
std::vector<scope> & scopes;
using expression_visitor::apply;
using const_statement_visitor::apply;
using statement_visitor::apply;
void apply(literal &)
{}
void apply(identifier & node)
{
if (auto type = types::builtin_type(node.name))
if (node.variable_node)
{
node.inferred_type = type;
return;
node.inferred_type = node.variable_node->inferred_type;
}
auto & scope = scopes.at(node.level);
if (auto it = scope.variables.find(node.name); it != scope.variables.end())
else if (node.function_node)
{
node.inferred_type = it->second.type;
node.inferred_type = node.function_node->inferred_function_type;
}
else if (auto it = scope.functions.find(node.name); it != scope.functions.end())
else if (node.foreign_function_node)
{
auto type = types::function_type{};
for (auto const & argument : it->second.arguments)
type.arguments.push_back(argument);
type.result = it->second.result_type;
node.inferred_type = std::make_unique<types::type>(std::move(type));
node.inferred_type = node.foreign_function_node->inferred_function_type;
}
else
throw invalid_ast_error("Identifier node without a variable/function/foreign function reference", node.location);
}
void apply(unary_operation & node)
@ -702,7 +671,6 @@ namespace pslang::ast
auto ltype = get_type(*node.lhs);
auto rtype = get_type(*node.rhs);
// TODO: check lvalue
if (!types::equal(*ltype, *rtype))
{
std::ostringstream os;
@ -714,29 +682,24 @@ namespace pslang::ast
};
}
void apply(variable_declaration const & node)
void apply(variable_declaration & node)
{
apply(node.initializer);
auto actual_type = get_type(*node.initializer);
node.inferred_type = get_type(*node.initializer);
if (node.type)
{
resolve_types(scopes, *node.type);
auto expected_type = get_type(*node.type);
if (!types::equal(*expected_type, *actual_type))
if (!types::equal(*expected_type, *node.inferred_type))
{
std::ostringstream os;
os << "Cannot initialize a variable of type ";
print(os, *expected_type);
os << " with an expression of type ";
print(os, *actual_type);
print(os, *node.inferred_type);
throw type_error(os.str(), node.location);
}
}
scopes.back().variables[node.name] = {
.category = node.category,
.type = actual_type,
};
}
void apply(if_chain const & node)
@ -779,27 +742,21 @@ namespace pslang::ast
scopes.pop_back();
}
void apply(function_definition const & node)
void apply(function_definition & node)
{
// Already added to scope by populate_globals_visitor
scopes.emplace_back().is_function_scope = true;
scopes.back().expected_return_type = get_type(*node.return_type);
for (auto const & argument : node.arguments)
{
scopes.back().variables[argument.name] = {
.category = value_category::constant,
.type = get_type(*argument.type),
};
}
scopes.back().expected_return_type = node.inferred_result_type;
apply(*node.statements);
scopes.pop_back();
}
void apply(foreign_function_declaration const &)
{}
void apply(foreign_function_declaration & node)
{
}
void apply(return_statement const & node)
{
@ -837,11 +794,11 @@ namespace pslang::ast
void apply(statement_list & node)
{
populate_globals(scopes, node);
for (auto const & statement : node.statements)
for (auto & statement : node.statements)
apply(*statement);
for (auto & struct_data : scopes.back().structs)
compute_layout(struct_data.second, scopes);
compute_layout(struct_data.first, struct_data.second, scopes);
}
private:
@ -849,12 +806,12 @@ namespace pslang::ast
{
if (auto identifier = std::get_if<ast::identifier>(node.get()))
{
auto const & scope = scopes[identifier->level];
if (scope.functions.contains(identifier->name))
if (identifier->variable_node)
return identifier->variable_node->category;
else if (identifier->function_node || identifier->foreign_function_node)
return ast::value_category::constant;
if (auto it = scope.variables.find(identifier->name); it != scope.variables.end())
return it->second.category;
return std::nullopt;
else
return std::nullopt;
}
else if (auto field_access = std::get_if<ast::field_access>(node.get()))
{

View file

@ -691,16 +691,17 @@ namespace pslang::interpreter
value * eval_ref_impl(context & context, ast::identifier const & identifier)
{
if (identifier.level >= context.frame_stack.size())
throw internal_error("Bad identifier level", identifier.location);
throw std::runtime_error("Not implemented");
// if (identifier.level >= context.frame_stack.size())
// throw internal_error("Bad identifier level", identifier.location);
auto & scope = context.frame_stack[identifier.level];
// auto & scope = context.frame_stack[identifier.level];
auto it = scope.variables.find(identifier.name);
if (it == scope.variables.end())
throw internal_error("Identifier \"" + identifier.name + "\" is not defined", identifier.location);
// auto it = scope.variables.find(identifier.name);
// if (it == scope.variables.end())
// throw internal_error("Identifier \"" + identifier.name + "\" is not defined", identifier.location);
return &it->second.value;
// return &it->second.value;
}
value * eval_ref_impl(context & context, ast::unary_operation const & unary_operation)

View file

@ -444,7 +444,7 @@ namespace pslang::jit::aarch64
struct scope
{
std::unordered_map<std::string, variable_data> variables = {};
std::unordered_map<ast::variable_base const *, variable_data> variables = {};
// Difference between initial virtual stack pointer at scope enter
// and current virtual stack pointer value
@ -521,8 +521,9 @@ namespace pslang::jit::aarch64
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (auto jt = it->variables.find(node.name); jt != it->variables.end())
if (node.variable_node && it->variables.contains(node.variable_node))
{
auto jt = it->variables.find(node.variable_node);
if (auto struct_type = std::get_if<types::struct_type>(node.inferred_type.get()))
{
std::size_t stack_size = ((struct_type->node->layout.size + 15) / 16) * 16;
@ -1086,7 +1087,7 @@ namespace pslang::jit::aarch64
}
else
push(0);
scopes.back().variables[node.name] = {.frame_offset = stack_offset};
scopes.back().variables[&node] = {.frame_offset = stack_offset};
}
void apply(ast::if_chain const & node)
@ -1214,7 +1215,7 @@ namespace pslang::jit::aarch64
push_fp(fp_reg, mode);
++fp_reg;
}
scopes.back().variables[argument.name] = {.frame_offset = stack_offset};
scopes.back().variables[&argument] = {.frame_offset = stack_offset};
}
apply(*node.statements);
@ -1291,9 +1292,10 @@ namespace pslang::jit::aarch64
{
if (auto identifier = std::get_if<ast::identifier>(node.get()))
{
auto const & scope = scopes[identifier->level];
if (auto it = scope.variables.find(identifier->name); it != scope.variables.end())
return it->second.frame_offset;
if (identifier->variable_node)
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
if (auto jt = it->variables.find(identifier->variable_node); jt != it->variables.end())
return jt->second.frame_offset;
throw std::runtime_error("Non-lvalue identifier: \"" + identifier->name + "\"");
}
else if (auto field_access = std::get_if<ast::field_access>(node.get()))

View file

@ -241,7 +241,7 @@ nonempty_function_declaration_argument_list
;
function_declaration_single_argument
: name colon type_expression { $$ = ast::function_declaration::argument{$1, std::make_unique<ast::type>($3), @$}; }
: name colon type_expression { $$ = ast::function_declaration::argument{ast::value_category::constant, $1, std::make_unique<ast::type>($3), @$}; }
;
function_return_type
@ -250,8 +250,8 @@ function_return_type
;
variable_declaration
: variable_keyword name assignment expression { $$ = ast::variable_declaration{$1, $2, nullptr, std::make_unique<ast::expression>($4), @$}; }
| variable_keyword name colon type_expression assignment expression { $$ = ast::variable_declaration{$1, $2, std::make_unique<ast::type>($4), std::make_unique<ast::expression>($6), @$}; }
: variable_keyword name assignment expression { $$ = ast::variable_declaration{{$1, $2, nullptr, @$}, std::make_unique<ast::expression>($4)}; }
| variable_keyword name colon type_expression assignment expression { $$ = ast::variable_declaration{{$1, $2, std::make_unique<ast::type>($4), @$}, std::make_unique<ast::expression>($6)}; }
;
variable_keyword

View file

@ -14,7 +14,7 @@ namespace pslang::types
struct struct_type
{
ast::struct_definition * node;
ast::struct_definition * node = nullptr;
friend bool operator == (struct_type const &, struct_type const &) = default;
};

View file

@ -16,6 +16,7 @@ Interpreter backlog:
* C FFI (foreign functions)
Aarch64 compiler backlog:
* Rewrite using IR (compiler_v2.cpp first, then swap with the old one)
* Struct fields in structs (initialization & field access)
* Struct function arguments & return values
* Arrays