Refactor jit-compiling interface: introduce common global program context with shared code & symbol tables

This commit is contained in:
Nikita Lisitsa 2026-01-06 19:10:52 +03:00
parent 668851f6bf
commit cb32ec3459
9 changed files with 88 additions and 118 deletions

View file

@ -161,16 +161,20 @@ int main(int argc, char ** argv)
if (jit)
{
auto abi = jit::host_abi();
std::vector<jit::compiled_module> modules;
for (auto const & ast : parsed)
modules.push_back(jit::make_host_executable(jit::compile(ast, abi)));
jit::program_context pcontext
{
.abi = jit::host_abi(),
};
for (auto const & ast : parsed)
jit::compile(pcontext, ast);
auto executable = jit::make_host_executable(pcontext.code);
for (auto const & module : modules)
{
// TODO: remove, testing-only code; should execute entry point instead
auto offset = module.code.symbol_table.at("test");
auto fptr = (float(*)())(module.code.memory.data.get() + offset);
auto offset = pcontext.symbols.at("test");
auto fptr = (float(*)())(executable.data.get() + offset);
auto x = fptr();
std::cout << "Result: " << std::boolalpha << x << std::endl;
}

View file

@ -5,6 +5,6 @@
namespace pslang::jit::aarch64
{
compiled_module compile(ast::statement_list_ptr const & statements);
void compile(program_context & context, ast::statement_list_ptr const & statements);
}

View file

@ -1,29 +0,0 @@
#pragma once
#include <pslang/jit/abi.hpp>
#include <pslang/jit/blob.hpp>
#include <string>
#include <unordered_map>
namespace pslang::jit
{
struct compiled_module
{
struct segment
{
blob memory;
std::unordered_map<std::string, std::size_t> symbol_table;
};
// TODO: How do we reference data segment from the code segment?
// Maybe use the same relocation/dynamic linking mechanism
// as for imported modules?
segment data;
segment code;
std::size_t entry_point; // in code segment
jit::abi abi;
};
}

View file

@ -1,12 +1,13 @@
#pragma once
#include <pslang/jit/compiled_module.hpp>
#include <pslang/jit/blob.hpp>
#include <vector>
#include <cstdint>
namespace pslang::jit
{
blob allocate(std::size_t size);
compiled_module make_host_executable(compiled_module module);
blob make_host_executable(std::vector<std::uint8_t> const & code);
}

View file

@ -1,12 +1,11 @@
#pragma once
#include <pslang/ast/statement_fwd.hpp>
#include <pslang/jit/compiled_module.hpp>
#include <pslang/jit/abi.hpp>
#include <pslang/jit/program_context.hpp>
namespace pslang::jit
{
compiled_module compile(ast::statement_list_ptr const & statements, abi abi);
void compile(program_context & context, ast::statement_list_ptr const & statements);
}

View file

@ -0,0 +1,20 @@
#pragma once
#include <pslang/jit/abi.hpp>
#include <vector>
#include <cstdint>
#include <string>
#include <unordered_map>
namespace pslang::jit
{
struct program_context
{
jit::abi abi;
std::vector<std::uint8_t> code = {};
std::unordered_map<std::string, std::int32_t> symbols = {};
};
}

View file

@ -16,11 +16,8 @@ namespace pslang::jit::aarch64
namespace
{
struct context
struct local_context
{
std::vector<std::uint8_t> code;
std::unordered_map<std::string, std::size_t> code_symbol_table;
std::unordered_map<float, std::int32_t> f16_constants;
std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants;
@ -42,7 +39,8 @@ namespace pslang::jit::aarch64
using const_expression_visitor::apply;
using const_statement_visitor::apply;
context & context;
program_context & pcontext;
local_context & lcontext;
template <typename T>
requires(!std::is_floating_point_v<T>)
@ -51,27 +49,27 @@ namespace pslang::jit::aarch64
void apply(ast::f16_literal const & node)
{
if (!context.f16_constants.contains(node.value.repr))
if (!lcontext.f16_constants.contains(node.value.repr))
{
context.f16_constants[node.value.repr] = context.code.size();
lcontext.f16_constants[node.value.repr] = pcontext.code.size();
push_bytes(node.value.repr);
}
}
void apply(ast::f32_literal const & node)
{
if (!context.f32_constants.contains(node.value))
if (!lcontext.f32_constants.contains(node.value))
{
context.f32_constants[node.value] = context.code.size();
lcontext.f32_constants[node.value] = pcontext.code.size();
push_bytes(node.value);
}
}
void apply(ast::f64_literal const & node)
{
if (!context.f64_constants.contains(node.value))
if (!lcontext.f64_constants.contains(node.value))
{
context.f64_constants[node.value] = context.code.size();
lcontext.f64_constants[node.value] = pcontext.code.size();
push_bytes(node.value);
}
}
@ -182,7 +180,7 @@ namespace pslang::jit::aarch64
{
auto begin = (std::uint8_t const *)(&value);
auto end = begin + sizeof(value);
context.code.insert(context.code.end(), begin, end);
pcontext.code.insert(pcontext.code.end(), begin, end);
}
};
@ -236,8 +234,9 @@ namespace pslang::jit::aarch64
using const_statement_visitor::apply;
using const_expression_visitor::apply;
context & context;
instruction_builder builder{context.code};
program_context & pcontext;
local_context & lcontext;
instruction_builder builder{pcontext.code};
// Difference between initial stack pointer at function enter
// and current virtual stack pointer value. The actual stack pointer
@ -307,23 +306,23 @@ namespace pslang::jit::aarch64
void apply(ast::f16_literal const & node)
{
auto offset = context.f16_constants.at(node.value.repr);
std::int32_t current = context.code.size();
auto offset = lcontext.f16_constants.at(node.value.repr);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
builder.fcvt(0, 0b10, 0, 0b01);
}
void apply(ast::f32_literal const & node)
{
auto offset = context.f32_constants.at(node.value);
std::int32_t current = context.code.size();
auto offset = lcontext.f32_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
}
void apply(ast::f64_literal const & node)
{
auto offset = context.f64_constants.at(node.value);
std::int32_t current = context.code.size();
auto offset = lcontext.f64_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 1, (offset - current) / 4);
}
@ -631,7 +630,7 @@ namespace pslang::jit::aarch64
if (block.condition)
{
apply(*block.condition);
branch_skip = context.code.size();
branch_skip = pcontext.code.size();
builder.cbz(0, 0);
}
@ -642,30 +641,30 @@ namespace pslang::jit::aarch64
if (i + 1 < node.blocks.size())
{
branch_to_end.push_back(context.code.size());
branch_to_end.push_back(pcontext.code.size());
builder.b(0);
}
if (branch_skip)
{
auto branch_offset = context.code.size() - *branch_skip;
builder.cb_inject(context.code.data() + *branch_skip, branch_offset / 4);
auto branch_offset = pcontext.code.size() - *branch_skip;
builder.cb_inject(pcontext.code.data() + *branch_skip, branch_offset / 4);
}
}
auto end = context.code.size();
auto end = pcontext.code.size();
for (auto instruction : branch_to_end)
{
auto delta = end - instruction;
builder.b_inject(context.code.data() + instruction, delta / 4);
builder.b_inject(pcontext.code.data() + instruction, delta / 4);
}
}
void apply(ast::while_block const & node)
{
std::int32_t start = context.code.size();
std::int32_t start = pcontext.code.size();
apply(*node.condition);
std::int32_t skip = context.code.size();
std::int32_t skip = pcontext.code.size();
builder.cbz(0, 0);
scopes.emplace_back();
@ -673,12 +672,12 @@ namespace pslang::jit::aarch64
scope_cleanup();
scopes.pop_back();
std::int32_t loop = context.code.size();
std::int32_t loop = pcontext.code.size();
builder.b(0);
std::int32_t end = context.code.size();
std::int32_t end = pcontext.code.size();
builder.cb_inject(context.code.data() + skip, (end - skip) / 4);
builder.b_inject(context.code.data() + loop, (start - loop) / 4);
builder.cb_inject(pcontext.code.data() + skip, (end - skip) / 4);
builder.b_inject(pcontext.code.data() + loop, (start - loop) / 4);
}
void apply(ast::return_statement const & node)
@ -793,8 +792,9 @@ namespace pslang::jit::aarch64
{
using const_statement_visitor::apply;
context & context;
instruction_builder builder{context.code};
program_context & pcontext;
local_context & lcontext;
instruction_builder builder{pcontext.code};
template <typename Statement>
void apply(Statement const &)
@ -804,34 +804,21 @@ namespace pslang::jit::aarch64
void apply(ast::function_definition const & node)
{
context.code_symbol_table[node.name] = context.code.size();
compile_function_visitor visitor{{}, {}, context};
pcontext.symbols[node.name] = pcontext.code.size();
compile_function_visitor visitor{{}, {}, pcontext, lcontext};
visitor.do_apply(node);
}
};
}
compiled_module compile(ast::statement_list_ptr const & statements)
void compile(program_context & pcontext, ast::statement_list_ptr const & statements)
{
context context;
local_context lcontext;
populate_constants_visitor{{}, {}, context}.apply(*statements);
populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements);
compile_visitor{{}, context}.apply(*statements);
auto code = allocate(context.code.size());
std::copy(context.code.data(), context.code.data() + context.code.size(), code.data.get());
return compiled_module {
.data = {},
.code = {
.memory = std::move(code),
.symbol_table = std::move(context.code_symbol_table),
},
.entry_point = 0,
.abi = abi::armv8,
};
compile_visitor{{}, pcontext, lcontext}.apply(*statements);
}
}

