Refactor jit-compiling interface: introduce common global program context with shared code & symbol tables
This commit is contained in:
parent
668851f6bf
commit
cb32ec3459
9 changed files with 88 additions and 118 deletions
|
|
@ -161,16 +161,20 @@ int main(int argc, char ** argv)
|
||||||
|
|
||||||
if (jit)
|
if (jit)
|
||||||
{
|
{
|
||||||
auto abi = jit::host_abi();
|
jit::program_context pcontext
|
||||||
std::vector<jit::compiled_module> modules;
|
{
|
||||||
for (auto const & ast : parsed)
|
.abi = jit::host_abi(),
|
||||||
modules.push_back(jit::make_host_executable(jit::compile(ast, abi)));
|
};
|
||||||
|
|
||||||
|
for (auto const & ast : parsed)
|
||||||
|
jit::compile(pcontext, ast);
|
||||||
|
|
||||||
|
auto executable = jit::make_host_executable(pcontext.code);
|
||||||
|
|
||||||
for (auto const & module : modules)
|
|
||||||
{
|
{
|
||||||
// TODO: remove, testing-only code; should execute entry point instead
|
// TODO: remove, testing-only code; should execute entry point instead
|
||||||
auto offset = module.code.symbol_table.at("test");
|
auto offset = pcontext.symbols.at("test");
|
||||||
auto fptr = (float(*)())(module.code.memory.data.get() + offset);
|
auto fptr = (float(*)())(executable.data.get() + offset);
|
||||||
auto x = fptr();
|
auto x = fptr();
|
||||||
std::cout << "Result: " << std::boolalpha << x << std::endl;
|
std::cout << "Result: " << std::boolalpha << x << std::endl;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,6 @@
|
||||||
namespace pslang::jit::aarch64
|
namespace pslang::jit::aarch64
|
||||||
{
|
{
|
||||||
|
|
||||||
compiled_module compile(ast::statement_list_ptr const & statements);
|
void compile(program_context & context, ast::statement_list_ptr const & statements);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -1,29 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <pslang/jit/abi.hpp>
|
|
||||||
#include <pslang/jit/blob.hpp>
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
namespace pslang::jit
|
|
||||||
{
|
|
||||||
|
|
||||||
struct compiled_module
|
|
||||||
{
|
|
||||||
struct segment
|
|
||||||
{
|
|
||||||
blob memory;
|
|
||||||
std::unordered_map<std::string, std::size_t> symbol_table;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: How do we reference data segment from the code segment?
|
|
||||||
// Maybe use the same relocation/dynamic linking mechanism
|
|
||||||
// as for imported modules?
|
|
||||||
segment data;
|
|
||||||
segment code;
|
|
||||||
std::size_t entry_point; // in code segment
|
|
||||||
jit::abi abi;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <pslang/jit/compiled_module.hpp>
|
#include <pslang/jit/blob.hpp>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
namespace pslang::jit
|
namespace pslang::jit
|
||||||
{
|
{
|
||||||
|
|
||||||
blob allocate(std::size_t size);
|
blob make_host_executable(std::vector<std::uint8_t> const & code);
|
||||||
|
|
||||||
compiled_module make_host_executable(compiled_module module);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <pslang/ast/statement_fwd.hpp>
|
#include <pslang/ast/statement_fwd.hpp>
|
||||||
#include <pslang/jit/compiled_module.hpp>
|
#include <pslang/jit/program_context.hpp>
|
||||||
#include <pslang/jit/abi.hpp>
|
|
||||||
|
|
||||||
namespace pslang::jit
|
namespace pslang::jit
|
||||||
{
|
{
|
||||||
|
|
||||||
compiled_module compile(ast::statement_list_ptr const & statements, abi abi);
|
void compile(program_context & context, ast::statement_list_ptr const & statements);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
20
libs/jit/include/pslang/jit/program_context.hpp
Normal file
20
libs/jit/include/pslang/jit/program_context.hpp
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <pslang/jit/abi.hpp>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace pslang::jit
|
||||||
|
{
|
||||||
|
|
||||||
|
struct program_context
|
||||||
|
{
|
||||||
|
jit::abi abi;
|
||||||
|
std::vector<std::uint8_t> code = {};
|
||||||
|
std::unordered_map<std::string, std::int32_t> symbols = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -16,11 +16,8 @@ namespace pslang::jit::aarch64
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
struct context
|
struct local_context
|
||||||
{
|
{
|
||||||
std::vector<std::uint8_t> code;
|
|
||||||
std::unordered_map<std::string, std::size_t> code_symbol_table;
|
|
||||||
|
|
||||||
std::unordered_map<float, std::int32_t> f16_constants;
|
std::unordered_map<float, std::int32_t> f16_constants;
|
||||||
std::unordered_map<float, std::int32_t> f32_constants;
|
std::unordered_map<float, std::int32_t> f32_constants;
|
||||||
std::unordered_map<double, std::int32_t> f64_constants;
|
std::unordered_map<double, std::int32_t> f64_constants;
|
||||||
|
|
@ -42,7 +39,8 @@ namespace pslang::jit::aarch64
|
||||||
using const_expression_visitor::apply;
|
using const_expression_visitor::apply;
|
||||||
using const_statement_visitor::apply;
|
using const_statement_visitor::apply;
|
||||||
|
|
||||||
context & context;
|
program_context & pcontext;
|
||||||
|
local_context & lcontext;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
requires(!std::is_floating_point_v<T>)
|
requires(!std::is_floating_point_v<T>)
|
||||||
|
|
@ -51,27 +49,27 @@ namespace pslang::jit::aarch64
|
||||||
|
|
||||||
void apply(ast::f16_literal const & node)
|
void apply(ast::f16_literal const & node)
|
||||||
{
|
{
|
||||||
if (!context.f16_constants.contains(node.value.repr))
|
if (!lcontext.f16_constants.contains(node.value.repr))
|
||||||
{
|
{
|
||||||
context.f16_constants[node.value.repr] = context.code.size();
|
lcontext.f16_constants[node.value.repr] = pcontext.code.size();
|
||||||
push_bytes(node.value.repr);
|
push_bytes(node.value.repr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ast::f32_literal const & node)
|
void apply(ast::f32_literal const & node)
|
||||||
{
|
{
|
||||||
if (!context.f32_constants.contains(node.value))
|
if (!lcontext.f32_constants.contains(node.value))
|
||||||
{
|
{
|
||||||
context.f32_constants[node.value] = context.code.size();
|
lcontext.f32_constants[node.value] = pcontext.code.size();
|
||||||
push_bytes(node.value);
|
push_bytes(node.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ast::f64_literal const & node)
|
void apply(ast::f64_literal const & node)
|
||||||
{
|
{
|
||||||
if (!context.f64_constants.contains(node.value))
|
if (!lcontext.f64_constants.contains(node.value))
|
||||||
{
|
{
|
||||||
context.f64_constants[node.value] = context.code.size();
|
lcontext.f64_constants[node.value] = pcontext.code.size();
|
||||||
push_bytes(node.value);
|
push_bytes(node.value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -182,7 +180,7 @@ namespace pslang::jit::aarch64
|
||||||
{
|
{
|
||||||
auto begin = (std::uint8_t const *)(&value);
|
auto begin = (std::uint8_t const *)(&value);
|
||||||
auto end = begin + sizeof(value);
|
auto end = begin + sizeof(value);
|
||||||
context.code.insert(context.code.end(), begin, end);
|
pcontext.code.insert(pcontext.code.end(), begin, end);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -236,8 +234,9 @@ namespace pslang::jit::aarch64
|
||||||
using const_statement_visitor::apply;
|
using const_statement_visitor::apply;
|
||||||
using const_expression_visitor::apply;
|
using const_expression_visitor::apply;
|
||||||
|
|
||||||
context & context;
|
program_context & pcontext;
|
||||||
instruction_builder builder{context.code};
|
local_context & lcontext;
|
||||||
|
instruction_builder builder{pcontext.code};
|
||||||
|
|
||||||
// Difference between initial stack pointer at function enter
|
// Difference between initial stack pointer at function enter
|
||||||
// and current virtual stack pointer value. The actual stack pointer
|
// and current virtual stack pointer value. The actual stack pointer
|
||||||
|
|
@ -307,23 +306,23 @@ namespace pslang::jit::aarch64
|
||||||
|
|
||||||
void apply(ast::f16_literal const & node)
|
void apply(ast::f16_literal const & node)
|
||||||
{
|
{
|
||||||
auto offset = context.f16_constants.at(node.value.repr);
|
auto offset = lcontext.f16_constants.at(node.value.repr);
|
||||||
std::int32_t current = context.code.size();
|
std::int32_t current = pcontext.code.size();
|
||||||
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
|
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
|
||||||
builder.fcvt(0, 0b10, 0, 0b01);
|
builder.fcvt(0, 0b10, 0, 0b01);
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ast::f32_literal const & node)
|
void apply(ast::f32_literal const & node)
|
||||||
{
|
{
|
||||||
auto offset = context.f32_constants.at(node.value);
|
auto offset = lcontext.f32_constants.at(node.value);
|
||||||
std::int32_t current = context.code.size();
|
std::int32_t current = pcontext.code.size();
|
||||||
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
|
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ast::f64_literal const & node)
|
void apply(ast::f64_literal const & node)
|
||||||
{
|
{
|
||||||
auto offset = context.f64_constants.at(node.value);
|
auto offset = lcontext.f64_constants.at(node.value);
|
||||||
std::int32_t current = context.code.size();
|
std::int32_t current = pcontext.code.size();
|
||||||
builder.ldr_fp_pc(0, 1, (offset - current) / 4);
|
builder.ldr_fp_pc(0, 1, (offset - current) / 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -631,7 +630,7 @@ namespace pslang::jit::aarch64
|
||||||
if (block.condition)
|
if (block.condition)
|
||||||
{
|
{
|
||||||
apply(*block.condition);
|
apply(*block.condition);
|
||||||
branch_skip = context.code.size();
|
branch_skip = pcontext.code.size();
|
||||||
builder.cbz(0, 0);
|
builder.cbz(0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -642,30 +641,30 @@ namespace pslang::jit::aarch64
|
||||||
|
|
||||||
if (i + 1 < node.blocks.size())
|
if (i + 1 < node.blocks.size())
|
||||||
{
|
{
|
||||||
branch_to_end.push_back(context.code.size());
|
branch_to_end.push_back(pcontext.code.size());
|
||||||
builder.b(0);
|
builder.b(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (branch_skip)
|
if (branch_skip)
|
||||||
{
|
{
|
||||||
auto branch_offset = context.code.size() - *branch_skip;
|
auto branch_offset = pcontext.code.size() - *branch_skip;
|
||||||
builder.cb_inject(context.code.data() + *branch_skip, branch_offset / 4);
|
builder.cb_inject(pcontext.code.data() + *branch_skip, branch_offset / 4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto end = context.code.size();
|
auto end = pcontext.code.size();
|
||||||
for (auto instruction : branch_to_end)
|
for (auto instruction : branch_to_end)
|
||||||
{
|
{
|
||||||
auto delta = end - instruction;
|
auto delta = end - instruction;
|
||||||
builder.b_inject(context.code.data() + instruction, delta / 4);
|
builder.b_inject(pcontext.code.data() + instruction, delta / 4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ast::while_block const & node)
|
void apply(ast::while_block const & node)
|
||||||
{
|
{
|
||||||
std::int32_t start = context.code.size();
|
std::int32_t start = pcontext.code.size();
|
||||||
apply(*node.condition);
|
apply(*node.condition);
|
||||||
std::int32_t skip = context.code.size();
|
std::int32_t skip = pcontext.code.size();
|
||||||
builder.cbz(0, 0);
|
builder.cbz(0, 0);
|
||||||
|
|
||||||
scopes.emplace_back();
|
scopes.emplace_back();
|
||||||
|
|
@ -673,12 +672,12 @@ namespace pslang::jit::aarch64
|
||||||
scope_cleanup();
|
scope_cleanup();
|
||||||
scopes.pop_back();
|
scopes.pop_back();
|
||||||
|
|
||||||
std::int32_t loop = context.code.size();
|
std::int32_t loop = pcontext.code.size();
|
||||||
builder.b(0);
|
builder.b(0);
|
||||||
std::int32_t end = context.code.size();
|
std::int32_t end = pcontext.code.size();
|
||||||
|
|
||||||
builder.cb_inject(context.code.data() + skip, (end - skip) / 4);
|
builder.cb_inject(pcontext.code.data() + skip, (end - skip) / 4);
|
||||||
builder.b_inject(context.code.data() + loop, (start - loop) / 4);
|
builder.b_inject(pcontext.code.data() + loop, (start - loop) / 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
void apply(ast::return_statement const & node)
|
void apply(ast::return_statement const & node)
|
||||||
|
|
@ -793,8 +792,9 @@ namespace pslang::jit::aarch64
|
||||||
{
|
{
|
||||||
using const_statement_visitor::apply;
|
using const_statement_visitor::apply;
|
||||||
|
|
||||||
context & context;
|
program_context & pcontext;
|
||||||
instruction_builder builder{context.code};
|
local_context & lcontext;
|
||||||
|
instruction_builder builder{pcontext.code};
|
||||||
|
|
||||||
template <typename Statement>
|
template <typename Statement>
|
||||||
void apply(Statement const &)
|
void apply(Statement const &)
|
||||||
|
|
@ -804,34 +804,21 @@ namespace pslang::jit::aarch64
|
||||||
|
|
||||||
void apply(ast::function_definition const & node)
|
void apply(ast::function_definition const & node)
|
||||||
{
|
{
|
||||||
context.code_symbol_table[node.name] = context.code.size();
|
pcontext.symbols[node.name] = pcontext.code.size();
|
||||||
compile_function_visitor visitor{{}, {}, context};
|
compile_function_visitor visitor{{}, {}, pcontext, lcontext};
|
||||||
visitor.do_apply(node);
|
visitor.do_apply(node);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
compiled_module compile(ast::statement_list_ptr const & statements)
|
void compile(program_context & pcontext, ast::statement_list_ptr const & statements)
|
||||||
{
|
{
|
||||||
context context;
|
local_context lcontext;
|
||||||
|
|
||||||
populate_constants_visitor{{}, {}, context}.apply(*statements);
|
populate_constants_visitor{{}, {}, pcontext, lcontext}.apply(*statements);
|
||||||
|
|
||||||
compile_visitor{{}, context}.apply(*statements);
|
compile_visitor{{}, pcontext, lcontext}.apply(*statements);
|
||||||
|
|
||||||
auto code = allocate(context.code.size());
|
|
||||||
std::copy(context.code.data(), context.code.data() + context.code.size(), code.data.get());
|
|
||||||
|
|
||||||
return compiled_module {
|
|
||||||
.data = {},
|
|
||||||
.code = {
|
|
||||||
.memory = std::move(code),
|
|
||||||
.symbol_table = std::move(context.code_symbol_table),
|
|
||||||
},
|
|
||||||
.entry_point = 0,
|
|
||||||
.abi = abi::armv8,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -15,28 +15,15 @@
|
||||||
namespace pslang::jit
|
namespace pslang::jit
|
||||||
{
|
{
|
||||||
|
|
||||||
blob allocate(std::size_t size)
|
blob make_host_executable(std::vector<std::uint8_t> const & code)
|
||||||
{
|
{
|
||||||
#if defined(__linux__) || defined(__APPLE__)
|
#if defined(__linux__) || defined(__APPLE__)
|
||||||
|
auto size = code.size();
|
||||||
auto ptr = (std::uint8_t *)mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);
|
auto ptr = (std::uint8_t *)mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);
|
||||||
auto shared_ptr = std::shared_ptr<std::uint8_t>(ptr, [size](void * ptr){ munmap(ptr, size); });
|
std::copy(code.begin(), code.end(), ptr);
|
||||||
return blob(shared_ptr, size);
|
if (mprotect(ptr, size, PROT_READ | PROT_EXEC) != 0)
|
||||||
#else
|
|
||||||
throw std::runtime_error("Allocate not supported for this platform");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
compiled_module make_host_executable(compiled_module module)
|
|
||||||
{
|
|
||||||
#if defined(__linux__) || defined(__APPLE__)
|
|
||||||
if (module.abi != abi::itanium && module.abi != abi::armv8)
|
|
||||||
throw std::runtime_error("Abi mismatch");
|
|
||||||
|
|
||||||
// Assume the module code memory was obtained via mmap
|
|
||||||
if (mprotect(module.code.memory.data.get(), module.code.memory.size, PROT_READ | PROT_EXEC) != 0)
|
|
||||||
throw std::system_error(errno, std::generic_category());
|
throw std::system_error(errno, std::generic_category());
|
||||||
|
return blob(std::shared_ptr<std::uint8_t>(ptr, [size](void * ptr){ munmap(ptr, size); }), size);
|
||||||
return module;
|
|
||||||
#else
|
#else
|
||||||
throw std::runtime_error("Host-executable modules are not supported for this platform");
|
throw std::runtime_error("Host-executable modules are not supported for this platform");
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -6,16 +6,17 @@
|
||||||
namespace pslang::jit
|
namespace pslang::jit
|
||||||
{
|
{
|
||||||
|
|
||||||
compiled_module compile(ast::statement_list_ptr const & statements, jit::abi abi)
|
void compile(program_context & context, ast::statement_list_ptr const & statements)
|
||||||
{
|
{
|
||||||
switch (abi)
|
switch (context.abi)
|
||||||
{
|
{
|
||||||
case abi::itanium:
|
case abi::itanium:
|
||||||
throw std::runtime_error("Itanium ABI JIT not implemented");
|
throw std::runtime_error("Itanium ABI JIT not implemented");
|
||||||
case abi::msvc:
|
case abi::msvc:
|
||||||
throw std::runtime_error("MSVC ABI JIT not implemented");
|
throw std::runtime_error("MSVC ABI JIT not implemented");
|
||||||
case abi::armv8:
|
case abi::armv8:
|
||||||
return aarch64::compile(statements);
|
aarch64::compile(context, statements);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue