Refactor: split pre-AST (contains if/else if/else blocks, field definitions, no hierarchy) and the actual AST obtained after resolving scoping & indentation (contains if chains, no field definitions)

This commit is contained in:
Nikita Lisitsa 2026-03-22 14:06:34 +03:00
parent 56c63d50ac
commit 292d6eeabf
15 changed files with 82 additions and 274 deletions

View file

@ -7,9 +7,6 @@
namespace pslang::ast namespace pslang::ast
{ {
// N.B.: if_block, else_block, and else_if_block are temporary parsing elements
// and are not present in the final AST
struct if_block struct if_block
{ {
expression_ptr condition; expression_ptr condition;

View file

@ -30,14 +30,13 @@ namespace pslang::ast
ast::location location; ast::location location;
}; };
using statement_impl = std::variant< using pre_statement_impl = std::variant<
expression_ptr, expression_ptr,
assignment, assignment,
variable_declaration, variable_declaration,
if_block, if_block,
else_block, else_block,
else_if_block, else_if_block,
if_chain,
while_block, while_block,
function_definition, function_definition,
foreign_function_declaration, foreign_function_declaration,
@ -46,12 +45,31 @@ namespace pslang::ast
struct_definition struct_definition
>; >;
struct pre_statement
: pre_statement_impl
{
using pre_statement_impl::pre_statement_impl;
};
using statement_impl = std::variant<
expression_ptr,
assignment,
variable_declaration,
if_chain,
while_block,
function_definition,
foreign_function_declaration,
return_statement,
struct_definition
>;
struct statement struct statement
: statement_impl : statement_impl
{ {
using statement_impl::statement_impl; using statement_impl::statement_impl;
}; };
location get_location(pre_statement const & statement);
location get_location(statement const & statement); location get_location(statement const & statement);
} }

View file

@ -6,15 +6,23 @@
namespace pslang::ast namespace pslang::ast
{ {
struct pre_statement;
struct statement; struct statement;
using pre_statement_ptr = std::shared_ptr<pre_statement>;
using statement_ptr = std::shared_ptr<statement>; using statement_ptr = std::shared_ptr<statement>;
struct pre_statement_list
{
std::vector<pre_statement_ptr> statements;
};
struct statement_list struct statement_list
{ {
std::vector<statement_ptr> statements; std::vector<statement_ptr> statements;
}; };
using pre_statement_list_ptr = std::shared_ptr<pre_statement_list>;
using statement_list_ptr = std::shared_ptr<statement_list>; using statement_list_ptr = std::shared_ptr<statement_list>;
} }

View file

@ -386,26 +386,6 @@ namespace pslang::ast
child(node.initializer); child(node.initializer);
} }
void apply(if_block const & node)
{
put_indent(out, options);
out << "if\n";
child(node.condition);
}
void apply(else_block const & node)
{
put_indent(out, options);
out << "else\n";
}
void apply(else_if_block const & node)
{
put_indent(out, options);
out << "else if\n";
child(node.condition);
}
void apply(if_chain const & node) void apply(if_chain const & node)
{ {
put_indent(out, options); put_indent(out, options);
@ -483,20 +463,17 @@ namespace pslang::ast
child(node.value); child(node.value);
} }
void apply(field_definition const & node)
{
put_indent(out, options);
out << "field { name = \"" << node.name << "\", type = ";
print(out, *node.type);
out << ", offset = " << node.layout.offset << " }\n";
}
void apply(struct_definition const & node) void apply(struct_definition const & node)
{ {
put_indent(out, options); put_indent(out, options);
out << "struct { name = \"" << node.name << "\", size = " << node.layout.size << ", align = " << node.layout.alignment << " }\n"; out << "struct { name = \"" << node.name << "\", size = " << node.layout.size << ", align = " << node.layout.alignment << " }\n";
for (auto const & field : node.fields) for (auto const & field : node.fields)
child(field); {
put_indent(out, as_child(options));
out << "field { name = \"" << node.name << "\", type = ";
print(out, *field.type);
out << ", offset = " << field.layout.offset << " }\n";
}
} }
}; };

View file