View file

@ -15,28 +15,15 @@
namespace pslang::jit
{
blob allocate(std::size_t size)
blob make_host_executable(std::vector<std::uint8_t> const & code)
{
#if defined(__linux__) || defined(__APPLE__)
auto size = code.size();
auto ptr = (std::uint8_t *)mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);
auto shared_ptr = std::shared_ptr<std::uint8_t>(ptr, [size](void * ptr){ munmap(ptr, size); });
return blob(shared_ptr, size);
#else
throw std::runtime_error("Allocate not supported for this platform");
#endif
}
compiled_module make_host_executable(compiled_module module)
{
#if defined(__linux__) || defined(__APPLE__)
if (module.abi != abi::itanium && module.abi != abi::armv8)
throw std::runtime_error("Abi mismatch");
// Assume the module code memory was obtained via mmap
if (mprotect(module.code.memory.data.get(), module.code.memory.size, PROT_READ | PROT_EXEC) != 0)
std::copy(code.begin(), code.end(), ptr);
if (mprotect(ptr, size, PROT_READ | PROT_EXEC) != 0)
throw std::system_error(errno, std::generic_category());
return module;
return blob(std::shared_ptr<std::uint8_t>(ptr, [size](void * ptr){ munmap(ptr, size); }), size);
#else
throw std::runtime_error("Host-executable modules are not supported for this platform");
#endif

View file

@ -6,16 +6,17 @@
namespace pslang::jit
{
compiled_module compile(ast::statement_list_ptr const & statements, jit::abi abi)
void compile(program_context & context, ast::statement_list_ptr const & statements)
{
switch (abi)
switch (context.abi)
{
case abi::itanium:
throw std::runtime_error("Itanium ABI JIT not implemented");
case abi::msvc:
throw std::runtime_error("MSVC ABI JIT not implemented");
case abi::armv8:
return aarch64::compile(statements);
aarch64::compile(context, statements);
break;
}
}