pslang/libs/parser/source/finalize.cpp

275 lines
10 KiB
C++

#include <pslang/parser/indented_statement.hpp>
#include <pslang/parser/error.hpp>
#include <pslang/ast/statement.hpp>
#include <pslang/ast/statement_visitor.hpp>
#include <vector>
namespace pslang::parser
{
namespace
{
struct fill_location_visitor
: ast::statement_visitor<fill_location_visitor>
{
using statement_visitor::apply;
ast::location apply(ast::expression_ptr const & node)
{
return ast::get_location(*node);
}
ast::location apply(ast::assignment const & node)
{
return node.location;
}
ast::location apply(ast::variable_declaration const & node)
{
return node.location;
}
ast::location apply(ast::if_chain & node)
{
bool first = true;
for (auto & block : node.blocks)
{
block.location = apply(*block.statements);
if (first)
node.location = block.location;
else
node.location = ast::merge(node.location, block.location);
first = false;
}
return node.location;
}
ast::location apply(ast::while_block & node)
{
return node.location = ast::merge(node.prelude_location, apply(*node.statements));
}
ast::location apply(ast::break_statement & node)
{
return node.location;
}
ast::location apply(ast::continue_statement & node)
{
return node.location;
}
ast::location apply(ast::function_definition & node)
{
return node.location = ast::merge(node.prelude_location, apply(*node.statements));
}
ast::location apply(ast::foreign_function_declaration & node)
{
return node.location = node.prelude_location;
}
ast::location apply(ast::return_statement const & node)
{
return node.location;
}
ast::location apply(ast::struct_definition & node)
{
node.location = node.prelude_location;
for (auto const & field : node.fields)
node.location = ast::merge(node.location, field.location);
return node.location;
}
ast::location apply(ast::statement_list & list)
{
ast::location result;
bool first = true;
for (auto & statement : list.statements)
{
auto statement_location = apply(*statement);
if (first)
result = statement_location;
else
result = ast::merge(result, statement_location);
first = false;
}
return result;
}
};
}
ast::statement_list_ptr finalize(indented_statement_list statements)
{
ast::statement_list_ptr result = std::make_unique<ast::statement_list>();
using stack_entry = std::variant<ast::statement_list *, ast::struct_definition *>;
std::vector<stack_entry> stack;
stack.push_back(result.get());
std::size_t current_indent = 0;
std::vector<ast::function_definition *> function_stack;
std::vector<ast::statement_list *> loop_stack;
auto current_statement_list = [&](ast::location const & location) -> ast::statement_list *
{
if (stack.empty())
throw internal_error("Empty finilization stack");
if (auto list = std::get_if<ast::statement_list *>(&stack.back()))
return *list;
throw parse_error("Unexpected statement inside struct definition", location);
};
auto current_struct_definition = [&](ast::location const & location) -> ast::struct_definition *
{
if (stack.empty())
throw internal_error("Empty finilization stack");
if (auto list = std::get_if<ast::struct_definition *>(&stack.back()))
return *list;
throw parse_error("Unexpected statement outside struct definition", location);
};
for (auto & statement : statements.statements)
{
auto location = ast::get_location(*statement.statement);
if (statement.indentation > current_indent)
throw parse_error("Unexpected indent", location);
while (statement.indentation < current_indent)
{
if (stack.empty())
throw ast::invalid_ast_error("Unexpected empty indent stack", ast::get_location(*statement.statement));
if (!function_stack.empty() && std::holds_alternative<ast::statement_list *>(stack.back()) && function_stack.back()->statements.get() == std::get<ast::statement_list *>(stack.back()))
function_stack.pop_back();
if (!loop_stack.empty() && std::holds_alternative<ast::statement_list *>(stack.back()) && loop_stack.back() == std::get<ast::statement_list *>(stack.back()))
loop_stack.pop_back();
stack.pop_back();
--current_indent;
}
// Now statement.indentation == current_indent
ast::statement_list * list = nullptr;
if (auto if_block = std::get_if<ast::if_block>(statement.statement.get()))
{
ast::if_chain chain;
chain.location = if_block->location;
chain.blocks.push_back({.condition = std::move(if_block->condition), .statements = std::make_unique<ast::statement_list>(), .prelude_location = if_block->location});
list = chain.blocks.back().statements.get();
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(chain)));
}
else if (auto else_block = std::get_if<ast::else_block>(statement.statement.get()))
{
if (current_statement_list(location)->statements.empty())
throw parse_error("Unexpected else block", location);
auto chain = std::get_if<ast::if_chain>(current_statement_list(location)->statements.back().get());
if (!chain || chain->blocks.empty() || !chain->blocks.back().condition)
throw parse_error("Unexpected else block", location);
chain->blocks.push_back({.condition = nullptr, .statements = std::make_unique<ast::statement_list>(), .prelude_location = else_block->location});
list = chain->blocks.back().statements.get();
}
else if (auto else_if_block = std::get_if<ast::else_if_block>(statement.statement.get()))
{
if (current_statement_list(location)->statements.empty())
throw parse_error("Unexpected else if block", location);
auto chain = std::get_if<ast::if_chain>(current_statement_list(location)->statements.back().get());
if (!chain || chain->blocks.empty() || !chain->blocks.back().condition)
throw parse_error("Unexpected else if block", location);
chain->blocks.push_back({.condition = std::move(else_if_block->condition), .statements = std::make_unique<ast::statement_list>(), .prelude_location = else_if_block->location});
list = chain->blocks.back().statements.get();
}
else if (auto while_block = std::get_if<ast::while_block>(statement.statement.get()))
{
while_block->statements = std::make_unique<ast::statement_list>();
list = while_block->statements.get();
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*while_block)));
loop_stack.push_back(list);
}
else if (auto function_definition = std::get_if<ast::function_definition>(statement.statement.get()))
{
function_definition->statements = std::make_unique<ast::statement_list>();
auto statement = std::make_unique<ast::statement>(std::move(*function_definition));
auto function_definition_ptr = std::get_if<ast::function_definition>(statement.get());
current_statement_list(location)->statements.push_back(std::move(statement));
list = function_definition_ptr->statements.get();
function_stack.push_back(function_definition_ptr);
}
else if (auto field_definition = std::get_if<ast::field_definition>(statement.statement.get()))
{
auto current = current_struct_definition(location);
for (auto const & field : current->fields)
if (field.name == field_definition->name)
throw parse_error("Duplicate field definition: \"" + field.name + "\"", field.location);
current->fields.push_back(*field_definition);
}
else if (auto struct_definition = std::get_if<ast::struct_definition>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*struct_definition)));
stack.push_back(std::get_if<ast::struct_definition>(current_statement_list(location)->statements.back().get()));
++current_indent;
}
else if (auto return_statement = std::get_if<ast::return_statement>(statement.statement.get()))
{
if (function_stack.empty())
throw parse_error("Return statement outside of function scope", return_statement->location);
return_statement->node = function_stack.back();
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*return_statement)));
}
else if (auto expression_ptr = std::get_if<ast::expression_ptr>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*expression_ptr)));
}
else if (auto assignment = std::get_if<ast::assignment>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*assignment)));
}
else if (auto variable_declaration = std::get_if<ast::variable_declaration>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*variable_declaration)));
}
else if (auto foreign_function_declaration = std::get_if<ast::foreign_function_declaration>(statement.statement.get()))
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*foreign_function_declaration)));
}
else if (auto break_statement = std::get_if<ast::break_statement>(statement.statement.get()))
{
if (loop_stack.empty())
throw parse_error("Break without an enclosing loop", break_statement->location);
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*break_statement)));
}
else if (auto continue_statement = std::get_if<ast::continue_statement>(statement.statement.get()))
{
if (loop_stack.empty())
throw parse_error("Continue without an enclosing loop", continue_statement->location);
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*continue_statement)));
}
else
{
throw ast::invalid_ast_error(std::format("Unknown pre-statement \"{}\"", std::visit([](auto const & statement){ return typeid(statement).name(); }, *statement.statement)), location);
}
if (list)
{
stack.push_back(list);
++current_indent;
}
}
fill_location_visitor{}.apply(*result);
return result;
}
}