@ -71,15 +71,6 @@ namespace pslang::ast
void apply(variable_declaration const &) void apply(variable_declaration const &)
{} {}
void apply(if_block const &)
{}
void apply(else_if_block const &)
{}
void apply(else_block const &)
{}
void apply(if_chain const &) void apply(if_chain const &)
{} {}
@ -105,9 +96,6 @@ namespace pslang::ast
void apply(return_statement const &) void apply(return_statement const &)
{} {}
void apply(field_definition const &)
{}
void apply(struct_definition const & struct_definition) void apply(struct_definition const & struct_definition)
{ {
if (scopes.back().contains(struct_definition.name)) if (scopes.back().contains(struct_definition.name))
@ -270,21 +258,6 @@ namespace pslang::ast
scopes.back().variables.insert(variable_declaration.name); scopes.back().variables.insert(variable_declaration.name);
} }
void apply(if_block const & if_block)
{
throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location);
}
void apply(else_if_block const & else_if_block)
{
throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location);
}
void apply(else_block const & else_block)
{
throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location);
}
void apply(if_chain const & if_chain) void apply(if_chain const & if_chain)
{ {
for (auto const & block : if_chain.blocks) for (auto const & block : if_chain.blocks)
@ -351,17 +324,12 @@ namespace pslang::ast
apply(*return_statement.value); apply(*return_statement.value);
} }
void apply(field_definition const & field_definition)
{
apply(*field_definition.type);
}
void apply(struct_definition const & struct_definition) void apply(struct_definition const & struct_definition)
{ {
// Already added to scope by populate_globals_visitor // Already added to scope by populate_globals_visitor
for (auto const & field : struct_definition.fields) for (auto const & field : struct_definition.fields)
apply(field); apply(*field.type);
scopes.back().structs.insert(struct_definition.name); scopes.back().structs.insert(struct_definition.name);
} }

View file

