Aarch64 jit compiler wip: implement branching and loops

This commit is contained in:
Nikita Lisitsa 2026-01-04 13:45:03 +03:00
parent db4a8ac264
commit f28138e5b9
5 changed files with 132 additions and 21 deletions

View file

@ -171,7 +171,7 @@ int main(int argc, char ** argv)
// TODO: remove, testing-only code; should execute entry point instead
auto offset = module.code.symbol_table.at("test");
auto fptr = (int(*)(int))(module.code.memory.data.get() + offset);
auto x = fptr(42);
auto x = fptr(10);
std::cout << "Result: " << std::boolalpha << x << std::endl;
}
}

View file

@ -1,2 +1,6 @@
func test(x : i32) -> bool:
return x <= 42
func factorial(x : i32) -> i32:
mut r = 1
while x > 0:
r = r * x
x = x - 1
return r

View file

@ -94,15 +94,32 @@ namespace pslang::jit::aarch64
// Take @bit_count lowest bits from @reg_src, copy them to @reg_dst, and zero-extend @reg_dst
void ubfm(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t bit_count);
// Perform a bitwise and of @reg_src1 and @reg_src2 and set the flags
void tst(std::uint8_t reg_src1, std::uint8_t reg_src2);
// If the value of @reg_src is non-zero, move the program counter to the value of
// 19-bit signed @offset multiplied by 4
void cbnz(std::uint8_t reg_src, std::int32_t offset);
// If the value of @reg_src is zero, move the program counter to the value of
// 19-bit signed @offset multiplied by 4
void cbz(std::uint8_t reg_src, std::int32_t offset);
// Inject the 19-bit signed @offset into the opcode of a cbz or cbnz instruction
// starting at @opcode
void cb_inject(std::uint8_t * opcode, std::int32_t offset);
// Unconditionally move the program counter to the value of
// 26-bit signed @offset multiplied by 4
void b(std::int32_t offset);
// Inject the 26-bit signed @offset into the opcode of b instruction
// starting at @opcode
void b_inject(std::uint8_t * opcode, std::int32_t offset);
// Return from a subroutine, taking the return address from register @reg
void ret(std::uint8_t reg = 30);
// Helper function: push a 64-bit register @reg to the stack
void push(std::uint8_t reg);
// Helper function: pop a 64-bit register @reg from the stack
void pop(std::uint8_t reg);
private:
void do_push(std::uint32_t opcode);
};

View file

@ -1,4 +1,3 @@
#include "pslang/ast/statement_fwd.hpp"
#include <pslang/jit/arch/aarch64/compiler.hpp>
#include <pslang/jit/arch/aarch64/instruction_builder.hpp>
#include <pslang/jit/executable.hpp>
@ -266,6 +265,7 @@ namespace pslang::jit::aarch64
push(0);
apply(*node.arg2);
pop(1);
builder.cmp_reg(1, 0);
if (types::is_bool_type(*ast::get_type(*node.arg1)) || types::is_unsigned_integer_type(*ast::get_type(*node.arg1)))
builder.csetm(0, 0b1000);
else
@ -323,6 +323,67 @@ namespace pslang::jit::aarch64
scopes.back().variables[node.name] = {.frame_offset = stack_offset};
}
void apply(ast::if_chain const & node)
{
std::vector<std::size_t> branch_to_end;
for (std::size_t i = 0; i < node.blocks.size(); ++i)
{
auto const & block = node.blocks[i];
std::optional<std::size_t> branch_skip;
if (block.condition)
{
apply(*block.condition);
branch_skip = context.code.size();
builder.cbz(0, 0);
}
scopes.emplace_back();
apply(*block.statements);
scope_cleanup();
scopes.pop_back();
if (i + 1 < node.blocks.size())
{
branch_to_end.push_back(context.code.size());
builder.b(0);
}
if (branch_skip)
{
auto branch_offset = context.code.size() - *branch_skip;
builder.cb_inject(context.code.data() + *branch_skip, branch_offset / 4);
}
}
auto end = context.code.size();
for (auto instruction : branch_to_end)
{
auto delta = end - instruction;
builder.b_inject(context.code.data() + instruction, delta / 4);
}
}
void apply(ast::while_block const & node)
{
std::int32_t start = context.code.size();
apply(*node.condition);
std::int32_t skip = context.code.size();
builder.cbz(0, 0);
scopes.emplace_back();
apply(*node.statements);
scope_cleanup();
scopes.pop_back();
std::int32_t loop = context.code.size();
builder.b(0);
std::int32_t end = context.code.size();
builder.cb_inject(context.code.data() + skip, (end - skip) / 4);
builder.b_inject(context.code.data() + loop, (start - loop) / 4);
}
void apply(ast::return_statement const & node)
{
apply(*node.value);
@ -340,7 +401,6 @@ namespace pslang::jit::aarch64
{
for (auto const & statement : node.statements)
apply(*statement);
builder.add_imm(31, 31, scopes.back().stack_offset);
}
void do_apply(ast::function_definition const & node)
@ -385,6 +445,12 @@ namespace pslang::jit::aarch64
{
reg_extend_visitor{{}, builder, reg}.apply(*type);
}
void scope_cleanup()
{
if (scopes.back().stack_offset > 0)
builder.add_imm(31, 31, scopes.back().stack_offset);
}
};
struct compile_visitor

View file

@ -125,21 +125,45 @@ namespace pslang::jit::aarch64
do_push(0xd3400000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((bit_count - 1) & 0x3fu) << 10));
}
void instruction_builder::tst(std::uint8_t reg_src1, std::uint8_t reg_src2)
{
do_push(0xea00001fu | ((reg_src1 & REG_MASK) << 5) | ((reg_src2 & REG_MASK) << 16));
}
void instruction_builder::cbnz(std::uint8_t reg_src, std::int32_t offset)
{
do_push(0xb5000000u | (reg_src & REG_MASK) | ((std::uint32_t(offset) & 0x7ffffu) << 5));
}
void instruction_builder::cbz(std::uint8_t reg_src, std::int32_t offset)
{
do_push(0xb4000000u | (reg_src & REG_MASK) | ((std::uint32_t(offset) & 0x7ffffu) << 5));
}
void instruction_builder::cb_inject(std::uint8_t * opcode, std::int32_t offset)
{
auto dst = (std::uint32_t *)opcode;
*dst &= ~(0x7ffffu << 5);
*dst |= (std::uint32_t(offset) & 0x7ffffu) << 5;
}
void instruction_builder::b(std::int32_t offset)
{
do_push(0x14000000u | (std::uint32_t(offset) & 0x3ffffffu));
}
void instruction_builder::b_inject(std::uint8_t * opcode, std::int32_t offset)
{
auto dst = (std::uint32_t *)opcode;
*dst &= ~0x3ffffffu;
*dst |= (std::uint32_t(offset) & 0x3ffffffu);
}
void instruction_builder::ret(std::uint8_t reg)
{
do_push(0xd65f0000u | ((reg & REG_MASK) << 5));
}
void instruction_builder::push(std::uint8_t reg)
{
str_pre(reg, 31, -8);
}
void instruction_builder::pop(std::uint8_t reg)
{
ldr_post(reg, 31, 8);
}
void instruction_builder::do_push(std::uint32_t opcode)
{
code.push_back((opcode >> 0) & 0xffu);