Properly implement inner functions in aarch64 compiler & support module entry point

This commit is contained in:
Nikita Lisitsa 2026-03-15 14:15:36 +03:00
parent 4b555c2ad4
commit b25400ad65
5 changed files with 195 additions and 39 deletions

View file

@ -195,13 +195,10 @@ int main(int argc, char ** argv)
auto executable = jit::make_host_executable(pcontext.code);
{
// TODO: remove, testing-only code; should execute entry point instead
auto offset = pcontext.symbols.at("test");
auto fptr = (int(*)(int))(executable.data.get() + offset);
auto x = fptr(0);
std::cout << "Result: " << std::boolalpha << x << std::endl;
}
// TODO: multiple input files => multiple entry points?
// Should probably compile them as separate modules.
auto entry_point = (void(*)())(executable.data.get() + pcontext.entry_point);
entry_point();
}
else
{

View file

@ -15,7 +15,7 @@ foreign func putchar(c: i32) -> i32
func print(c: u8):
putchar(c as i32)
func test():
func test1():
print('H')
print('e')
print('l')
@ -31,6 +31,21 @@ func test():
print('!')
print('\n')
func b() -> i32:
return 300
func test() -> i32:
if false:
func b() -> i32:
return 200
return b()
else:
return b()
print('O')
print('K')
print('\n')
//func test1():
// let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n']
// mut i = 0

View file

@ -23,6 +23,7 @@ namespace pslang::jit
jit::abi abi;
std::vector<std::uint8_t> code = {};
std::unordered_map<std::string, std::int32_t> symbols = {};
std::int32_t entry_point = 0;
std::vector<foreign_resolve_info> foreign_resolve = {};
};

View file

