diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 2dfcec9..5177b97 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -173,9 +173,9 @@ int main(int argc, char ** argv) { // TODO: remove, testing-only code; should execute entry point instead - auto offset = pcontext.symbols.at("sqr"); - auto fptr = (unsigned(*)(unsigned))(executable.data.get() + offset); - auto x = fptr(30); + auto offset = pcontext.symbols.at("foo"); + auto fptr = (uint32_t(*)(uint32_t))(executable.data.get() + offset); + auto x = fptr(10u); std::cout << "Result: " << std::boolalpha << x << std::endl; } } diff --git a/examples/jit_test.psl b/examples/jit_test.psl index a644ef6..2f9f62d 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -1,5 +1,7 @@ -func sqr(x : f32) -> f32: - return x * x +func foo(x : u32) -> u32: + if x == 0u: + return 42u + return 1u + bar(x - 1u) -func smoothstep(x : f32) -> f32: - return sqr(x) * (3.0 - 2.0 * x) +func bar(x : u32) -> u32: + return foo(x) diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index ea93e9b..ea8d4ee 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -56,12 +56,66 @@ namespace pslang::ast bool is_global_scope = false; }; + struct populate_globals_visitor + : statement_visitor + { + std::vector & scopes; + + using statement_visitor::apply; + + void apply(expression_ptr const &) + {} + + void apply(assignment const &) + {} + + void apply(variable_declaration const &) + {} + + void apply(if_block const &) + {} + + void apply(else_if_block const &) + {} + + void apply(else_block const &) + {} + + void apply(if_chain const &) + {} + + void apply(while_block const &) + {} + + void apply(function_definition const & function_definition) + { + if (scopes.back().contains(function_definition.name)) + throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); + + scopes.back().functions.insert(function_definition.name); + } + + void apply(return_statement const &) + {} + + void apply(field_definition const &) + {} + + void apply(struct_definition const & struct_definition) + { + if (scopes.back().contains(struct_definition.name)) + throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); + + scopes.back().structs.insert(struct_definition.name); + } + }; + struct resolve_identifiers_visitor : type_visitor , expression_visitor , statement_visitor { - std::vector scopes; + std::vector & scopes; using type_visitor::apply; using expression_visitor::apply; @@ -246,8 +300,7 @@ namespace pslang::ast void apply(function_definition const & function_definition) { - if (scopes.back().contains(function_definition.name)) - throw parse_error("Identifier \"" + function_definition.name + "\" is already defined at this scope", function_definition.location); + // Already added to scope by populate_globals_visitor std::unordered_set argument_names; for (auto const & argument : function_definition.arguments) @@ -282,8 +335,7 @@ namespace pslang::ast void apply(struct_definition const & struct_definition) { - if (scopes.back().contains(struct_definition.name)) - throw parse_error("Identifier \"" + struct_definition.name + "\" is already defined at this scope", struct_definition.location); + // Already added to scope by populate_globals_visitor for (auto const & field : struct_definition.fields) apply(field); @@ -291,14 +343,23 @@ namespace pslang::ast scopes.back().structs.insert(struct_definition.name); } + void apply(statement_list & statement_list) + { + populate_globals_visitor populate_globals_visitor{{}, scopes}; + populate_globals_visitor.apply(statement_list); + + for (auto const & statement : statement_list.statements) + apply(*statement); + } }; } void resolve_identifiers(statement_list_ptr & statements) { - resolve_identifiers_visitor visitor; - visitor.scopes.emplace_back().is_global_scope = true; + std::vector scopes; + scopes.emplace_back().is_global_scope = true; + resolve_identifiers_visitor visitor{{}, {}, {}, scopes}; visitor.apply(*statements); } diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 755e6bd..efb1592 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -50,16 +50,12 @@ namespace pslang::ast types::type_ptr expected_return_type = nullptr; }; - struct check_visitor - : type_visitor - , expression_visitor - , const_statement_visitor + struct resolve_types_visitor + : type_visitor { - std::vector scopes; + std::vector & scopes; using type_visitor::apply; - using expression_visitor::apply; - using const_statement_visitor::apply; void apply(types::unit_type const &) {} @@ -98,6 +94,104 @@ namespace pslang::ast type.level = node.level; node.inferred_type = std::make_unique(std::move(type)); } + }; + + void resolve_types(std::vector & scopes, ast::type & type) + { + resolve_types_visitor visitor{{}, scopes}; + visitor.apply(type); + } + + struct populate_globals_visitor + : statement_visitor + { + std::vector & scopes; + + using statement_visitor::apply; + + void apply(expression_ptr const &) + {} + + void apply(assignment const &) + {} + + void apply(variable_declaration const &) + {} + + void apply(if_block const &) + {} + + void apply(else_if_block const &) + {} + + void apply(else_block const &) + {} + + void apply(if_chain const &) + {} + + void apply(while_block const &) + {} + + void apply(function_definition const & node) + { + for (auto const & argument : node.arguments) + resolve_types(scopes, *argument.type); + resolve_types(scopes, *node.return_type); + + auto & data = scopes.back().functions[node.name]; + for (auto const & argument : node.arguments) + data.arguments.push_back(get_type(*argument.type)); + data.result_type = get_type(*node.return_type); + + scopes.emplace_back().is_function_scope = true; + scopes.back().expected_return_type = get_type(*node.return_type); + + for (auto const & argument : node.arguments) + { + scopes.back().variables[argument.name] = { + .category = value_category::constant, + .type = get_type(*argument.type), + }; + } + + apply(*node.statements); + scopes.pop_back(); + } + + void apply(return_statement const &) + {} + + void apply(field_definition const & node) + { + resolve_types(scopes, *node.type); + } + + void apply(struct_definition const & node) + { + for (auto const & field : node.fields) + apply(field); + + auto & data = scopes.back().structs[node.name]; + for (auto const & field : node.fields) + data.fields.push_back({.name = field.name, .type = get_type(*field.type)}); + } + }; + + void populate_globals(std::vector & scopes, statement_list & statement_list) + { + populate_globals_visitor visitor{{}, scopes}; + visitor.apply(statement_list); + } + + struct check_visitor + : expression_visitor + , const_statement_visitor + { + std::vector & scopes; + + using expression_visitor::apply; + using const_statement_visitor::apply; void apply(literal &) {} @@ -277,7 +371,7 @@ namespace pslang::ast void apply(cast_operation & node) { apply(*node.expression); - apply(*node.type); + resolve_types(scopes, *node.type); auto source_type = get_type(*node.expression); auto target_type = get_type(*node.type); @@ -304,7 +398,7 @@ namespace pslang::ast if (node.function) apply(*node.function); if (node.type) - apply(*node.type); + resolve_types(scopes, *node.type); for (auto const & argument : node.arguments) apply(*argument); @@ -530,7 +624,7 @@ namespace pslang::ast auto actual_type = get_type(*node.initializer); if (node.type) { - apply(*node.type); + resolve_types(scopes, *node.type); auto expected_type = get_type(*node.type); if (!types::equal(*expected_type, *actual_type)) { @@ -606,14 +700,7 @@ namespace pslang::ast void apply(function_definition const & node) { - for (auto const & argument : node.arguments) - apply(*argument.type); - apply(*node.return_type); - - auto & data = scopes.back().functions[node.name]; - for (auto const & argument : node.arguments) - data.arguments.push_back(get_type(*argument.type)); - data.result_type = get_type(*node.return_type); + // Already added to scope by populate_globals_visitor scopes.emplace_back().is_function_scope = true; scopes.back().expected_return_type = get_type(*node.return_type); @@ -658,19 +745,19 @@ namespace pslang::ast } } - void apply(field_definition const & node) - { - apply(*node.type); - } + void apply(field_definition const &) + {} void apply(struct_definition const & node) { - for (auto const & field : node.fields) - apply(field); + // Already added to scope by populate_globals_visitor + } - auto & data = scopes.back().structs[node.name]; - for (auto const & field : node.fields) - data.fields.push_back({.name = field.name, .type = get_type(*field.type)}); + void apply(statement_list & node) + { + populate_globals(scopes, node); + for (auto const & statement : node.statements) + apply(*statement); } }; @@ -678,8 +765,9 @@ namespace pslang::ast void check_and_infer_types(statement_list_ptr & statements) { - check_visitor visitor; - visitor.scopes.emplace_back().is_global_scope = true; + std::vector scopes; + scopes.emplace_back().is_global_scope = true; + check_visitor visitor{{}, {}, scopes}; visitor.apply(*statements); }