Add scope levels to identifier AST nodes & implement identifier resolution

This commit is contained in:
Nikita Lisitsa 2025-12-19 17:42:40 +03:00
parent 8661ab6ace
commit 437123f6f4
12 changed files with 426 additions and 130 deletions

View file

@ -2,6 +2,7 @@
#include <pslang/parser/error.hpp>
#include <pslang/interpreter/exec.hpp>
#include <pslang/ast/statement.hpp>
#include <pslang/ast/preprocess.hpp>
#include <pslang/ast/print.hpp>
#include <filesystem>
@ -91,7 +92,9 @@ int main(int argc, char ** argv)
try
{
filenames.push_back(argv[arg]);
parsed.push_back(parser::parse(filenames.back()));
auto ast = parser::parse(filenames.back());
ast::resolve_identifiers(ast);
parsed.push_back(std::move(ast));
}
catch (pslang::parser::parse_error const & error)
{

3
backlog.txt Normal file
View file

@ -0,0 +1,3 @@
* Mutually recursive functions
* Refactor all tree visitors to prevent infinite recursion (maybe add decicated tree traversal functions & visitors?)
* Type identifier location + move types to ast library / split type values and type ast nodes

View file

@ -13,20 +13,17 @@ namespace pslang::ast
struct if_block
{
expression_ptr condition;
statement_list_ptr statements;
ast::location location;
};
struct else_block
{
statement_list_ptr statements;
ast::location location;
};
struct else_if_block
{
expression_ptr condition;
statement_list_ptr statements;
ast::location location;
};

View file

@ -0,0 +1,66 @@
#pragma once
#include <pslang/ast/location.hpp>
#include <exception>
#include <string>
namespace pslang::ast
{
struct parse_error
: std::exception
{
parse_error(std::string message, location location)
: message_(std::move(message))
, filename_(location.filename)
, location_(location)
{
location_.filename = filename_;
}
char const * what() const noexcept
{
return message_.c_str();
}
ast::location location() const noexcept
{
return location_;
}
private:
std::string message_;
std::string filename_;
ast::location location_;
};
struct invalid_ast_error
: std::exception
{
invalid_ast_error(std::string message, location location)
: message_(std::move(message))
, filename_(location.filename)
, location_(location)
{
location_.filename = filename_;
}
char const * what() const noexcept
{
return message_.c_str();
}
ast::location location() const noexcept
{
return location_;
}
private:
std::string message_;
std::string filename_;
ast::location location_;
};
}

View file

@ -11,6 +11,7 @@ namespace pslang::ast
{
std::string name;
ast::location location;
std::size_t level = 0;
};
}

View file

@ -0,0 +1,10 @@
#pragma once
#include <pslang/ast/statement_fwd.hpp>
namespace pslang::ast
{
void resolve_identifiers(statement_list_ptr & statements);
}

View file

@ -185,14 +185,12 @@ namespace pslang::ast
put_indent(out, options);
out << "if\n";
print(out, node.condition, child(options));
print(out, node.statements, child(options));
}
void print_impl(std::ostream & out, else_block const & node, print_options const & options)
{
put_indent(out, options);
out << "else\n";
print(out, node.statements, child(options));
}
void print_impl(std::ostream & out, else_if_block const & node, print_options const & options)
@ -200,7 +198,6 @@ namespace pslang::ast
put_indent(out, options);
out << "else if\n";
print(out, node.condition, child(options));
print(out, node.statements, child(options));
}
void print_impl(std::ostream & out, if_chain const & node, print_options const & options)

View file

@ -0,0 +1,295 @@
#include <pslang/ast/preprocess.hpp>
#include <pslang/ast/statement.hpp>
#include <pslang/ast/error.hpp>
#include <unordered_set>
#include <vector>
namespace pslang::ast
{
namespace
{
struct scope
{
std::unordered_set<std::string> functions;
std::unordered_set<std::string> structs;
std::unordered_set<std::string> variables;
bool contains_transitive(std::string const & name) const
{
return false
|| (functions.count(name) > 0)
|| (structs.count(name) > 0)
;
}
bool contains_type(std::string const & name) const
{
return false
|| (structs.count(name) > 0)
;
}
bool contains(std::string const & name) const
{
return false
|| (functions.count(name) > 0)
|| (structs.count(name) > 0)
|| (variables.count(name) > 0)
;
}
bool contains(std::string const & name, bool crossed_function_scope) const
{
if (crossed_function_scope)
return contains_transitive(name);
return contains(name);
}
bool is_function_scope = false;
bool is_global_scope = false;
};
struct context
{
std::vector<scope> scopes;
};
void resolve_identifiers(context & context, type::type & type);
void resolve_identifiers(context & context, expression & expression);
void resolve_identifiers(context & context, statement_list const & statements);
void resolve_identifiers_impl(context &, type::unit_type const &)
{}
void resolve_identifiers_impl(context &, type::primitive_type const &)
{}
void resolve_identifiers_impl(context & context, type::array_type const & array_type)
{
resolve_identifiers(context, *array_type.element_type);
}
void resolve_identifiers_impl(context & context, type::function_type const & function_type)
{
for (auto const & argument : function_type.arguments)
resolve_identifiers(context, *argument);
resolve_identifiers(context, *function_type.result);
}
void resolve_identifiers_impl(context & context, type::identifier & identifier)
{
for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it)
{
if (it->contains_type(identifier.name))
{
identifier.level = it.base() - context.scopes.begin() - 1;
return;
}
}
// TODO location
throw parse_error("Identifier \"" + identifier.name + "\" not found", {});
}
void resolve_identifiers(context & context, type::type & type)
{
return std::visit([&](auto & type){ return resolve_identifiers_impl(context, type); }, type);
}
void resolve_identifiers_impl(context &, literal const &)
{}
void resolve_identifiers_impl(context & context, identifier & identifier)
{
bool crossed_function_scope = false;
for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it)
{
if (it->contains(identifier.name, crossed_function_scope && !it->is_global_scope))
{
identifier.level = it.base() - context.scopes.begin() - 1;
return;
}
crossed_function_scope |= it->is_function_scope;
}
throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location);
}
void resolve_identifiers_impl(context & context, unary_operation const & unary_operation)
{
resolve_identifiers(context, *unary_operation.arg1);
}
void resolve_identifiers_impl(context & context, binary_operation const & binary_operation)
{
resolve_identifiers(context, *binary_operation.arg1);
resolve_identifiers(context, *binary_operation.arg2);
}
void resolve_identifiers_impl(context & context, cast_operation const & cast_operation)
{
resolve_identifiers(context, *cast_operation.expression);
resolve_identifiers(context, *cast_operation.type);
}
void resolve_identifiers_impl(context & context, function_call const & function_call)
{
resolve_identifiers(context, *function_call.function);
for (auto const & argument : function_call.arguments)
resolve_identifiers(context, *argument);
}
void resolve_identifiers_impl(context & context, array const & array)
{
for (auto const & element : array.elements)
resolve_identifiers(context, *element);
}
void resolve_identifiers_impl(context & context, array_access const & array_access)
{
resolve_identifiers(context, *array_access.array);
resolve_identifiers(context, *array_access.index);
}
void resolve_identifiers_impl(context & context, field_access const & field_access)
{
resolve_identifiers(context, *field_access.object);
}
void resolve_identifiers(context & context, expression & expression)
{
return std::visit([&](auto & expression){ return resolve_identifiers_impl(context, expression); }, expression);
}
void resolve_identifiers_impl(context & context, expression_ptr const & expression_ptr)
{
resolve_identifiers(context, *expression_ptr);
}
void resolve_identifiers_impl(context & context, assignment const & assignment)
{
resolve_identifiers(context, *assignment.lhs);
resolve_identifiers(context, *assignment.rhs);
}
void resolve_identifiers_impl(context & context, variable_declaration const & variable_declaration)
{
if (context.scopes.back().contains(variable_declaration.name))
throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location);
if (variable_declaration.type)
resolve_identifiers(context, *variable_declaration.type);
resolve_identifiers(context, *variable_declaration.initializer);
context.scopes.back().variables.insert(variable_declaration.name);
}
void resolve_identifiers_impl(context &, if_block const & if_block)
{
throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location);
}
void resolve_identifiers_impl(context &, else_if_block const & else_if_block)
{
throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location);
}
void resolve_identifiers_impl(context &, else_block const & else_block)
{
throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location);
}
void resolve_identifiers_impl(context & context, if_chain const & if_chain)
{
for (auto const & block : if_chain.blocks)
{
if (block.condition)
resolve_identifiers(context, *block.condition);
context.scopes.emplace_back();
resolve_identifiers(context, *block.statements);
context.scopes.pop_back();
}
}
void resolve_identifiers_impl(context & context, while_block const & while_block)
{
resolve_identifiers(context, *while_block.condition);
context.scopes.emplace_back();
resolve_identifiers(context, *while_block.statements);
context.scopes.pop_back();
}
void resolve_identifiers_impl(context & context, function_definition const & function_definition)
{
if (context.scopes.back().contains(function_definition.name))
throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location);
std::unordered_set<std::string> argument_names;
for (auto const & argument : function_definition.arguments)
{
if (argument_names.count(argument.name) > 0)
throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location);
argument_names.insert(argument.name);
resolve_identifiers(context, *argument.type);
}
resolve_identifiers(context, *function_definition.return_type);
context.scopes.back().functions.insert(function_definition.name);
auto & scope = context.scopes.emplace_back();
scope.is_function_scope = true;
scope.variables = std::move(argument_names);
resolve_identifiers(context, *function_definition.statements);
context.scopes.pop_back();
}
void resolve_identifiers_impl(context & context, return_statement const & return_statement)
{
if (return_statement.value)
resolve_identifiers(context, *return_statement.value);
}
void resolve_identifiers_impl(context & context, field_definition const & field_definition)
{
resolve_identifiers(context, *field_definition.type);
}
void resolve_identifiers_impl(context & context, struct_definition const & struct_definition)
{
if (context.scopes.back().contains(struct_definition.name))
throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location);
for (auto const & field : struct_definition.fields)
resolve_identifiers_impl(context, field);
context.scopes.back().structs.insert(struct_definition.name);
}
void resolve_identifiers(context & context, statement const & statement)
{
return std::visit([&](auto const & statement){ return resolve_identifiers_impl(context, statement); }, statement);
}
void resolve_identifiers(context & context, statement_list const & statements)
{
for (auto const & statement : statements.statements)
{
resolve_identifiers(context, *statement);
}
}
}
void resolve_identifiers(statement_list_ptr & statements)
{
context context;
context.scopes.emplace_back().is_global_scope = true;
resolve_identifiers(context, *statements);
}
}

View file

@ -1,5 +1,6 @@
#pragma once
#include <pslang/ast/error.hpp>
#include <pslang/ast/location.hpp>
#include <exception>
@ -8,32 +9,7 @@
namespace pslang::parser
{
struct parse_error
: std::exception
{
parse_error(std::string message, ast::location location)
: message_(std::move(message))
, filename_(location.filename)
, location_(location)
{
location_.filename = filename_;
}
char const * what() const noexcept
{
return message_.c_str();
}
ast::location location() const noexcept
{
return location_;
}
private:
std::string message_;
std::string filename_;
ast::location location_;
};
using parse_error = ast::parse_error;
struct internal_error
: std::exception

View file

@ -148,6 +148,7 @@ template <typename T>
%type <ast::statement> statement
%type <std::vector<ast::function_definition::argument>> function_definition_argument_list
%type <std::vector<ast::function_definition::argument>> nonempty_function_definition_argument_list
%type <ast::function_definition::argument> function_definition_single_argument
%type <type::type_ptr> function_return_type
%type <ast::variable_declaration> variable_declaration
%type <ast::value_category> variable_keyword
@ -183,9 +184,9 @@ statement
: expression { $$ = std::make_unique<ast::expression>($1); }
| expression assignment expression { $$ = ast::assignment{ std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| variable_declaration { $$ = $1; }
| if expression colon { $$ = ast::if_block{std::make_unique<ast::expression>($2), {}, @$}; }
| else colon { $$ = ast::else_block{{}, @$}; }
| else if expression colon { $$ = ast::else_if_block{std::make_unique<ast::expression>($3), {}, @$}; }
| if expression colon { $$ = ast::if_block{std::make_unique<ast::expression>($2), @$}; }
| else colon { $$ = ast::else_block{@$}; }
| else if expression colon { $$ = ast::else_if_block{std::make_unique<ast::expression>($3), @$}; }
| while expression colon { $$ = ast::while_block{std::make_unique<ast::expression>($2), {}, @$}; }
| func name lparen function_definition_argument_list rparen function_return_type colon { $$ = ast::function_definition{$2, $4, $6, {}, @$}; }
| return expression { $$ = ast::return_statement{std::make_unique<ast::expression>($2), @$}; }
@ -200,8 +201,12 @@ function_definition_argument_list
;
nonempty_function_definition_argument_list
: name colon type_expression { std::vector<ast::function_definition::argument> tmp; tmp.push_back({.name = $1, .type = std::make_unique<type::type>($3), @$}); $$ = std::move(tmp); }
| nonempty_function_definition_argument_list comma name colon type_expression { auto tmp = $1; tmp.push_back({.name = $3, .type = std::make_unique<type::type>($5)}); $$ = std::move(tmp); }
: function_definition_single_argument { std::vector<ast::function_definition::argument> tmp; tmp.push_back($1); $$ = std::move(tmp); }
| nonempty_function_definition_argument_list comma function_definition_single_argument { auto tmp = $1; tmp.push_back($3); $$ = std::move(tmp); }
;
function_definition_single_argument
: name colon type_expression { $$ = ast::function_definition::argument{$1, std::make_unique<type::type>($3), @$}; }
;
function_return_type

View file

@ -2,88 +2,11 @@
#include <pslang/parser/error.hpp>
#include <pslang/ast/statement.hpp>
#include <stdexcept>
#include <vector>
namespace pslang::parser
{
namespace
{
ast::statement_list * get_statement_list(ast::expression_ptr &)
{
return nullptr;
}
ast::statement_list * get_statement_list(ast::assignment &)
{
return nullptr;
}
ast::statement_list * get_statement_list(ast::variable_declaration &)
{
return nullptr;
}
ast::statement_list * get_statement_list(ast::if_block & node)
{
node.statements = std::make_unique<ast::statement_list>();
return node.statements.get();
}
ast::statement_list * get_statement_list(ast::else_block & node)
{
node.statements = std::make_unique<ast::statement_list>();
return node.statements.get();
}
ast::statement_list * get_statement_list(ast::else_if_block & node)
{
node.statements = std::make_unique<ast::statement_list>();
return node.statements.get();
}
// NB: if chain merging happens after retrieving statement list
ast::statement_list * get_statement_list(ast::if_chain & node)
{
return nullptr;
}
ast::statement_list * get_statement_list(ast::while_block & node)
{
node.statements = std::make_unique<ast::statement_list>();
return node.statements.get();
}
ast::statement_list * get_statement_list(ast::function_definition & node)
{
node.statements = std::make_unique<ast::statement_list>();
return node.statements.get();
}
ast::statement_list * get_statement_list(ast::return_statement &)
{
return nullptr;
}
ast::statement_list * get_statement_list(ast::field_definition &)
{
return nullptr;
}
ast::statement_list * get_statement_list(ast::struct_definition &)
{
return nullptr;
}
ast::statement_list * get_statement_list(ast::statement & statement)
{
return std::visit([](auto & value){ return get_statement_list(value); }, statement);
}
}
ast::statement_list_ptr finilize(indented_statement_list statements)
{
ast::statement_list_ptr result = std::make_unique<ast::statement_list>();
@ -97,23 +20,23 @@ namespace pslang::parser
auto current_statement_list = [&](ast::location const & location) -> ast::statement_list *
{
if (stack.empty())
throw internal_error("empty finilization stack");
throw internal_error("Empty finilization stack");
if (auto list = std::get_if<ast::statement_list *>(&stack.back()))
return *list;
throw parse_error("unexpected statement inside struct definition", location);
throw parse_error("Unexpected statement inside struct definition", location);
};
auto current_struct_definition = [&](ast::location const & location) -> ast::struct_definition *
{
if (stack.empty())
throw std::runtime_error("Internal error: empty finilization stack");
throw internal_error("Empty finilization stack");
if (auto list = std::get_if<ast::struct_definition *>(&stack.back()))
return *list;
throw parse_error("unexpected statement outside struct definition", location);
throw parse_error("Unexpected statement outside struct definition", location);
};
for (auto & statement : statements.statements)
@ -121,7 +44,7 @@ namespace pslang::parser
auto location = ast::get_location(*statement.statement);
if (statement.indentation > current_indent)
throw parse_error("unexpected indent", location);
throw parse_error("Unexpected indent", location);
while (statement.indentation < current_indent)
{
@ -131,37 +54,56 @@ namespace pslang::parser
// Now statement.indentation == current_indent
auto list = get_statement_list(*statement.statement);
ast::statement_list * list = nullptr;
if (auto if_block = std::get_if<ast::if_block>(statement.statement.get()))
{
ast::if_chain chain;
chain.blocks.push_back({.condition = std::move(if_block->condition), .statements = std::move(if_block->statements)});
chain.blocks.push_back({.condition = std::move(if_block->condition), .statements = std::make_unique<ast::statement_list>()});
list = chain.blocks.back().statements.get();
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(chain)));
}
else if (auto else_block = std::get_if<ast::else_block>(statement.statement.get()))
{
if (current_statement_list(location)->statements.empty())
throw parse_error("unexpected else block", location);
throw parse_error("Unexpected else block", location);
auto chain = std::get_if<ast::if_chain>(current_statement_list(location)->statements.back().get());
if (!chain || chain->blocks.empty() || !chain->blocks.back().condition)
throw parse_error("unexpected else block", location);
throw parse_error("Unexpected else block", location);
chain->blocks.push_back({.condition = nullptr, .statements = std::move(else_block->statements)});
chain->blocks.push_back({.condition = nullptr, .statements = std::make_unique<ast::statement_list>()});
list = chain->blocks.back().statements.get();
}
else if (auto else_if_block = std::get_if<ast::else_if_block>(statement.statement.get()))
{
if (current_statement_list(location)->statements.empty())
throw parse_error("unexpected else if block", location);
throw parse_error("Unexpected else if block", location);
auto chain = std::get_if<ast::if_chain>(current_statement_list(location)->statements.back().get());
if (!chain || chain->blocks.empty() || !chain->blocks.back().condition)
throw parse_error("unexpected else if block", location);
throw parse_error("Unexpected else if block", location);
chain->blocks.push_back({.condition = std::move(else_if_block->condition), .statements = std::move(else_if_block->statements)});
chain->blocks.push_back({.condition = std::move(else_if_block->condition), .statements = std::make_unique<ast::statement_list>()});
list = chain->blocks.back().statements.get();
}
else if (auto while_block = std::get_if<ast::while_block>(statement.statement.get()))
{
while_block->statements = std::make_unique<ast::statement_list>();
list = while_block->statements.get();
current_statement_list(location)->statements.push_back(std::move(statement.statement));
}
else if (auto function_definition = std::get_if<ast::function_definition>(statement.statement.get()))
{
function_definition->statements = std::make_unique<ast::statement_list>();
list = function_definition->statements.get();
current_statement_list(location)->statements.push_back(std::move(statement.statement));
}
else if (auto field_definition = std::get_if<ast::field_definition>(statement.statement.get()))
{
current_struct_definition(location)->fields.push_back(*field_definition);
auto current = current_struct_definition(location);
for (auto const & field : current->fields)
if (field.name == field_definition->name)
throw parse_error("Duplicate field definition: \"" + field.name + "\"", field.location);
current->fields.push_back(*field_definition);
}
else if (std::get_if<ast::struct_definition>(statement.statement.get()))
{

View file

@ -8,11 +8,12 @@ namespace pslang::type
struct identifier
{
std::string name;
std::size_t level = 0;
};
inline bool operator == (identifier const & t1, identifier const & t2)
{
return t1.name == t2.name;
return (t1.level == t2.level) && (t1.name == t2.name);
}
}