@ -8,6 +8,7 @@
#include <stdexcept>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <sstream>
@ -23,16 +24,52 @@ namespace pslang::jit::aarch64
std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants;
std::unordered_map<std::string, std::int32_t> foreign_address;
std::unordered_map<ast::function_definition const *, std::int32_t> functions;
struct scope
{
std::unordered_set<std::string> foreign_functions;
std::unordered_map<std::string, ast::function_definition const *> functions;
};
std::vector<scope> scopes;
struct resolve_info
{
std::string name;
ast::function_definition const * node;
// Must be 'adr' instruction
std::int32_t instruction_offset;
};
std::vector<resolve_info> resolve;
std::unordered_map<std::string, std::int32_t> foreign_address;
bool is_foreign(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->foreign_functions.contains(name))
return true;
if (it->functions.contains(name))
return false;
}
return false;
}
ast::function_definition const * is_function(std::string const & name)
{
for (auto it = scopes.rbegin(); it != scopes.rend(); ++it)
{
if (it->foreign_functions.contains(name))
return nullptr;
if (auto jt = it->functions.find(name); jt != it->functions.end())
return jt->second;
}
return nullptr;
}
};
std::uint8_t fp_mode_for(types::type const & type)
@ -56,6 +93,8 @@ namespace pslang::jit::aarch64
}
}
// Add all f16, f32 and f64 constants as read-only data entries
// Add extern pointers for all foreign functions as read-only data entries
struct populate_constants_visitor
: ast::const_expression_visitor<populate_constants_visitor>
, ast::const_statement_visitor<populate_constants_visitor>
@ -219,6 +258,48 @@ namespace pslang::jit::aarch64
}
};
// Iterate over a single scope (i.e. not visiting subscopes recursively)
// and add all defined functions & foreign functions to the current scope
struct populate_symbols_visitor
: ast::const_statement_visitor<populate_symbols_visitor>
{
local_context & lcontext;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const &) {}
void apply(ast::while_block const &) {}
void apply(ast::function_definition const & node)
{
lcontext.scopes.back().functions[node.name] = &node;
}
void apply(ast::foreign_function_declaration const & node)
{
lcontext.scopes.back().foreign_functions.insert(node.name);
}
void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {}
};
struct reg_extend_visitor
: types::const_visitor<reg_extend_visitor>
{
@ -262,6 +343,8 @@ namespace pslang::jit::aarch64
}
};
// Compile a single function and store the entry point offset
// in local_context
struct compile_function_visitor
: ast::const_statement_visitor<compile_function_visitor>
, ast::const_expression_visitor<compile_function_visitor>
@ -375,15 +458,18 @@ namespace pslang::jit::aarch64
}
}
// Not a variable - must be a function!
if (lcontext.foreign_address.contains(node.name))
if (lcontext.is_foreign(node.name))
{
builder.ldr_pc(0, (lcontext.foreign_address.at(node.name) - (std::int32_t)pcontext.code.size()) / 4);
}
else if (auto function_node = lcontext.is_function(node.name))
{
lcontext.resolve.push_back({function_node, (std::int32_t)pcontext.code.size()});
builder.adr(0, 0);
}
else
{
lcontext.resolve.push_back({node.name, (std::int32_t)pcontext.code.size()});
builder.adr(0, 0);
throw std::runtime_error("unknown identifier \"" + node.name + "\"");
}
}
@ -830,22 +916,27 @@ namespace pslang::jit::aarch64
void apply(ast::function_definition const &)
{
// Must be handled prior to that
// Must be handled prior to that in populate_symbols_visitor
}
void apply(ast::foreign_function_declaration const &)
{
// Must be handled prior to that
// Must be handled prior to that in populate_symbols_visitor
}
void apply(ast::statement_list const & node)
{
lcontext.scopes.emplace_back();
populate_symbols_visitor{{}, lcontext}.apply(node);
for (auto const & statement : node.statements)
apply(*statement);
lcontext.scopes.pop_back();
}
void do_apply(ast::function_definition const & node)
{
lcontext.functions[&node] = pcontext.code.size();
// TODO: struct arguments
scopes.emplace_back();
@ -876,7 +967,12 @@ namespace pslang::jit::aarch64
}
scopes.back().variables[argument.name] = {.frame_offset = stack_offset};
}
apply(*node.statements);
if (node.statements->statements.empty() || !std::get_if<ast::return_statement>(node.statements->statements.back().get()))
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
do_return();
scopes.pop_back();
}
@ -942,34 +1038,63 @@ namespace pslang::jit::aarch64
}
};
// Main compilation visitor
struct compile_visitor
: ast::const_statement_visitor<compile_visitor>
{
using const_statement_visitor::apply;
program_context & pcontext;
local_context & lcontext;
instruction_builder builder{pcontext.code};
template <typename Statement>
void apply(Statement const &)
instruction_builder & builder;
using const_statement_visitor::apply;
void apply(ast::expression_ptr const &) {}
void apply(ast::assignment const &) {}
void apply(ast::variable_declaration const &) {}
void apply(ast::if_block const &) {}
void apply(ast::else_block const &) {}
void apply(ast::else_if_block const &) {}
void apply(ast::if_chain const & node)
{
throw std::runtime_error(std::string("compile_visitor is not implemented for ") + typeid(Statement).name());
for (auto const & block : node.blocks)
apply(*block.statements);
}
void apply(ast::while_block const & node)
{
apply(*node.statements);
}
void apply(ast::function_definition const & node)
{
pcontext.symbols[node.name] = pcontext.code.size();
compile_function_visitor visitor{{}, {}, pcontext, lcontext};
visitor.do_apply(node);
if (node.statements->statements.empty() || !std::get_if<ast::return_statement>(node.statements->statements.back().get()))
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
visitor.do_return();
compile_function_visitor{{}, {}, pcontext, lcontext}.do_apply(node);
apply(*node.statements);
}
void apply(ast::foreign_function_declaration const &)
void apply(ast::foreign_function_declaration const &) {}
void apply(ast::return_statement const &) {}
void apply(ast::field_definition const &) {}
void apply(ast::struct_definition const &) {}
void apply(ast::statement_list const & node)
{
// Already handled by populate_globals
lcontext.scopes.emplace_back();
populate_symbols_visitor{{}, lcontext}.apply(node);
for (auto const & statement : node.statements)
apply(*statement);
// Don't pop_back entry point scope
if (lcontext.scopes.size() > 1)
lcontext.scopes.pop_back();
}
};
@ -977,18 +1102,36 @@ namespace pslang::jit::aarch64
void compile(program_context & pcontext, ast::statement_list_ptr const & statements)
{
// Add a fake AST node for the entry point
auto root = std::make_shared<ast::statement>(ast::function_definition{
{
{},
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
statements,
{},
});
local_context lcontext;
populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements);
compile_visitor{{}, pcontext, lcontext}.apply(*statements);
instruction_builder builder{pcontext.code};
populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements);
compile_visitor{{}, pcontext, lcontext, builder}.apply(*root);
for (auto const & resolve : lcontext.resolve)
{
builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, pcontext.symbols.at(resolve.name) - resolve.instruction_offset);
}
builder.adr_inject(pcontext.code.data() + resolve.instruction_offset, lcontext.functions.at(resolve.node) - resolve.instruction_offset);
for (auto const & foreign : lcontext.foreign_address)
pcontext.foreign_resolve.push_back({foreign.first, foreign.second});
for (auto const & function : lcontext.scopes.front().functions)
pcontext.symbols[function.first] = lcontext.functions.at(function.second);
pcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
}
}

View file

@ -1,4 +1,5 @@
Future plans:
* Globals (requires a separate mmaped segment in JIT compiler)
* Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter
* Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value
* Const propagation: annotate expression AST nodes that are computable in compile-time
@ -14,11 +15,10 @@ Interpreter backlog:
* C FFI (foreign functions)
Aarch64 compiler backlog:
* Inner functions
* Struct values
* Structs
* Arrays
General backlog:
* Separate single-line location & full statement list location for function definitions, if chains, while blocks, structs, etc
* Mutually recursive structs (relevant only with pointers)
* Empty array expression
* Calling functions as methods