diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 279db92..70cb251 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -91,7 +92,9 @@ int main(int argc, char ** argv) try { filenames.push_back(argv[arg]); - parsed.push_back(parser::parse(filenames.back())); + auto ast = parser::parse(filenames.back()); + ast::resolve_identifiers(ast); + parsed.push_back(std::move(ast)); } catch (pslang::parser::parse_error const & error) { diff --git a/backlog.txt b/backlog.txt new file mode 100644 index 0000000..0d25ed4 --- /dev/null +++ b/backlog.txt @@ -0,0 +1,3 @@ +* Mutually recursive functions +* Refactor all tree visitors to prevent infinite recursion (maybe add decicated tree traversal functions & visitors?) +* Type identifier location + move types to ast library / split type values and type ast nodes diff --git a/libs/ast/include/pslang/ast/control.hpp b/libs/ast/include/pslang/ast/control.hpp index a4e5af5..0416878 100644 --- a/libs/ast/include/pslang/ast/control.hpp +++ b/libs/ast/include/pslang/ast/control.hpp @@ -13,20 +13,17 @@ namespace pslang::ast struct if_block { expression_ptr condition; - statement_list_ptr statements; ast::location location; }; struct else_block { - statement_list_ptr statements; ast::location location; }; struct else_if_block { expression_ptr condition; - statement_list_ptr statements; ast::location location; }; diff --git a/libs/ast/include/pslang/ast/error.hpp b/libs/ast/include/pslang/ast/error.hpp new file mode 100644 index 0000000..e264d67 --- /dev/null +++ b/libs/ast/include/pslang/ast/error.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include +#include + +namespace pslang::ast +{ + + struct parse_error + : std::exception + { + parse_error(std::string message, location location) + : message_(std::move(message)) + , filename_(location.filename) + , location_(location) + { + location_.filename = filename_; + } + + char const * what() const noexcept + { + return message_.c_str(); + } + + ast::location location() const noexcept + { + return location_; + } + + private: + std::string message_; + std::string filename_; + ast::location location_; + }; + + struct invalid_ast_error + : std::exception + { + invalid_ast_error(std::string message, location location) + : message_(std::move(message)) + , filename_(location.filename) + , location_(location) + { + location_.filename = filename_; + } + + char const * what() const noexcept + { + return message_.c_str(); + } + + ast::location location() const noexcept + { + return location_; + } + + private: + std::string message_; + std::string filename_; + ast::location location_; + }; + + +} diff --git a/libs/ast/include/pslang/ast/identifier.hpp b/libs/ast/include/pslang/ast/identifier.hpp index 5753854..d0b9961 100644 --- a/libs/ast/include/pslang/ast/identifier.hpp +++ b/libs/ast/include/pslang/ast/identifier.hpp @@ -11,6 +11,7 @@ namespace pslang::ast { std::string name; ast::location location; + std::size_t level = 0; }; } diff --git a/libs/ast/include/pslang/ast/preprocess.hpp b/libs/ast/include/pslang/ast/preprocess.hpp new file mode 100644 index 0000000..2f29fb8 --- /dev/null +++ b/libs/ast/include/pslang/ast/preprocess.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace pslang::ast +{ + + void resolve_identifiers(statement_list_ptr & statements); + +} diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index c5f5033..b3a9b66 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -185,14 +185,12 @@ namespace pslang::ast put_indent(out, options); out << "if\n"; print(out, node.condition, child(options)); - print(out, node.statements, child(options)); } void print_impl(std::ostream & out, else_block const & node, print_options const & options) { put_indent(out, options); out << "else\n"; - print(out, node.statements, child(options)); } void print_impl(std::ostream & out, else_if_block const & node, print_options const & options) @@ -200,7 +198,6 @@ namespace pslang::ast put_indent(out, options); out << "else if\n"; print(out, node.condition, child(options)); - print(out, node.statements, child(options)); } void print_impl(std::ostream & out, if_chain const & node, print_options const & options) diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp new file mode 100644 index 0000000..48226c1 --- /dev/null +++ b/libs/ast/source/resolve_identifiers.cpp @@ -0,0 +1,295 @@ +#include +#include +#include + +#include +#include + +namespace pslang::ast +{ + + namespace + { + + struct scope + { + std::unordered_set functions; + std::unordered_set structs; + std::unordered_set variables; + + bool contains_transitive(std::string const & name) const + { + return false + || (functions.count(name) > 0) + || (structs.count(name) > 0) + ; + } + + bool contains_type(std::string const & name) const + { + return false + || (structs.count(name) > 0) + ; + } + + bool contains(std::string const & name) const + { + return false + || (functions.count(name) > 0) + || (structs.count(name) > 0) + || (variables.count(name) > 0) + ; + } + + bool contains(std::string const & name, bool crossed_function_scope) const + { + if (crossed_function_scope) + return contains_transitive(name); + return contains(name); + } + + bool is_function_scope = false; + bool is_global_scope = false; + }; + + struct context + { + std::vector scopes; + }; + + void resolve_identifiers(context & context, type::type & type); + void resolve_identifiers(context & context, expression & expression); + void resolve_identifiers(context & context, statement_list const & statements); + + void resolve_identifiers_impl(context &, type::unit_type const &) + {} + + void resolve_identifiers_impl(context &, type::primitive_type const &) + {} + + void resolve_identifiers_impl(context & context, type::array_type const & array_type) + { + resolve_identifiers(context, *array_type.element_type); + } + + void resolve_identifiers_impl(context & context, type::function_type const & function_type) + { + for (auto const & argument : function_type.arguments) + resolve_identifiers(context, *argument); + resolve_identifiers(context, *function_type.result); + } + + void resolve_identifiers_impl(context & context, type::identifier & identifier) + { + for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it) + { + if (it->contains_type(identifier.name)) + { + identifier.level = it.base() - context.scopes.begin() - 1; + return; + } + } + + // TODO location + throw parse_error("Identifier \"" + identifier.name + "\" not found", {}); + } + + void resolve_identifiers(context & context, type::type & type) + { + return std::visit([&](auto & type){ return resolve_identifiers_impl(context, type); }, type); + } + + void resolve_identifiers_impl(context &, literal const &) + {} + + void resolve_identifiers_impl(context & context, identifier & identifier) + { + bool crossed_function_scope = false; + for (auto it = context.scopes.rbegin(); it != context.scopes.rend(); ++it) + { + if (it->contains(identifier.name, crossed_function_scope && !it->is_global_scope)) + { + identifier.level = it.base() - context.scopes.begin() - 1; + return; + } + + crossed_function_scope |= it->is_function_scope; + } + + throw parse_error("Identifier \"" + identifier.name + "\" not found", identifier.location); + } + + void resolve_identifiers_impl(context & context, unary_operation const & unary_operation) + { + resolve_identifiers(context, *unary_operation.arg1); + } + + void resolve_identifiers_impl(context & context, binary_operation const & binary_operation) + { + resolve_identifiers(context, *binary_operation.arg1); + resolve_identifiers(context, *binary_operation.arg2); + } + + void resolve_identifiers_impl(context & context, cast_operation const & cast_operation) + { + resolve_identifiers(context, *cast_operation.expression); + resolve_identifiers(context, *cast_operation.type); + } + + void resolve_identifiers_impl(context & context, function_call const & function_call) + { + resolve_identifiers(context, *function_call.function); + for (auto const & argument : function_call.arguments) + resolve_identifiers(context, *argument); + } + + void resolve_identifiers_impl(context & context, array const & array) + { + for (auto const & element : array.elements) + resolve_identifiers(context, *element); + } + + void resolve_identifiers_impl(context & context, array_access const & array_access) + { + resolve_identifiers(context, *array_access.array); + resolve_identifiers(context, *array_access.index); + } + + void resolve_identifiers_impl(context & context, field_access const & field_access) + { + resolve_identifiers(context, *field_access.object); + } + + void resolve_identifiers(context & context, expression & expression) + { + return std::visit([&](auto & expression){ return resolve_identifiers_impl(context, expression); }, expression); + } + + void resolve_identifiers_impl(context & context, expression_ptr const & expression_ptr) + { + resolve_identifiers(context, *expression_ptr); + } + + void resolve_identifiers_impl(context & context, assignment const & assignment) + { + resolve_identifiers(context, *assignment.lhs); + resolve_identifiers(context, *assignment.rhs); + } + + void resolve_identifiers_impl(context & context, variable_declaration const & variable_declaration) + { + if (context.scopes.back().contains(variable_declaration.name)) + throw parse_error("Identifier \"" + variable_declaration.name + "\" is already defined at this scope", variable_declaration.location); + + if (variable_declaration.type) + resolve_identifiers(context, *variable_declaration.type); + resolve_identifiers(context, *variable_declaration.initializer); + + context.scopes.back().variables.insert(variable_declaration.name); + } + + void resolve_identifiers_impl(context &, if_block const & if_block) + { + throw invalid_ast_error("if blocks cannot be present in the final AST", if_block.location); + } + + void resolve_identifiers_impl(context &, else_if_block const & else_if_block) + { + throw invalid_ast_error("else if blocks cannot be present in the final AST", else_if_block.location); + } + + void resolve_identifiers_impl(context &, else_block const & else_block) + { + throw invalid_ast_error("else blocks cannot be present in the final AST", else_block.location); + } + + void resolve_identifiers_impl(context & context, if_chain const & if_chain) + { + for (auto const & block : if_chain.blocks) + { + if (block.condition) + resolve_identifiers(context, *block.condition); + context.scopes.emplace_back(); + resolve_identifiers(context, *block.statements); + context.scopes.pop_back(); + } + } + + void resolve_identifiers_impl(context & context, while_block const & while_block) + { + resolve_identifiers(context, *while_block.condition); + context.scopes.emplace_back(); + resolve_identifiers(context, *while_block.statements); + context.scopes.pop_back(); + } + + void resolve_identifiers_impl(context & context, function_definition const & function_definition) + { + if (context.scopes.back().contains(function_definition.name)) + throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); + + std::unordered_set argument_names; + for (auto const & argument : function_definition.arguments) + { + if (argument_names.count(argument.name) > 0) + throw parse_error("Duplicate argument name \"" + argument.name + "\" in function \"" + function_definition.name + "\"", argument.location); + argument_names.insert(argument.name); + resolve_identifiers(context, *argument.type); + } + + resolve_identifiers(context, *function_definition.return_type); + + context.scopes.back().functions.insert(function_definition.name); + + auto & scope = context.scopes.emplace_back(); + scope.is_function_scope = true; + scope.variables = std::move(argument_names); + resolve_identifiers(context, *function_definition.statements); + context.scopes.pop_back(); + } + + void resolve_identifiers_impl(context & context, return_statement const & return_statement) + { + if (return_statement.value) + resolve_identifiers(context, *return_statement.value); + } + + void resolve_identifiers_impl(context & context, field_definition const & field_definition) + { + resolve_identifiers(context, *field_definition.type); + } + + void resolve_identifiers_impl(context & context, struct_definition const & struct_definition) + { + if (context.scopes.back().contains(struct_definition.name)) + throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); + + for (auto const & field : struct_definition.fields) + resolve_identifiers_impl(context, field); + + context.scopes.back().structs.insert(struct_definition.name); + } + + void resolve_identifiers(context & context, statement const & statement) + { + return std::visit([&](auto const & statement){ return resolve_identifiers_impl(context, statement); }, statement); + } + + void resolve_identifiers(context & context, statement_list const & statements) + { + for (auto const & statement : statements.statements) + { + resolve_identifiers(context, *statement); + } + } + + } + + void resolve_identifiers(statement_list_ptr & statements) + { + context context; + context.scopes.emplace_back().is_global_scope = true; + resolve_identifiers(context, *statements); + } + +} diff --git a/libs/parser/include/pslang/parser/error.hpp b/libs/parser/include/pslang/parser/error.hpp index 7738b88..24e198a 100644 --- a/libs/parser/include/pslang/parser/error.hpp +++ b/libs/parser/include/pslang/parser/error.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -8,32 +9,7 @@ namespace pslang::parser { - struct parse_error - : std::exception - { - parse_error(std::string message, ast::location location) - : message_(std::move(message)) - , filename_(location.filename) - , location_(location) - { - location_.filename = filename_; - } - - char const * what() const noexcept - { - return message_.c_str(); - } - - ast::location location() const noexcept - { - return location_; - } - - private: - std::string message_; - std::string filename_; - ast::location location_; - }; + using parse_error = ast::parse_error; struct internal_error : std::exception diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index ab211f7..5203c80 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -148,6 +148,7 @@ template %type statement %type > function_definition_argument_list %type > nonempty_function_definition_argument_list +%type function_definition_single_argument %type function_return_type %type variable_declaration %type variable_keyword @@ -183,9 +184,9 @@ statement : expression { $$ = std::make_unique($1); } | expression assignment expression { $$ = ast::assignment{ std::make_unique($1), std::make_unique($3), @$ }; } | variable_declaration { $$ = $1; } -| if expression colon { $$ = ast::if_block{std::make_unique($2), {}, @$}; } -| else colon { $$ = ast::else_block{{}, @$}; } -| else if expression colon { $$ = ast::else_if_block{std::make_unique($3), {}, @$}; } +| if expression colon { $$ = ast::if_block{std::make_unique($2), @$}; } +| else colon { $$ = ast::else_block{@$}; } +| else if expression colon { $$ = ast::else_if_block{std::make_unique($3), @$}; } | while expression colon { $$ = ast::while_block{std::make_unique($2), {}, @$}; } | func name lparen function_definition_argument_list rparen function_return_type colon { $$ = ast::function_definition{$2, $4, $6, {}, @$}; } | return expression { $$ = ast::return_statement{std::make_unique($2), @$}; } @@ -200,8 +201,12 @@ function_definition_argument_list ; nonempty_function_definition_argument_list -: name colon type_expression { std::vector tmp; tmp.push_back({.name = $1, .type = std::make_unique($3), @$}); $$ = std::move(tmp); } -| nonempty_function_definition_argument_list comma name colon type_expression { auto tmp = $1; tmp.push_back({.name = $3, .type = std::make_unique($5)}); $$ = std::move(tmp); } +: function_definition_single_argument { std::vector tmp; tmp.push_back($1); $$ = std::move(tmp); } +| nonempty_function_definition_argument_list comma function_definition_single_argument { auto tmp = $1; tmp.push_back($3); $$ = std::move(tmp); } +; + +function_definition_single_argument +: name colon type_expression { $$ = ast::function_definition::argument{$1, std::make_unique($3), @$}; } ; function_return_type diff --git a/libs/parser/source/finilize.cpp b/libs/parser/source/finilize.cpp index e2c7917..4eab83b 100644 --- a/libs/parser/source/finilize.cpp +++ b/libs/parser/source/finilize.cpp @@ -2,88 +2,11 @@ #include #include -#include #include namespace pslang::parser { - namespace - { - - ast::statement_list * get_statement_list(ast::expression_ptr &) - { - return nullptr; - } - - ast::statement_list * get_statement_list(ast::assignment &) - { - return nullptr; - } - - ast::statement_list * get_statement_list(ast::variable_declaration &) - { - return nullptr; - } - - ast::statement_list * get_statement_list(ast::if_block & node) - { - node.statements = std::make_unique(); - return node.statements.get(); - } - - ast::statement_list * get_statement_list(ast::else_block & node) - { - node.statements = std::make_unique(); - return node.statements.get(); - } - - ast::statement_list * get_statement_list(ast::else_if_block & node) - { - node.statements = std::make_unique(); - return node.statements.get(); - } - - // NB: if chain merging happens after retrieving statement list - ast::statement_list * get_statement_list(ast::if_chain & node) - { - return nullptr; - } - - ast::statement_list * get_statement_list(ast::while_block & node) - { - node.statements = std::make_unique(); - return node.statements.get(); - } - - ast::statement_list * get_statement_list(ast::function_definition & node) - { - node.statements = std::make_unique(); - return node.statements.get(); - } - - ast::statement_list * get_statement_list(ast::return_statement &) - { - return nullptr; - } - - ast::statement_list * get_statement_list(ast::field_definition &) - { - return nullptr; - } - - ast::statement_list * get_statement_list(ast::struct_definition &) - { - return nullptr; - } - - ast::statement_list * get_statement_list(ast::statement & statement) - { - return std::visit([](auto & value){ return get_statement_list(value); }, statement); - } - - } - ast::statement_list_ptr finilize(indented_statement_list statements) { ast::statement_list_ptr result = std::make_unique(); @@ -97,23 +20,23 @@ namespace pslang::parser auto current_statement_list = [&](ast::location const & location) -> ast::statement_list * { if (stack.empty()) - throw internal_error("empty finilization stack"); + throw internal_error("Empty finilization stack"); if (auto list = std::get_if(&stack.back())) return *list; - throw parse_error("unexpected statement inside struct definition", location); + throw parse_error("Unexpected statement inside struct definition", location); }; auto current_struct_definition = [&](ast::location const & location) -> ast::struct_definition * { if (stack.empty()) - throw std::runtime_error("Internal error: empty finilization stack"); + throw internal_error("Empty finilization stack"); if (auto list = std::get_if(&stack.back())) return *list; - throw parse_error("unexpected statement outside struct definition", location); + throw parse_error("Unexpected statement outside struct definition", location); }; for (auto & statement : statements.statements) @@ -121,7 +44,7 @@ namespace pslang::parser auto location = ast::get_location(*statement.statement); if (statement.indentation > current_indent) - throw parse_error("unexpected indent", location); + throw parse_error("Unexpected indent", location); while (statement.indentation < current_indent) { @@ -131,37 +54,56 @@ namespace pslang::parser // Now statement.indentation == current_indent - auto list = get_statement_list(*statement.statement); + ast::statement_list * list = nullptr; if (auto if_block = std::get_if(statement.statement.get())) { ast::if_chain chain; - chain.blocks.push_back({.condition = std::move(if_block->condition), .statements = std::move(if_block->statements)}); + chain.blocks.push_back({.condition = std::move(if_block->condition), .statements = std::make_unique()}); + list = chain.blocks.back().statements.get(); current_statement_list(location)->statements.push_back(std::make_unique(std::move(chain))); } else if (auto else_block = std::get_if(statement.statement.get())) { if (current_statement_list(location)->statements.empty()) - throw parse_error("unexpected else block", location); + throw parse_error("Unexpected else block", location); auto chain = std::get_if(current_statement_list(location)->statements.back().get()); if (!chain || chain->blocks.empty() || !chain->blocks.back().condition) - throw parse_error("unexpected else block", location); + throw parse_error("Unexpected else block", location); - chain->blocks.push_back({.condition = nullptr, .statements = std::move(else_block->statements)}); + chain->blocks.push_back({.condition = nullptr, .statements = std::make_unique()}); + list = chain->blocks.back().statements.get(); } else if (auto else_if_block = std::get_if(statement.statement.get())) { if (current_statement_list(location)->statements.empty()) - throw parse_error("unexpected else if block", location); + throw parse_error("Unexpected else if block", location); auto chain = std::get_if(current_statement_list(location)->statements.back().get()); if (!chain || chain->blocks.empty() || !chain->blocks.back().condition) - throw parse_error("unexpected else if block", location); + throw parse_error("Unexpected else if block", location); - chain->blocks.push_back({.condition = std::move(else_if_block->condition), .statements = std::move(else_if_block->statements)}); + chain->blocks.push_back({.condition = std::move(else_if_block->condition), .statements = std::make_unique()}); + list = chain->blocks.back().statements.get(); + } + else if (auto while_block = std::get_if(statement.statement.get())) + { + while_block->statements = std::make_unique(); + list = while_block->statements.get(); + current_statement_list(location)->statements.push_back(std::move(statement.statement)); + } + else if (auto function_definition = std::get_if(statement.statement.get())) + { + function_definition->statements = std::make_unique(); + list = function_definition->statements.get(); + current_statement_list(location)->statements.push_back(std::move(statement.statement)); } else if (auto field_definition = std::get_if(statement.statement.get())) { - current_struct_definition(location)->fields.push_back(*field_definition); + auto current = current_struct_definition(location); + for (auto const & field : current->fields) + if (field.name == field_definition->name) + throw parse_error("Duplicate field definition: \"" + field.name + "\"", field.location); + current->fields.push_back(*field_definition); } else if (std::get_if(statement.statement.get())) { diff --git a/libs/type/include/pslang/type/identifier.hpp b/libs/type/include/pslang/type/identifier.hpp index d02790d..74ec88d 100644 --- a/libs/type/include/pslang/type/identifier.hpp +++ b/libs/type/include/pslang/type/identifier.hpp @@ -8,11 +8,12 @@ namespace pslang::type struct identifier { std::string name; + std::size_t level = 0; }; inline bool operator == (identifier const & t1, identifier const & t2) { - return t1.name == t2.name; + return (t1.level == t2.level) && (t1.name == t2.name); } }