@ -1,5 +1,4 @@
#include <pslang/ast/statement.hpp> #include <pslang/ast/statement.hpp>
#include <pslang/ast/statement_visitor.hpp>
namespace pslang::ast namespace pslang::ast
{ {
@ -8,71 +7,14 @@ namespace pslang::ast
{ {
struct get_location_visitor struct get_location_visitor
: const_statement_visitor<get_location_visitor>
{ {
using const_statement_visitor::apply; location operator()(expression_ptr const & node)
location apply(expression_ptr const & node)
{ {
return get_location(*node); return get_location(*node);
} }
location apply(assignment const & node) template <typename Node>
{ location operator()(Node const & node)
return node.location;
}
location apply(variable_declaration const & node)
{
return node.location;
}
location apply(if_block const & node)
{
return node.location;
}
location apply(else_block const & node)
{
return node.location;
}
location apply(else_if_block const & node)
{
return node.location;
}
location apply(if_chain const & node)
{
return node.location;
}
location apply(while_block const & node)
{
return node.location;
}
location apply(function_definition const & node)
{
return node.location;
}
location apply(foreign_function_declaration const & node)
{
return node.location;
}
location apply(return_statement const & node)
{
return node.location;
}
location apply(field_definition const & node)
{
return node.location;
}
location apply(struct_definition const & node)
{ {
return node.location; return node.location;
} }
@ -80,9 +22,14 @@ namespace pslang::ast
} }
location get_location(pre_statement const & statement)
{
return std::visit(get_location_visitor{}, statement);
}
location get_location(statement const & statement) location get_location(statement const & statement)
{ {
return get_location_visitor{}.apply(statement); return std::visit(get_location_visitor{}, statement);
} }
} }

View file

@ -197,15 +197,6 @@ namespace pslang::ast
void apply(variable_declaration const &) void apply(variable_declaration const &)
{} {}
void apply(if_block const &)
{}
void apply(else_if_block const &)
{}
void apply(else_block const &)
{}
void apply(if_chain const &) void apply(if_chain const &)
{} {}
@ -253,16 +244,13 @@ namespace pslang::ast
void apply(return_statement const &) void apply(return_statement const &)
{} {}
void apply(field_definition & node)
{
resolve_types(scopes, *node.type);
node.inferred_type = get_type(*node.type);
}
void apply(struct_definition & node) void apply(struct_definition & node)
{ {
for (auto & field : node.fields) for (auto & field : node.fields)
apply(field); {
resolve_types(scopes, *field.type);
field.inferred_type = get_type(*field.type);
}
scopes.back().structs[node.name].node = &node; scopes.back().structs[node.name].node = &node;
} }
@ -751,21 +739,6 @@ namespace pslang::ast
}; };
} }
void apply(if_block const & node)
{
throw invalid_ast_error("if blocks cannot be present in the final AST", node.location);
}
void apply(else_if_block const & node)
{
throw invalid_ast_error("else if blocks cannot be present in the final AST", node.location);
}
void apply(else_block const & node)
{
throw invalid_ast_error("else blocks cannot be present in the final AST", node.location);
}
void apply(if_chain const & node) void apply(if_chain const & node)
{ {
for (auto const & block : node.blocks) for (auto const & block : node.blocks)
@ -856,9 +829,6 @@ namespace pslang::ast
} }
} }
void apply(field_definition const &)
{}
void apply(struct_definition const & node) void apply(struct_definition const & node)
{ {
// Already added to scope by populate_globals_visitor // Already added to scope by populate_globals_visitor

View file

@ -20,12 +20,6 @@ namespace pslang::ast
void apply(variable_declaration const &) {} void apply(variable_declaration const &) {}
void apply(if_block const &) {}
void apply(else_block const &) {}
void apply(else_if_block const &) {}
void apply(if_chain const & node) void apply(if_chain const & node)
{ {
for (auto const & block : node.blocks) for (auto const & block : node.blocks)
@ -53,8 +47,6 @@ namespace pslang::ast
void apply(return_statement const &) {} void apply(return_statement const &) {}
void apply(field_definition const &) {}
void apply(struct_definition const &) {} void apply(struct_definition const &) {}
}; };

View file

@ -81,21 +81,6 @@ namespace pslang::interpreter
context.frame_stack.back().variables[variable_declaration.name] = {.category = variable_declaration.category, .value = value}; context.frame_stack.back().variables[variable_declaration.name] = {.category = variable_declaration.category, .value = value};
} }
void exec_impl(context & context, ast::if_block const & if_block)
{
throw internal_error("if blocks cannot be present in the final AST", if_block.location);
}
void exec_impl(context & context, ast::else_block const & else_block)
{
throw internal_error("else blocks cannot be present in the final AST", else_block.location);
}
void exec_impl(context & context, ast::else_if_block const & else_if_block)
{
throw internal_error("else if blocks cannot be present in the final AST", else_if_block.location);
}
void exec_impl(context & context, ast::if_chain const & if_chain) void exec_impl(context & context, ast::if_chain const & if_chain)
{ {
for (auto const & block : if_chain.blocks) for (auto const & block : if_chain.blocks)
@ -225,11 +210,6 @@ namespace pslang::interpreter
throw internal_error("Cannot return outside of function scope", return_statement.location); throw internal_error("Cannot return outside of function scope", return_statement.location);
} }
void exec_impl(context & context, ast::field_definition const & field_definition)
{
throw internal_error("Field definitions cannot be present in the final AST outside of struct definitions", field_definition.location);
}
void exec_impl(context & context, ast::struct_definition const & struct_definition) void exec_impl(context & context, ast::struct_definition const & struct_definition)
{ {
auto & frame = context.frame_stack.back(); auto & frame = context.frame_stack.back();

View file

@ -58,12 +58,6 @@ namespace pslang::ir
void apply(ast::variable_declaration const &) {} void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const &) {} void apply(ast::if_chain const &) {}
void apply(ast::while_block const &) {} void apply(ast::while_block const &) {}
@ -80,8 +74,6 @@ namespace pslang::ir
void apply(ast::return_statement const &) {} void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const & node) void apply(ast::struct_definition const & node)
{ {
lcontext.scopes.back().structs[node.name] = &node; lcontext.scopes.back().structs[node.name] = &node;
@ -324,7 +316,7 @@ namespace pslang::ir
return last(); return last();
} }
// TODO: struct_definition, field_definition // TODO: struct_definition
// Statement list // Statement list
@ -379,12 +371,6 @@ namespace pslang::ir
void apply(ast::variable_declaration const &) {} void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const & node) void apply(ast::if_chain const & node)
{ {
for (auto const & block : node.blocks) for (auto const & block : node.blocks)
@ -420,8 +406,6 @@ namespace pslang::ir
void apply(ast::return_statement const &) {} void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {} void apply(ast::struct_definition const &) {}
void apply(ast::statement_list const & list) void apply(ast::statement_list const & list)

View file

@ -235,15 +235,6 @@ namespace pslang::jit::aarch64
apply(*node.initializer); apply(*node.initializer);
} }
void apply(ast::if_block const &)
{}
void apply(ast::else_block const &)
{}
void apply(ast::else_if_block const &)
{}
void apply(ast::if_chain const & node) void apply(ast::if_chain const & node)
{ {
for (auto const & block : node.blocks) for (auto const & block : node.blocks)
@ -280,9 +271,6 @@ namespace pslang::jit::aarch64
apply(*node.value); apply(*node.value);
} }
void apply(ast::field_definition const &)
{}
void apply(ast::struct_definition const &) void apply(ast::struct_definition const &)
{} {}
@ -357,12 +345,6 @@ namespace pslang::jit::aarch64
void apply(ast::variable_declaration const &) {} void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const &) {} void apply(ast::if_chain const &) {}
void apply(ast::while_block const &) {} void apply(ast::while_block const &) {}
@ -379,8 +361,6 @@ namespace pslang::jit::aarch64
void apply(ast::return_statement const &) {} void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const & node) void apply(ast::struct_definition const & node)
{ {
lcontext.scopes.back().structs[node.name] = &node; lcontext.scopes.back().structs[node.name] = &node;
@ -1364,12 +1344,6 @@ namespace pslang::jit::aarch64
void apply(ast::variable_declaration const &) {} void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const & node) void apply(ast::if_chain const & node)
{ {
for (auto const & block : node.blocks) for (auto const & block : node.blocks)
@ -1392,8 +1366,6 @@ namespace pslang::jit::aarch64
void apply(ast::return_statement const &) {} void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {} void apply(ast::struct_definition const &) {}
void apply(ast::statement_list const & node) void apply(ast::statement_list const & node)

View file

@ -10,7 +10,7 @@ namespace pslang::parser
struct indented_statement struct indented_statement
{ {
std::size_t indentation; std::size_t indentation;
ast::statement_ptr statement; ast::pre_statement_ptr statement;
}; };
struct indented_statement_list struct indented_statement_list

View file

@ -165,7 +165,7 @@ template <typename T>
%type <indented_statement_list> indented_statement_list %type <indented_statement_list> indented_statement_list
%type <indented_statement> statement_line %type <indented_statement> statement_line
%type <std::size_t> indentation %type <std::size_t> indentation
%type <ast::statement> statement %type <ast::pre_statement> statement
%type <std::vector<ast::function_declaration::argument>> function_declaration_argument_list %type <std::vector<ast::function_declaration::argument>> function_declaration_argument_list
%type <std::vector<ast::function_declaration::argument>> nonempty_function_declaration_argument_list %type <std::vector<ast::function_declaration::argument>> nonempty_function_declaration_argument_list
%type <ast::function_declaration::argument> function_declaration_single_argument %type <ast::function_declaration::argument> function_declaration_single_argument
@ -197,7 +197,7 @@ indented_statement_list
; ;
statement_line statement_line
: indentation statement optional_comment { $$ = indented_statement{$1, std::make_unique<ast::statement>($2)}; } : indentation statement optional_comment { $$ = indented_statement{$1, std::make_unique<ast::pre_statement>($2)}; }
; ;
empty_line empty_line

View file

@ -18,7 +18,7 @@ namespace pslang::parser
ast::location apply(ast::expression_ptr const & node) ast::location apply(ast::expression_ptr const & node)
{ {
return ast::get_location(node); return ast::get_location(*node);
} }
ast::location apply(ast::assignment const & node) ast::location apply(ast::assignment const & node)
@ -31,21 +31,6 @@ namespace pslang::parser
return node.location; return node.location;
} }
ast::location apply(ast::if_block const & node)
{
return node.location;
}
ast::location apply(ast::else_block const & node)
{
return node.location;
}
ast::location apply(ast::else_if_block const & node)
{
return node.location;
}
ast::location apply(ast::if_chain & node) ast::location apply(ast::if_chain & node)
{ {
bool first = true; bool first = true;
@ -81,16 +66,11 @@ namespace pslang::parser
return node.location; return node.location;
} }
ast::location apply(ast::field_definition const & node)
{
return node.location;
}
ast::location apply(ast::struct_definition & node) ast::location apply(ast::struct_definition & node)
{ {
node.location = node.prelude_location; node.location = node.prelude_location;
for (auto const & field : node.fields) for (auto const & field : node.fields)
node.location = ast::merge(node.location, apply(field)); node.location = ast::merge(node.location, field.location);
return node.location; return node.location;
} }
@ -200,13 +180,13 @@ namespace pslang::parser
{ {
while_block->statements = std::make_unique<ast::statement_list>(); while_block->statements = std::make_unique<ast::statement_list>();
list = while_block->statements.get(); list = while_block->statements.get();
current_statement_list(location)->statements.push_back(std::move(statement.statement)); current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*while_block)));
} }
else if (auto function_definition = std::get_if<ast::function_definition>(statement.statement.get())) else if (auto function_definition = std::get_if<ast::function_definition>(statement.statement.get()))
{ {
function_definition->statements = std::make_unique<ast::statement_list>(); function_definition->statements = std::make_unique<ast::statement_list>();
list = function_definition->statements.get(); list = function_definition->statements.get();
current_statement_list(location)->statements.push_back(std::move(statement.statement)); current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*function_definition)));
is_function_definition = true; is_function_definition = true;
} }
else if (auto field_definition = std::get_if<ast::field_definition>(statement.statement.get())) else if (auto field_definition = std::get_if<ast::field_definition>(statement.statement.get()))
@ -217,9 +197,9 @@ namespace pslang::parser
throw parse_error("Duplicate field definition: \"" + field.name + "\"", field.location); throw parse_error("Duplicate field definition: \"" + field.name + "\"", field.location);
current->fields.push_back(*field_definition); current->fields.push_back(*field_definition);
} }
else if (std::get_if<ast::struct_definition>(statement.statement.get())) else if (auto struct_definition = std::get_if<ast::struct_definition>(statement.statement.get()))
{ {
current_statement_list(location)->statements.push_back(std::move(statement.statement)); current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*struct_definition)));
stack.push_back(std::get_if<ast::struct_definition>(current_statement_list(location)->statements.back().get())); stack.push_back(std::get_if<ast::struct_definition>(current_statement_list(location)->statements.back().get()));
++current_indent; ++current_indent;
if (in_function_scope > 0) if (in_function_scope > 0)
@ -230,11 +210,27 @@ namespace pslang::parser
if (in_function_scope == 0) if (in_function_scope == 0)
throw parse_error("Return statement outside of function scope", return_statement->location); throw parse_error("Return statement outside of function scope", return_statement->location);
return_statement->level = stack.size() - in_function_scope; return_statement->level = stack.size() - in_function_scope;
current_statement_list(location)->statements.push_back(std::move(statement.statement)); current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*return_statement)));
}
else if (auto expression_ptr = std::get_if<ast::expression_ptr>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*expression_ptr)));
}
else if (auto assignment = std::get_if<ast::assignment>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*assignment)));
}
else if (auto variable_declaration = std::get_if<ast::variable_declaration>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*variable_declaration)));
}
else if (auto foreign_function_declaration = std::get_if<ast::foreign_function_declaration>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*foreign_function_declaration)));
} }
else else
{ {
current_statement_list(location)->statements.push_back(std::move(statement.statement)); throw ast::invalid_ast_error(std::format("Unknown pre-statement \"{}\"", std::visit([](auto const & statement){ return typeid(statement).name(); }, *statement.statement)), location);
} }
if (list) if (list)

View file

@ -43,4 +43,3 @@ General backlog:
* Replace std::runtime_error with appropriate custom exception types * Replace std::runtime_error with appropriate custom exception types
* Replace std::ostringstream with std::format (need support for std::format in type/ast printing) * Replace std::ostringstream with std::format (need support for std::format in type/ast printing)
* TEST COVERAGE!!! * TEST COVERAGE!!!
* Separate actual AST from pre-AST (the one before indentation & scoping is resolved)