From b25400ad65f2d06cc11b15c92bf21d2c3ce6c4dd Mon Sep 17 00:00:00 2001 From: lisyarus Date: Sun, 15 Mar 2026 14:15:36 +0300 Subject: [PATCH] Properly implement inner functions in aarch64 compiler & support module entry point --- apps/interpreter/source/main.cpp | 11 +- examples/jit_test.psl | 17 +- .../include/pslang/jit/program_context.hpp | 1 + libs/jit/source/arch/aarch64/compiler.cpp | 199 +++++++++++++++--- plans.txt | 6 +- 5 files changed, 195 insertions(+), 39 deletions(-) diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 6531c80..0f1d0ca 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -195,13 +195,10 @@ int main(int argc, char ** argv) auto executable = jit::make_host_executable(pcontext.code); - { - // TODO: remove, testing-only code; should execute entry point instead - auto offset = pcontext.symbols.at("test"); - auto fptr = (int(*)(int))(executable.data.get() + offset); - auto x = fptr(0); - std::cout << "Result: " << std::boolalpha << x << std::endl; - } + // TODO: multiple input files => multiple entry points? + // Should probably compile them as separate modules. + auto entry_point = (void(*)())(executable.data.get() + pcontext.entry_point); + entry_point(); } else { diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 199f236..bd949af 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -15,7 +15,7 @@ foreign func putchar(c: i32) -> i32 func print(c: u8): putchar(c as i32) -func test(): +func test1(): print('H') print('e') print('l') @@ -31,6 +31,21 @@ func test(): print('!') print('\n') +func b() -> i32: + return 300 + +func test() -> i32: + if false: + func b() -> i32: + return 200 + return b() + else: + return b() + +print('O') +print('K') +print('\n') + //func test1(): // let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n'] // mut i = 0 diff --git a/libs/jit/include/pslang/jit/program_context.hpp b/libs/jit/include/pslang/jit/program_context.hpp index b86c03b..90a9516 100644 --- a/libs/jit/include/pslang/jit/program_context.hpp +++ b/libs/jit/include/pslang/jit/program_context.hpp @@ -23,6 +23,7 @@ namespace pslang::jit jit::abi abi; std::vector code = {}; std::unordered_map symbols = {}; + std::int32_t entry_point = 0; std::vector foreign_resolve = {}; }; diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 313c23e..661f0e3 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -23,16 +24,52 @@ namespace pslang::jit::aarch64 std::unordered_map f32_constants; std::unordered_map f64_constants; + std::unordered_map foreign_address; + + std::unordered_map functions; + + struct scope + { + std::unordered_set foreign_functions; + std::unordered_map functions; + }; + + std::vector scopes; + struct resolve_info { - std::string name; + ast::function_definition const * node; // Must be 'adr' instruction std::int32_t instruction_offset; }; std::vector resolve; - std::unordered_map foreign_address; + bool is_foreign(std::string const & name) + { + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + if (it->foreign_functions.contains(name)) + return true; + + if (it->functions.contains(name)) + return false; + } + return false; + } + + ast::function_definition const * is_function(std::string const & name) + { + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + if (it->foreign_functions.contains(name)) + return nullptr; + + if (auto jt = it->functions.find(name); jt != it->functions.end()) + return jt->second; + } + return nullptr; + } }; std::uint8_t fp_mode_for(types::type const & type) @@ -56,6 +93,8 @@ namespace pslang::jit::aarch64 } } + // Add all f16, f32 and f64 constants as read-only data entries + // Add extern pointers for all foreign functions as read-only data entries struct populate_constants_visitor : ast::const_expression_visitor , ast::const_statement_visitor @@ -219,6 +258,48 @@ namespace pslang::jit::aarch64 } }; + // Iterate over a single scope (i.e. not visiting subscopes recursively) + // and add all defined functions & foreign functions to the current scope + struct populate_symbols_visitor + : ast::const_statement_visitor + { + local_context & lcontext; + + using const_statement_visitor::apply; + + void apply(ast::expression_ptr const &) {} + + void apply(ast::assignment const &) {} + + void apply(ast::variable_declaration const &) {} + + void apply(ast::if_block const &) {} + + void apply(ast::else_block const &) {} + + void apply(ast::else_if_block const &) {} + + void apply(ast::if_chain const &) {} + + void apply(ast::while_block const &) {} + + void apply(ast::function_definition const & node) + { + lcontext.scopes.back().functions[node.name] = &node; + } + + void apply(ast::foreign_function_declaration const & node) + { + lcontext.scopes.back().foreign_functions.insert(node.name); + } + + void apply(ast::return_statement const &) {} + + void apply(ast::field_definition const &) {} + + void apply(ast::struct_definition const &) {} + }; + struct reg_extend_visitor : types::const_visitor { @@ -262,6 +343,8 @@ namespace pslang::jit::aarch64 } }; + // Compile a single function and store the entry point offset + // in local_context struct compile_function_visitor : ast::const_statement_visitor , ast::const_expression_visitor @@ -375,15 +458,18 @@ namespace pslang::jit::aarch64 } } - // Not a variable - must be a function! - if (lcontext.foreign_address.contains(node.name)) + if (lcontext.is_foreign(node.name)) { builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4); } + else if (auto function_node = lcontext.is_function(node.name)) + { + lcontext.resolve.push_back({function_node, (std::int32_t)pcontext.code.size()}); + builder.adr(0, 0); + } else { - lcontext.resolve.push_back({node.name, (std::int32_t)pcontext.code.size()}); - builder.adr(0, 0); + throw std::runtime_error("unknown identifier \"" + node.name + "\""); } } @@ -830,22 +916,27 @@ namespace pslang::jit::aarch64 void apply(ast::function_definition const &) { - // Must be handled prior to that + // Must be handled prior to that in populate_symbols_visitor } void apply(ast::foreign_function_declaration const &) { - // Must be handled prior to that + // Must be handled prior to that in populate_symbols_visitor } void apply(ast::statement_list const & node) { + lcontext.scopes.emplace_back(); + populate_symbols_visitor{{}, lcontext}.apply(node); for (auto const & statement : node.statements) apply(*statement); + lcontext.scopes.pop_back(); } void do_apply(ast::function_definition const & node) { + lcontext.functions[&node] = pcontext.code.size(); + // TODO: struct arguments scopes.emplace_back(); @@ -876,7 +967,12 @@ namespace pslang::jit::aarch64 } scopes.back().variables[argument.name] = {.frame_offset = stack_offset}; } + apply(*node.statements); + if (node.statements->statements.empty() || !std::get_if(node.statements->statements.back().get())) + if (types::equal(*ast::get_type(*node.return_type), types::unit_type{})) + do_return(); + scopes.pop_back(); } @@ -942,34 +1038,63 @@ namespace pslang::jit::aarch64 } }; + // Main compilation visitor struct compile_visitor : ast::const_statement_visitor { - using const_statement_visitor::apply; - program_context & pcontext; local_context & lcontext; - instruction_builder builder{pcontext.code}; - - template - void apply(Statement const &) + instruction_builder & builder; + + using const_statement_visitor::apply; + + void apply(ast::expression_ptr const &) {} + + void apply(ast::assignment const &) {} + + void apply(ast::variable_declaration const &) {} + + void apply(ast::if_block const &) {} + + void apply(ast::else_block const &) {} + + void apply(ast::else_if_block const &) {} + + void apply(ast::if_chain const & node) { - throw std::runtime_error(std::string("compile_visitor is not implemented for ") + typeid(Statement).name()); + for (auto const & block : node.blocks) + apply(*block.statements); + } + + void apply(ast::while_block const & node) + { + apply(*node.statements); } void apply(ast::function_definition const & node) { - pcontext.symbols[node.name] = pcontext.code.size(); - compile_function_visitor visitor{{}, {}, pcontext, lcontext}; - visitor.do_apply(node); - if (node.statements->statements.empty() || !std::get_if(node.statements->statements.back().get())) - if (types::equal(*ast::get_type(*node.return_type), types::unit_type{})) - visitor.do_return(); + compile_function_visitor{{}, {}, pcontext, lcontext}.do_apply(node); + + apply(*node.statements); } - void apply(ast::foreign_function_declaration const &) + void apply(ast::foreign_function_declaration const &) {} + + void apply(ast::return_statement const &) {} + + void apply(ast::field_definition const &) {} + + void apply(ast::struct_definition const &) {} + + void apply(ast::statement_list const & node) { - // Already handled by populate_globals + lcontext.scopes.emplace_back(); + populate_symbols_visitor{{}, lcontext}.apply(node); + for (auto const & statement : node.statements) + apply(*statement); + // Don't pop_back entry point scope + if (lcontext.scopes.size() > 1) + lcontext.scopes.pop_back(); } }; @@ -977,18 +1102,36 @@ namespace pslang::jit::aarch64 void compile(program_context & pcontext, ast::statement_list_ptr const & statements) { + // Add a fake AST node for the entry point + auto root = std::make_shared(ast::function_definition{ + { + {}, + {}, + std::make_shared(types::unit_type{}), + {}, + }, + statements, + {}, + }); + local_context lcontext; - populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements); - compile_visitor{{}, pcontext, lcontext}.apply(*statements); instruction_builder builder{pcontext.code}; + + populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements); + + compile_visitor{{}, pcontext, lcontext, builder}.apply(*root); + for (auto const & resolve : lcontext.resolve) - { - builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, pcontext.symbols.at(resolve.name) - resolve.instruction_offset); - } + builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, lcontext.functions.at(resolve.node) - resolve.instruction_offset); for (auto const & foreign : lcontext.foreign_address) pcontext.foreign_resolve.push_back({foreign.first, foreign.second}); + + for (auto const & function : lcontext.scopes.front().functions) + pcontext.symbols[function.first] = lcontext.functions.at(function.second); + + pcontext.entry_point = lcontext.functions.at(std::get_if(root.get())); } } \ No newline at end of file diff --git a/plans.txt b/plans.txt index 18d9bd1..eb4a6e2 100644 --- a/plans.txt +++ b/plans.txt @@ -1,4 +1,5 @@ Future plans: +* Globals (requires a separate mmaped segment in JIT compiler) * Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter * Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value * Const propagation: annotate expression AST nodes that are computable in compile-time @@ -14,11 +15,10 @@ Interpreter backlog: * C FFI (foreign functions) Aarch64 compiler backlog: -* Inner functions -* Struct values +* Structs +* Arrays General backlog: -* Separate single-line location & full statement list location for function definitions, if chains, while blocks, structs, etc * Mutually recursive structs (relevant only with pointers) * Empty array expression * Calling functions as methods