Properly implement inner functions in aarch64 compiler & support module entry point

This commit is contained in:
Nikita Lisitsa 2026-03-15 14:15:36 +03:00
parent 4b555c2ad4
commit b25400ad65
5 changed files with 195 additions and 39 deletions

View file

@ -195,13 +195,10 @@ int main(int argc, char ** argv)
auto executable = jit::make_host_executable(pcontext.code); auto executable = jit::make_host_executable(pcontext.code);
{ // TODO: multiple input files => multiple entry points?
// TODO: remove, testing-only code; should execute entry point instead // Should probably compile them as separate modules.
auto offset = pcontext.symbols.at("test"); auto entry_point = (void(*)())(executable.data.get() + pcontext.entry_point);
auto fptr = (int(*)(int))(executable.data.get() + offset); entry_point();
auto x = fptr(0);
std::cout << "Result: " << std::boolalpha << x << std::endl;
}
} }
else else
{ {

View file

@ -15,7 +15,7 @@ foreign func putchar(c: i32) -> i32
func print(c: u8): func print(c: u8):
putchar(c as i32) putchar(c as i32)
func test(): func test1():
print('H') print('H')
print('e') print('e')
print('l') print('l')
@ -31,6 +31,21 @@ func test():
print('!') print('!')
print('\n') print('\n')
func b() -> i32:
return 300
func test() -> i32:
if false:
func b() -> i32:
return 200
return b()
else:
return b()
print('O')
print('K')
print('\n')
//func test1(): //func test1():
// let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n'] // let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n']
// mut i = 0 // mut i = 0

View file

@ -23,6 +23,7 @@ namespace pslang::jit
jit::abi abi; jit::abi abi;
std::vector<std::uint8_t> code = {}; std::vector<std::uint8_t> code = {};
std::unordered_map<std::string, std::int32_t> symbols = {}; std::unordered_map<std::string, std::int32_t> symbols = {};
std::int32_t entry_point = 0;
std::vector<foreign_resolve_info> foreign_resolve = {}; std::vector<foreign_resolve_info> foreign_resolve = {};
}; };

View file

@ -8,6 +8,7 @@
#include <stdexcept> #include <stdexcept>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
@ -23,16 +24,52 @@ namespace pslang::jit::aarch64
std::unordered_map<float, std::int32_t> f32_constants; std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants; std::unordered_map<double, std::int32_t> f64_constants;
std::unordered_map<std::string, std::int32_t> foreign_address;
std::unordered_map<ast::function_definition const *, std::int32_t> functions;
struct scope
{
std::unordered_set<std::string> foreign_functions;
std::unordered_map<std::string, ast::function_definition const *> functions;
};
std::vector<scope> scopes;
struct resolve_info struct resolve_info
{ {
std::string name; ast::function_definition const * node;
// Must be 'adr' instruction // Must be 'adr' instruction
std::int32_t instruction_offset; std::int32_t instruction_offset;
}; };
std::vector<resolve_info> resolve; std::vector<resolve_info> resolve;
std::unordered_map<std::string, std::int32_t> foreign_address; bool is_foreign(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->foreign_functions.contains(name))
return true;
if (it->functions.contains(name))
return false;
}
return false;
}
ast::function_definition const * is_function(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->foreign_functions.contains(name))
return nullptr;
if (auto jt = it->functions.find(name); jt != it->functions.end())
return jt->second;
}
return nullptr;
}
}; };
std::uint8_t fp_mode_for(types::type const & type) std::uint8_t fp_mode_for(types::type const & type)
@ -56,6 +93,8 @@ namespace pslang::jit::aarch64
} }
} }
// Add all f16, f32 and f64 constants as read-only data entries
// Add extern pointers for all foreign functions as read-only data entries
struct populate_constants_visitor struct populate_constants_visitor
: ast::const_expression_visitor<populate_constants_visitor> : ast::const_expression_visitor<populate_constants_visitor>
, ast::const_statement_visitor<populate_constants_visitor> , ast::const_statement_visitor<populate_constants_visitor>
@ -219,6 +258,48 @@ namespace pslang::jit::aarch64
} }
}; };
// Iterate over a single scope (i.e. not visiting subscopes recursively)
// and add all defined functions & foreign functions to the current scope
struct populate_symbols_visitor
: ast::const_statement_visitor<populate_symbols_visitor>
{
local_context & lcontext;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const &) {}
void apply(ast::while_block const &) {}
void apply(ast::function_definition const & node)
{
lcontext.scopes.back().functions[node.name] = &node;
}
void apply(ast::foreign_function_declaration const & node)
{
lcontext.scopes.back().foreign_functions.insert(node.name);
}
void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {}
};
struct reg_extend_visitor struct reg_extend_visitor
: types::const_visitor<reg_extend_visitor> : types::const_visitor<reg_extend_visitor>
{ {
@ -262,6 +343,8 @@ namespace pslang::jit::aarch64
} }
}; };
// Compile a single function and store the entry point offset
// in local_context
struct compile_function_visitor struct compile_function_visitor
: ast::const_statement_visitor<compile_function_visitor> : ast::const_statement_visitor<compile_function_visitor>
, ast::const_expression_visitor<compile_function_visitor> , ast::const_expression_visitor<compile_function_visitor>
@ -375,15 +458,18 @@ namespace pslang::jit::aarch64
} }
} }
// Not a variable - must be a function! if (lcontext.is_foreign(node.name))
if (lcontext.foreign_address.contains(node.name))
{ {
builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4); builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4);
} }
else if (auto function_node = lcontext.is_function(node.name))
{
lcontext.resolve.push_back({function_node, (std::int32_t)pcontext.code.size()});
builder.adr(0, 0);
}
else else
{ {
lcontext.resolve.push_back({node.name, (std::int32_t)pcontext.code.size()}); throw std::runtime_error("unknown identifier \"" + node.name + "\"");
builder.adr(0, 0);
} }
} }
@ -830,22 +916,27 @@ namespace pslang::jit::aarch64
void apply(ast::function_definition const &) void apply(ast::function_definition const &)
{ {
// Must be handled prior to that // Must be handled prior to that in populate_symbols_visitor
} }
void apply(ast::foreign_function_declaration const &) void apply(ast::foreign_function_declaration const &)
{ {
// Must be handled prior to that // Must be handled prior to that in populate_symbols_visitor
} }
void apply(ast::statement_list const & node) void apply(ast::statement_list const & node)
{ {
lcontext.scopes.emplace_back();
populate_symbols_visitor{{}, lcontext}.apply(node);
for (auto const & statement : node.statements) for (auto const & statement : node.statements)
apply(*statement); apply(*statement);
lcontext.scopes.pop_back();
} }
void do_apply(ast::function_definition const & node) void do_apply(ast::function_definition const & node)
{ {
lcontext.functions[&node] = pcontext.code.size();
// TODO: struct arguments // TODO: struct arguments
scopes.emplace_back(); scopes.emplace_back();
@ -876,7 +967,12 @@ namespace pslang::jit::aarch64
} }
scopes.back().variables[argument.name] = {.frame_offset = stack_offset}; scopes.back().variables[argument.name] = {.frame_offset = stack_offset};
} }
apply(*node.statements); apply(*node.statements);
if (node.statements->statements.empty() || !std::get_if<ast::return_statement>(node.statements->statements.back().get()))
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
do_return();
scopes.pop_back(); scopes.pop_back();
} }
@ -942,34 +1038,63 @@ namespace pslang::jit::aarch64
} }
}; };
// Main compilation visitor
struct compile_visitor struct compile_visitor
: ast::const_statement_visitor<compile_visitor> : ast::const_statement_visitor<compile_visitor>
{ {
using const_statement_visitor::apply;
program_context & pcontext; program_context & pcontext;
local_context & lcontext; local_context & lcontext;
instruction_builder builder{pcontext.code}; instruction_builder & builder;
template <typename Statement> using const_statement_visitor::apply;
void apply(Statement const &)
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const & node)
{ {
throw std::runtime_error(std::string("compile_visitor is not implemented for ") + typeid(Statement).name()); for (auto const & block : node.blocks)
apply(*block.statements);
}
void apply(ast::while_block const & node)
{
apply(*node.statements);
} }
void apply(ast::function_definition const & node) void apply(ast::function_definition const & node)
{ {
pcontext.symbols[node.name] = pcontext.code.size(); compile_function_visitor{{}, {}, pcontext, lcontext}.do_apply(node);
compile_function_visitor visitor{{}, {}, pcontext, lcontext};
visitor.do_apply(node); apply(*node.statements);
if (node.statements->statements.empty() || !std::get_if<ast::return_statement>(node.statements->statements.back().get()))
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
visitor.do_return();
} }
void apply(ast::foreign_function_declaration const &) void apply(ast::foreign_function_declaration const &) {}
void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {}
void apply(ast::statement_list const & node)
{ {
// Already handled by populate_globals lcontext.scopes.emplace_back();
populate_symbols_visitor{{}, lcontext}.apply(node);
for (auto const & statement : node.statements)
apply(*statement);
// Don't pop_back entry point scope
if (lcontext.scopes.size() > 1)
lcontext.scopes.pop_back();
} }
}; };
@ -977,18 +1102,36 @@ namespace pslang::jit::aarch64
void compile(program_context & pcontext, ast::statement_list_ptr const & statements) void compile(program_context & pcontext, ast::statement_list_ptr const & statements)
{ {
// Add a fake AST node for the entry point
auto root = std::make_shared<ast::statement>(ast::function_definition{
{
{},
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
statements,
{},
});
local_context lcontext; local_context lcontext;
populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements);
compile_visitor{{}, pcontext, lcontext}.apply(*statements);
instruction_builder builder{pcontext.code}; instruction_builder builder{pcontext.code};
populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements);
compile_visitor{{}, pcontext, lcontext, builder}.apply(*root);
for (auto const & resolve : lcontext.resolve) for (auto const & resolve : lcontext.resolve)
{ builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, lcontext.functions.at(resolve.node) - resolve.instruction_offset);
builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, pcontext.symbols.at(resolve.name) - resolve.instruction_offset);
}
for (auto const & foreign : lcontext.foreign_address) for (auto const & foreign : lcontext.foreign_address)
pcontext.foreign_resolve.push_back({foreign.first, foreign.second}); pcontext.foreign_resolve.push_back({foreign.first, foreign.second});
for (auto const & function : lcontext.scopes.front().functions)
pcontext.symbols[function.first] = lcontext.functions.at(function.second);
pcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
} }
} }

View file

@ -1,4 +1,5 @@
Future plans: Future plans:
* Globals (requires a separate mmaped segment in JIT compiler)
* Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter * Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter
* Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value * Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value
* Const propagation: annotate expression AST nodes that are computable in compile-time * Const propagation: annotate expression AST nodes that are computable in compile-time
@ -14,11 +15,10 @@ Interpreter backlog:
* C FFI (foreign functions) * C FFI (foreign functions)
Aarch64 compiler backlog: Aarch64 compiler backlog:
* Inner functions * Structs
* Struct values * Arrays
General backlog: General backlog:
* Separate single-line location & full statement list location for function definitions, if chains, while blocks, structs, etc
* Mutually recursive structs (relevant only with pointers) * Mutually recursive structs (relevant only with pointers)
* Empty array expression * Empty array expression
* Calling functions as methods * Calling functions as methods