New IR -> Aarch64 compiler wip: basic operations done (no pointers, structs, arrays)

This commit is contained in:
Nikita Lisitsa 2026-03-25 00:22:50 +03:00
parent c3c5010e04
commit 42e7f7961e
23 changed files with 1016 additions and 111 deletions

View file

@ -95,7 +95,8 @@ int main(int argc, char ** argv)
bool jit = false;
std::vector<std::string> filenames;
std::vector<ast::statement_list_ptr> parsed;
std::vector<ast::statement_ptr> parsed;
std::vector<ir::module_context> ir_compiled;
bool no_more_options = false;
@ -181,11 +182,14 @@ int main(int argc, char ** argv)
try
{
filenames.push_back(argv[arg]);
auto ast = parser::parse(filenames.back());
ast::resolve_identifiers(ast);
ast::check_and_infer_types(ast);
ast::validate(ast);
parsed.push_back(std::move(ast));
auto root = parser::parse(filenames.back());
ast::resolve_identifiers(root);
ast::check_and_infer_types(root);
ast::validate(root);
parsed.push_back(std::move(root));
ir_compiled.emplace_back();
ir::compile(ir_compiled.back(), parsed.back());
}
catch (ast::parse_error const & error)
{
@ -230,34 +234,35 @@ int main(int argc, char ** argv)
for (std::size_t i = 0; i < filenames.size(); ++i)
{
std::cout << "Input file " << filenames[i] << " AST dump:\n\n";
ast::print(std::cout, *parsed[i]);
if (auto function_definition = std::get_if<ast::function_definition>(parsed[i].get()))
ast::print(std::cout, *function_definition->statements);
std::cout << "\n";
}
std::cout << std::flush;
}
// NB: IR isn't used in JIT or interpreter right now.
if (dump_ir)
{
for (std::size_t i = 0; i < filenames.size(); ++i)
{
std::cout << "Input file " << filenames[i] << " IR dump:\n\n";
ir::module_context context;
ir::compile(context, parsed[i]);
ir::print(std::cout, context);
ir::print(std::cout, ir_compiled[i]);
std::cout << "\n";
}
std::cout << std::flush;
}
if (jit)
{
// TODO: treat all input files as modules combined into a single program
for (std::size_t i = 0; i < filenames.size(); ++i)
{
jit::program_context pcontext
{
.abi = jit::host_abi(),
};
for (auto const & ast : parsed)
jit::compile(pcontext, ast);
jit::compile(pcontext, ir_compiled[i]);
for (auto const & resolve : pcontext.foreign_resolve)
{
@ -267,15 +272,14 @@ int main(int argc, char ** argv)
auto executable = jit::make_host_executable(pcontext.code);
// TODO: multiple input files => multiple entry points?
// Should probably compile them as separate modules.
auto entry_point = (void(*)())(executable.data.get() + pcontext.entry_point);
entry_point();
}
}
else
{
for (auto const & ast : parsed)
interpreter::exec(context, ast);
// for (auto const & ast : parsed)
// interpreter::exec(context, ast);
if (dump)
interpreter::dump(std::cout, context);

View file

@ -1,12 +1,20 @@
func alloc(size: u64) -> unit mut*:
foreign func malloc(size: u64) -> unit mut*
return malloc(size)
func print(c: u8):
foreign func putchar(c: i32) -> i32
putchar(c as i32)
foreign func free(ptr: unit*)
func print32(n: u32):
if n >= 10u:
print32(n / 10u)
print('0' + ((n % 10u) as u8))
let array = alloc(400ul) as i32 mut*
*array = 10
*(array + 1) = 20
let test = array[10]
array[15] = 1500
free(array as unit*)
func factorial(n: u32) -> u32:
if n == 0u:
return 1u
return n * factorial(n - 1u)
foreign func sinf(x: f32) -> f32
print32(factorial(10u)) // 3628800
print('\n')
print32(sinf(1.0) * 1000000.0 as u32) // 841471
print('\n')

View file

@ -5,8 +5,8 @@
namespace pslang::ast
{
void resolve_identifiers(statement_list_ptr & statements);
void check_and_infer_types(statement_list_ptr & statements);
void validate(statement_list_ptr & statements);
void resolve_identifiers(statement_ptr & statements);
void check_and_infer_types(statement_ptr & statements);
void validate(statement_ptr & statements);
}

View file

@ -375,12 +375,12 @@ namespace pslang::ast
}
void resolve_identifiers(statement_list_ptr & statements)
void resolve_identifiers(statement_ptr & root)
{
std::vector<scope> scopes;
scopes.emplace_back();
resolve_identifiers_visitor visitor{{}, {}, {}, scopes};
visitor.apply(*statements);
visitor.apply(*root);
}
}

View file

@ -903,11 +903,11 @@ namespace pslang::ast
}
void check_and_infer_types(statement_list_ptr & statements)
void check_and_infer_types(statement_ptr & root)
{
local_context lcontext;
check_visitor visitor{{}, {}, lcontext};
visitor.apply(*statements);
visitor.apply(*root);
for (auto & struct_data : lcontext.structs)
compute_layout(lcontext, struct_data.first, struct_data.second);

View file

@ -52,9 +52,9 @@ namespace pslang::ast
}
void validate(statement_list_ptr & statements)
void validate(statement_ptr & root)
{
validate_visitor{}.apply(*statements);
validate_visitor{}.apply(*root);
}
}

View file

@ -21,10 +21,16 @@ namespace pslang::ir
std::unordered_map<node const *, std::string> labels;
std::unordered_map<ast::function_definition const *, node_ref> symbols;
struct symbol_info
{
node_ref begin;
node_ref end;
};
std::unordered_map<ast::function_definition const *, symbol_info> symbols;
node_ref entry_point;
};
void compile(module_context & context, ast::statement_list_ptr const & statements);
void compile(module_context & context, ast::statement_ptr const & root);
}

View file

@ -6,12 +6,13 @@
#include <pslang/types/type.hpp>
#include <variant>
#include <functional>
namespace pslang::ir
{
// Used primarily as jump target
struct nop
struct label
{};
struct literal
@ -108,7 +109,7 @@ namespace pslang::ir
};
using instruction = std::variant<
nop,
label,
literal,
copy,
load,
@ -135,4 +136,20 @@ namespace pslang::ir
types::type_ptr inferred_type = nullptr;
};
bool is_value_instruction(instruction const & instruction);
}
namespace std
{
template <>
struct hash<::pslang::ir::node_ref>
{
std::size_t operator()(pslang::ir::node_ref const & ref) const
{
return std::hash<pslang::ir::node const *>()(ref.operator->());
}
};
}

View file

@ -13,8 +13,7 @@ namespace pslang::ir
struct local_context
{
std::unordered_map<ast::function_definition const *, node_ref> functions;
std::unordered_map<ast::foreign_function_declaration const *, node_ref> foreign_functions;
std::unordered_map<ast::function_definition const *, std::pair<node_ref, node_ref>> functions;
std::unordered_map<ast::variable_base const *, node_ref> variables;
struct scope
@ -113,7 +112,7 @@ namespace pslang::ir
auto arg2 = apply(*node.arg2);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_and, arg1, arg2}, node.inferred_type);
mcontext.nodes->emplace_back(assignment{arg1, last()});
mcontext.nodes->emplace_back(nop{});
mcontext.nodes->emplace_back(label{});
std::get<jump_if_zero>(jump->instruction).target = last();
return arg1;
}
@ -126,7 +125,7 @@ namespace pslang::ir
auto arg2 = apply(*node.arg2);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_or, arg1, arg2}, node.inferred_type);
mcontext.nodes->emplace_back(assignment{arg1, last()});
mcontext.nodes->emplace_back(nop{});
mcontext.nodes->emplace_back(label{});
std::get<jump_if_zero>(jump->instruction).target = last();
return arg1;
}
@ -266,7 +265,7 @@ namespace pslang::ir
mcontext.nodes->emplace_back(jump{});
jumps_to_end.push_back(last());
mcontext.nodes->emplace_back(nop{});
mcontext.nodes->emplace_back(label{});
if (jump_to_next)
std::get<jump_if_zero>((*jump_to_next)->instruction).target = last();
}
@ -280,7 +279,8 @@ namespace pslang::ir
node_ref apply(ast::while_block const & node)
{
auto before = last();
mcontext.nodes->emplace_back(label{});
auto begin = last();
auto condition = apply(*node.condition);
mcontext.nodes->emplace_back(jump_if_zero{condition, {}});
auto jump1 = last();
@ -289,8 +289,8 @@ namespace pslang::ir
apply(*node.statements);
lcontext.scopes.pop_back();
mcontext.nodes->emplace_back(jump{std::next(before)});
mcontext.nodes->emplace_back(nop{});
mcontext.nodes->emplace_back(jump{begin});
mcontext.nodes->emplace_back(label{});
std::get<jump_if_zero>(jump1->instruction).target = last();
return last();
}
@ -326,7 +326,8 @@ namespace pslang::ir
void do_apply(ast::function_definition const & node)
{
auto before = last();
mcontext.nodes->emplace_back(label{});
auto begin = last();
lcontext.scopes.emplace_back();
for (std::size_t i = 0; i < node.arguments.size(); ++i)
@ -339,11 +340,12 @@ namespace pslang::ir
if (types::equal(*ast::get_type(*node.return_type), types::unit_type{}))
mcontext.nodes->emplace_back(return_value{}, std::make_shared<types::type>(types::unit_type{}));
lcontext.scopes.pop_back();
auto end = last();
auto entry = std::next(before);
lcontext.functions[&node] = entry;
mcontext.labels[&(*entry)] = lcontext.scopes.back().label_prefix + node.name;
lcontext.functions[&node] = {begin, end};
mcontext.labels[&(*begin)] = lcontext.scopes.back().label_prefix + node.name;
lcontext.scopes.pop_back();
}
private:
@ -408,40 +410,31 @@ namespace pslang::ir
}
void compile(module_context & mcontext, ast::statement_list_ptr const & statements)
void compile(module_context & mcontext, ast::statement_ptr const & root)
{
if (!mcontext.nodes)
mcontext.nodes = std::make_shared<node_list>();
// Add a fake AST node for the entry point
auto root = std::make_shared<ast::statement>(ast::function_definition{
{
{},
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
statements,
{},
});
mcontext.nodes->emplace_back(nop{});
auto extra_nop = std::prev(mcontext.nodes->end());
local_context lcontext;
mcontext.nodes->emplace_back(label{});
auto extra_label = std::prev(mcontext.nodes->end());
compile_visitor{{}, mcontext, lcontext}.apply(*root);
mcontext.labels[&(*std::next(extra_nop))] = "[entry point]";
mcontext.nodes->erase(extra_nop);
mcontext.nodes->erase(extra_label);
for (auto & symbol : lcontext.functions)
symbol.second.second++;
for (auto const & resolve : lcontext.resolve_address)
std::get<instruction_address>(resolve.address->instruction).target = lcontext.functions.at(resolve.target);
std::get<instruction_address>(resolve.address->instruction).target = lcontext.functions.at(resolve.target).first;
for (auto const & resolve : lcontext.resolve_call)
std::get<call>(resolve.call->instruction).target = lcontext.functions.at(resolve.target);
std::get<call>(resolve.call->instruction).target = lcontext.functions.at(resolve.target).first;
for (auto const & symbol : lcontext.functions)
mcontext.symbols[symbol.first] = symbol.second;
mcontext.symbols[symbol.first] = {.begin = symbol.second.first, .end = symbol.second.second};
mcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
mcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get())).first;
}
}

55
libs/ir/source/node.cpp Normal file
View file

@ -0,0 +1,55 @@
#include <pslang/ir/node.hpp>
namespace pslang::ir
{
namespace
{
struct visitor
{
template <typename Node>
bool operator()(Node const &)
{
return true;
}
bool operator()(label const &)
{
return false;
}
bool operator()(store const &)
{
return false;
}
bool operator()(assignment const &)
{
return false;
}
bool operator()(jump const &)
{
return false;
}
bool operator()(jump_if_zero const &)
{
return false;
}
bool operator()(return_value const &)
{
return false;
}
};
}
bool is_value_instruction(instruction const & instruction)
{
return std::visit(visitor{}, instruction);
}
}

View file

@ -188,9 +188,9 @@ namespace pslang::ir
out << std::right << std::setfill(' ') << std::setw(indent);
}
void operator()(nop const &)
void operator()(label const &)
{
out << "nop";
out << "label";
}
void operator()(literal const & instruction)

View file

@ -1,10 +1,13 @@
#pragma once
#include <pslang/jit/jit.hpp>
#include <pslang/ir/compiler.hpp>
namespace pslang::jit::aarch64
{
void compile(program_context & context, ast::statement_list_ptr const & statements);
void compile(program_context & pcontext, ir::module_context const & mcontext);
}

View file

@ -149,10 +149,19 @@ namespace pslang::jit::aarch64
// 26-bit signed @offset multiplied by 4
void b(std::int32_t offset);
// Unconditionally move the program counter to the value of
// 26-bit signed @offset multiplied by 4, and store the address of the next
// instruction in link register (x30)
void bl(std::int32_t offset);
// Unconditionally move the program counter to the value of the register @reg
void b_reg(std::uint8_t reg);
// Inject the 26-bit signed @offset into the opcode of b instruction
// Unconditionally move the program counter to the value of the register @reg
// and store the address of the next instruction in link register (x30)
void bl_reg(std::uint8_t reg);
// Inject the 26-bit signed @offset into the opcode of b or bl instruction
// starting at @opcode
void b_inject(std::uint8_t * opcode, std::int32_t offset);

View file

@ -2,10 +2,12 @@
#include <pslang/ast/statement_fwd.hpp>
#include <pslang/jit/program_context.hpp>
#include <pslang/ir/compiler.hpp>
namespace pslang::jit
{
void compile(program_context & context, ast::statement_list_ptr const & statements);
void compile(program_context & pcontext, ir::module_context const & mcontext);
}

View file

@ -7,6 +7,13 @@
#include <string>
#include <unordered_map>
namespace pslang::ast
{
struct function_definition;
}
namespace pslang::jit
{
@ -22,7 +29,7 @@ namespace pslang::jit
jit::abi abi;
std::vector<std::uint8_t> code = {};
std::unordered_map<std::string, std::int32_t> symbols = {};
std::unordered_map<ast::function_definition const *, std::int32_t> symbols = {};
std::int32_t entry_point = 0;
std::vector<foreign_resolve_info> foreign_resolve = {};
};

View file

@ -901,7 +901,7 @@ namespace pslang::jit::aarch64
pop(reg);
push(30);
builder.b_reg(reg);
builder.bl_reg(reg);
pop(30);
}
else // if (node.type)
@ -1417,7 +1417,7 @@ namespace pslang::jit::aarch64
pcontext.foreign_resolve.push_back({foreign.first, foreign.second});
for (auto const & function : lcontext.scopes.front().functions)
pcontext.symbols[function.first] = lcontext.functions.at(function.second);
pcontext.symbols[function.second] = lcontext.functions.at(function.second);
pcontext.entry_point = lcontext.functions.at(std::get_if<ast::function_definition>(root.get()));
}

View file

@ -0,0 +1,753 @@
#include <pslang/jit/arch/aarch64/compiler.hpp>
#include <pslang/jit/arch/aarch64/instruction_builder.hpp>
#include <pslang/ir/node.hpp>
#include <pslang/ir/compiler.hpp>
#include <pslang/types/type_visitor.hpp>
#include <sstream>
namespace pslang::jit::aarch64
{
namespace
{
struct local_context
{
std::unordered_map<std::string, std::int32_t> extern_symbols;
std::unordered_map<ir::node_ref, std::int32_t> nodes;
std::unordered_map<float, std::int32_t> f16_constants;
std::unordered_map<float, std::int32_t> f32_constants;
std::unordered_map<double, std::int32_t> f64_constants;
struct resolve_data
{
std::int32_t offset;
ir::node_ref target;
};
std::vector<resolve_data> branch_resolve;
std::vector<resolve_data> cbranch_resolve;
std::vector<resolve_data> adr_resolve;
};
std::uint8_t fp_mode_for(types::type const & type)
{
if (types::equal(type, types::primitive_type(types::f16_type{})))
return 1;
if (types::equal(type, types::primitive_type(types::f32_type{})))
return 2;
return 3;
}
std::int32_t fp_size(std::uint8_t mode)
{
return 1 << mode;
}
struct populate_const_data_visitor
{
program_context & pcontext;
local_context & lcontext;
template <typename Node>
void apply(Node const & node, types::type_ptr const &)
{}
void apply(ir::literal const & node, types::type_ptr const &)
{
if (auto f16_literal = std::get_if<ast::f16_literal>(&node.value))
{
lcontext.f16_constants[f16_literal->value.repr] = pcontext.code.size();
push_bytes(f16_literal->value.repr);
}
else if (auto f32_literal = std::get_if<ast::f32_literal>(&node.value))
{
lcontext.f32_constants[f32_literal->value] = pcontext.code.size();
push_bytes(f32_literal->value);
}
else if (auto f64_literal = std::get_if<ast::f64_literal>(&node.value))
{
lcontext.f32_constants[f64_literal->value] = pcontext.code.size();
push_bytes(f64_literal->value);
}
}
void apply(ir::extern_symbol const & node, types::type_ptr const &)
{
std::int32_t offset = pcontext.code.size();
lcontext.extern_symbols[node.name] = offset;
pcontext.foreign_resolve.push_back({node.name, offset});
push_bytes<void *>(nullptr);
}
private:
template <typename T>
void push_bytes(T const & value)
{
auto begin = (std::uint8_t const *)(&value);
auto end = begin + sizeof(value);
pcontext.code.insert(pcontext.code.end(), begin, end);
}
};
// Set register @reg to -1 (all bits = 1)
void set_m1(instruction_builder & builder, std::uint8_t reg)
{
builder.or_not_reg(31, 31, reg);
}
struct literal_visitor
{
program_context & pcontext;
local_context & lcontext;
instruction_builder & builder;
void operator()(ast::bool_literal const & node)
{
if (node.value)
set_m1(builder, 0);
else
builder.movz(0, 0);
}
template <typename T>
requires(std::is_integral_v<T> && !std::is_same_v<T, bool>)
void operator()(ast::primitive_literal_base<T> const & node)
{
for (std::size_t i = 0; i < sizeof(T); i += 2)
{
if (i == 0)
{
builder.movz(0, std::uint64_t(node.value));
}
else
{
auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8));
if (val != 0)
builder.movk(0, val, i / 2);
}
}
if (sizeof(T) < 8)
{
if constexpr (std::is_signed_v<T>)
{
if (node.value < 0)
builder.sbfm(0, 0, sizeof(T) * 8);
}
}
}
void operator()(ast::f16_literal const & node)
{
auto offset = lcontext.f16_constants.at(node.value.repr);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
builder.fcvt(0, 0b10, 0, 0b01);
}
void operator()(ast::f32_literal const & node)
{
auto offset = lcontext.f32_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 0, (offset - current) / 4);
}
void operator()(ast::f64_literal const & node)
{
auto offset = lcontext.f64_constants.at(node.value);
std::int32_t current = pcontext.code.size();
builder.ldr_fp_pc(0, 1, (offset - current) / 4);
}
};
struct reg_extend_visitor
: types::const_visitor<reg_extend_visitor>
{
using const_visitor::apply;
instruction_builder & builder;
std::uint8_t reg;
void apply(types::bool_type const &)
{}
void apply(types::f32_type const &)
{}
void apply(types::f64_type const &)
{}
template <typename T>
void apply(types::primitive_type_base<T> const &)
{
if constexpr (sizeof(T) == 8)
{
return;
}
if constexpr (std::is_signed_v<T>)
{
builder.sbfm(reg, reg, sizeof(T) * 8);
}
if constexpr (std::is_unsigned_v<T>)
{
builder.ubfm(reg, reg, sizeof(T) * 8);
}
}
template <typename T>
void apply(T const &)
{
throw std::runtime_error(std::string("reg_extend_visitor is not implemented for ") + typeid(T).name());
}
};
struct compile_visitor
{
program_context & pcontext;
ir::module_context const & mcontext;
local_context & lcontext;
instruction_builder & builder;
std::unordered_map<ir::node_ref, std::int32_t> stack_position;
std::int32_t stack_size = 0;
void apply(ir::node_ref, ir::label const &, types::type_ptr const &)
{}
void apply(ir::node_ref it, ir::literal const & node, types::type_ptr const & type)
{
std::visit(literal_visitor{pcontext, lcontext, builder}, node.value);
if (types::is_integer_like_type(*type))
store(it, 0);
else if (types::is_floating_point_type(*type))
store_fp(it, 0, fp_mode_for(*type));
}
void apply(ir::node_ref it, ir::copy const & node, types::type_ptr const &)
{
// TODO: struct/array copy?
load(node.source, 0);
store(it, 0);
}
void apply(ir::node_ref, ir::load const &, types::type_ptr const &)
{
throw std::runtime_error("Not implemented");
}
void apply(ir::node_ref, ir::store const &, types::type_ptr const &)
{
throw std::runtime_error("Not implemented");
}
void apply(ir::node_ref it, ir::unary_operation const & node, types::type_ptr const & type)
{
switch (node.type)
{
case ast::unary_operation_type::negation:
if (types::is_integer_type(*type))
{
load(node.arg1, 0);
builder.sub_reg(31, 0, 0);
extend(0, type);
store(it, 0);
}
else if (types::is_floating_point_type(*type))
{
auto mode = fp_mode_for(*type);
load_fp(it, 0, mode);
builder.fneg(0, mode, 0);
store_fp(it, 0, mode);
}
break;
case ast::unary_operation_type::logical_not:
load(node.arg1, 0);
builder.or_not_reg(31, 0, 0);
if (types::is_integer_type(*type))
extend(0, type);
store(it, 0);
break;
case ast::unary_operation_type::address_of:
case ast::unary_operation_type::mutable_address_of:
case ast::unary_operation_type::dereference:
throw std::runtime_error("Not implemented");
}
}
void apply(ir::node_ref it, ir::binary_operation const & node, types::type_ptr const & type)
{
auto arg1_type = node.arg1->inferred_type;
bool const is_fp = types::is_floating_point_type(*arg1_type);
std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
if (is_fp)
{
load_fp(node.arg1, 0, fp_mode);
load_fp(node.arg2, 1, fp_mode);
}
else
{
load(node.arg1, 0);
load(node.arg2, 1);
}
switch (node.type)
{
case ast::binary_operation_type::addition:
if (is_fp)
builder.fadd(0, 1, fp_mode, 0);
else
{
builder.add_reg(0, 1, 0);
extend(0, type);
}
break;
case ast::binary_operation_type::subtraction:
if (is_fp)
builder.fsub(0, 1, fp_mode, 0);
else
{
builder.sub_reg(0, 1, 0);
extend(0, type);
}
break;
case ast::binary_operation_type::multiplication:
if (is_fp)
builder.fmul(0, 1, fp_mode, 0);
else
{
builder.mul_reg(0, 1, 0);
extend(0, type);
}
break;
case ast::binary_operation_type::division:
if (is_fp)
builder.fdiv(0, 1, fp_mode, 0);
else
{
if (types::is_signed_integer_type(*type))
builder.sdiv_reg(0, 1, 0);
else
builder.udiv_reg(0, 1, 0);
extend(0, type);
}
break;
case ast::binary_operation_type::remainder:
if (types::is_signed_integer_type(*type))
{
builder.sdiv_reg(0, 1, 2);
builder.mul_reg(1, 2, 1);
builder.sub_reg(0, 1, 0);
}
else if (types::is_unsigned_integer_type(*type))
{
builder.udiv_reg(0, 1, 2);
builder.mul_reg(1, 2, 1);
builder.sub_reg(0, 1, 0);
}
break;
case ast::binary_operation_type::binary_and:
builder.and_reg(0, 1, 0);
break;
case ast::binary_operation_type::logical_and:
throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler");
case ast::binary_operation_type::binary_or:
builder.or_reg(0, 1, 0);
break;
case ast::binary_operation_type::logical_or:
throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler");
case ast::binary_operation_type::logical_xor:
builder.xor_reg(0, 1, 0);
break;
case ast::binary_operation_type::equals:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0000);
}
else
{
builder.cmp_reg(0, 1);
builder.csetm(0, 0b0000);
}
break;
case ast::binary_operation_type::not_equals:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0001);
}
else
{
builder.cmp_reg(0, 1);
builder.csetm(0, 0b0001);
}
break;
case ast::binary_operation_type::less:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b0100);
}
else
{
builder.cmp_reg(1, 0);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break;
case ast::binary_operation_type::greater:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b0100);
}
else
{
builder.cmp_reg(0, 1);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1000);
else
builder.csetm(0, 0b1100);
}
break;
case ast::binary_operation_type::less_equals:
if (is_fp)
{
builder.fcmp(0, 1, fp_mode);
builder.csetm(0, 0b1001);
}
else
{
builder.cmp_reg(0, 1);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break;
case ast::binary_operation_type::greater_equals:
if (is_fp)
{
builder.fcmp(1, 0, fp_mode);
builder.csetm(0, 0b1001);
}
else
{
builder.cmp_reg(1, 0);
if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type))
builder.csetm(0, 0b1001);
else
builder.csetm(0, 0b1101);
}
break;
default:
{
std::ostringstream os;
os << "binary operation " << node.type << " is not implemented";
throw std::runtime_error(os.str());
}
}
if (is_fp)
store_fp(it, 0, fp_mode);
else
store(it, 0);
}
void apply(ir::node_ref it, ir::cast_operation const & node, types::type_ptr const &)
{
auto src_type = node.arg1->inferred_type;
auto dst_type = node.target_type;
if (types::equal(*src_type, *dst_type) || (types::is_pointer_type(*src_type) && types::is_pointer_type(*dst_type)))
{
load(node.arg1, 0);
store(it, 0);
return;
}
if (types::is_integer_type(*src_type))
{
load(node.arg1, 0);
if (types::is_integer_type(*dst_type))
{
extend(0, dst_type);
}
else if (types::is_floating_point_type(*dst_type))
{
auto dst_mode = fp_mode_for(*dst_type);
if (types::is_signed_integer_type(*src_type))
{
builder.fmov(0, 0, 3, 1);
builder.scvtf(0, 0, 3);
if (dst_mode != 3)
builder.fcvt(0, 3, 0, dst_mode);
}
else if (types::is_unsigned_integer_type(*src_type))
{
builder.fmov(0, 0, 3, 1);
builder.ucvtf(0, 0, 3);
if (dst_mode != 3)
builder.fcvt(0, 3, 0, dst_mode);
}
}
}
else if (types::is_floating_point_type(*src_type))
{
auto src_mode = fp_mode_for(*src_type);
if (types::is_integer_type(*dst_type))
{
if (types::is_signed_integer_type(*dst_type))
{
builder.fcvtns(0, 0, src_mode);
extend(0, dst_type);
}
else if (types::is_unsigned_integer_type(*dst_type))
{
builder.fcvtnu(0, 0, src_mode);
extend(0, dst_type);
}
}
else if (types::is_floating_point_type(*dst_type))
{
auto dst_mode = fp_mode_for(*dst_type);
builder.fcvt(0, src_mode, 0, dst_mode);
}
}
if (types::is_integer_type(*dst_type))
{
store(it, 0);
}
else if (types::is_floating_point_type(*dst_type))
{
store_fp(it, 0, fp_mode_for(*dst_type));
}
}
void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type)
{
// TODO: compute argument layout before compiling the function?
// TODO: floating-point arguments? struct/array arguments?
store(it, node.index);
}
void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &)
{
lcontext.adr_resolve.emplace_back(pcontext.code.size(), node.target);
builder.adr(0, 0);
store(it, 0);
}
void apply(ir::node_ref it, ir::extern_symbol const & node, types::type_ptr const &)
{
builder.ldr_pc(0, (lcontext.extern_symbols[node.name] - (std::int32_t)pcontext.code.size()) / 4);
store(it, 0);
}
void apply(ir::node_ref, ir::assignment const & node, types::type_ptr const &)
{
// TODO: struct/array assignment?
load(node.rhs, 0);
store(node.lhs, 0);
}
void apply(ir::node_ref, ir::jump const & node, types::type_ptr const &)
{
lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.b(0);
}
void apply(ir::node_ref, ir::jump_if_zero const & node, types::type_ptr const &)
{
load(node.condition, 0);
lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.cbz(0, 0);
}
void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type)
{
// TODO: struct/array arguments?
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (auto const & argument : node.arguments)
{
if (types::is_integer_like_type(*argument->inferred_type))
load(argument, reg++);
else if (types::is_floating_point_type(*argument->inferred_type))
load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type));
else
throw std::runtime_error("Unsupported function argument type");
}
builder.sub_imm(31, 31, 16);
builder.str(30, 31, 0);
lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.bl(0);
builder.ldr(30, 31, 0);
builder.add_imm(31, 31, 16);
// TODO: struct/array return value?
if (types::is_unit_type(*type))
{}
else if (types::is_integer_like_type(*type))
store(it, 0);
else if (types::is_floating_point_type(*type))
store_fp(it, 0, fp_mode_for(*type));
else
throw std::runtime_error("Unsupported return value type");
}
void apply(ir::node_ref it, ir::call_pointer const & node, types::type_ptr const & type)
{
// TODO: struct/array arguments?
std::uint8_t reg = 0;
std::uint8_t fp_reg = 0;
for (auto const & argument : node.arguments)
{
if (types::is_integer_like_type(*argument->inferred_type))
load(argument, reg++);
else if (types::is_floating_point_type(*argument->inferred_type))
load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type));
else
throw std::runtime_error("Unsupported function argument type");
}
load(node.pointer, reg);
builder.sub_imm(31, 31, 16);
builder.str(30, 31, 0);
builder.bl_reg(reg);
builder.ldr(30, 31, 0);
builder.add_imm(31, 31, 16);
// TODO: struct/array return value?
if (types::is_unit_type(*type))
{}
else if (types::is_integer_like_type(*type))
store(it, 0);
else if (types::is_floating_point_type(*type))
store_fp(it, 0, fp_mode_for(*type));
else
throw std::runtime_error("Unsupported return value type");
}
void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &)
{
// TODO: struct/array return value?
if (node.value)
{
auto type = (*node.value)->inferred_type;
if (types::is_integer_like_type(*type))
load(*node.value, 0);
else if (types::is_floating_point_type(*type))
load_fp(*node.value, 0, fp_mode_for(*type));
else
throw std::runtime_error("Unsupported return value type");
}
if (stack_size > 0)
builder.add_imm(31, 31, stack_size);
builder.ret();
}
void compile(ir::node_ref begin, ir::node_ref end)
{
stack_size = 0;
for (auto it = begin; it != end; ++it)
{
if (ir::is_value_instruction(it->instruction))
{
stack_size += 8;
stack_position[it] = stack_size;
}
}
stack_size = ((stack_size + 15) / 16) * 16;
if (!std::holds_alternative<ir::label>(begin->instruction))
throw std::runtime_error("First IR node of a function must be a label");
auto it = begin;
lcontext.nodes[it] = pcontext.code.size();
if (stack_size > 0)
builder.sub_imm(31, 31, stack_size);
++it;
for (; it != end; ++it)
{
lcontext.nodes[it] = pcontext.code.size();
std::visit([&](auto const & instruction){ apply(it, instruction, it->inferred_type); }, it->instruction);
}
}
private:
void load(ir::node_ref ir, std::uint8_t reg)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.ldr(reg, 31, offset / 8);
}
void load_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.ldr_fp(reg, mode, 31, offset / fp_size(mode));
}
void store(ir::node_ref ir, std::uint8_t reg)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.str(reg, 31, offset / 8);
}
void store_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode)
{
std::int32_t offset = stack_size - stack_position.at(ir);
builder.str_fp(reg, mode, 31, offset / fp_size(mode));
}
// Sign- or zero-extend the register depending on the exact type
void extend(std::uint8_t reg, types::type_ptr const & type)
{
reg_extend_visitor{{}, builder, reg}.apply(*type);
}
};
}
void compile(program_context & pcontext, ir::module_context const & mcontext)
{
local_context lcontext;
instruction_builder builder{pcontext.code};
{
populate_const_data_visitor visitor{pcontext, lcontext};
for (auto it = mcontext.nodes->begin(); it != mcontext.nodes->end(); ++it)
std::visit([&](auto const & instruction){ visitor.apply(instruction, it->inferred_type); }, it->instruction);
}
for (auto const & symbol : mcontext.symbols)
{
pcontext.symbols[symbol.first] = pcontext.code.size();
compile_visitor visitor{pcontext, mcontext, lcontext, builder};
visitor.compile(symbol.second.begin, symbol.second.end);
}
pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point);
for (auto const & resolve : lcontext.branch_resolve)
builder.b_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4);
for (auto const & resolve : lcontext.cbranch_resolve)
builder.cb_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4);
for (auto const & resolve : lcontext.adr_resolve)
builder.adr_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4);
}
}

View file

@ -197,7 +197,17 @@ namespace pslang::jit::aarch64
do_push(0x14000000u | (std::uint32_t(offset) & 0x3ffffffu));
}
void instruction_builder::bl(std::int32_t offset)
{
do_push(0x94000000u | (std::uint32_t(offset) & 0x3ffffffu));
}
void instruction_builder::b_reg(std::uint8_t reg)
{
do_push(0xd61f0000u | ((reg & REG_MASK) << 5));
}
void instruction_builder::bl_reg(std::uint8_t reg)
{
do_push(0xd63f0000u | ((reg & REG_MASK) << 5));
}
@ -286,7 +296,7 @@ namespace pslang::jit::aarch64
void instruction_builder::fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op)
{
do_push(0x1e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (mode == 2 ? 0u : 0x8000000u) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16));
do_push(0x9e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16));
}
void instruction_builder::scvtf(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode)

View file

@ -20,4 +20,18 @@ namespace pslang::jit
}
}
void compile(program_context & pcontext, ir::module_context const & mcontext)
{
switch (pcontext.abi)
{
case abi::itanium:
throw std::runtime_error("Itanium ABI JIT not implemented");
case abi::msvc:
throw std::runtime_error("MSVC ABI JIT not implemented");
case abi::armv8:
aarch64::compile(pcontext, mcontext);
break;
}
}
}

View file

@ -7,6 +7,6 @@
namespace pslang::parser
{
ast::statement_list_ptr parse(std::string_view path);
ast::statement_ptr parse(std::string_view path);
}

View file

@ -8,15 +8,15 @@
namespace pslang::parser
{
ast::statement_list_ptr parse(std::string_view path)
ast::statement_ptr parse(std::string_view path)
{
yyin = fopen(path.data(), "r");
if (!yyin)
throw std::system_error(std::make_error_code(static_cast<std::errc>(errno)));
ast::location location{.begin = {.filename = path}, .end = {.filename = path}};
indented_statement_list result;
context ctx{location, result};
indented_statement_list statements;
context ctx{location, statements};
bison::parser parser(ctx);
@ -24,7 +24,17 @@ namespace pslang::parser
fclose(yyin);
return finalize(std::move(result));
// Add a fake AST node for the entry point
return std::make_shared<ast::statement>(ast::function_definition{
{
"[entry point]",
{},
std::make_shared<ast::type>(types::unit_type{}),
{},
},
finalize(std::move(statements)),
{},
});
}
}

View file

@ -25,6 +25,10 @@ namespace pslang::types
bool is_function_type(type const & type);
bool is_pointer_type(type const & type);
// Everything that is probably stored in non-fp registers:
// bool, integers, pointers, function points
bool is_integer_like_type(types::type const & type);
std::size_t type_size(type const & type);
}

View file

@ -137,6 +137,16 @@ namespace pslang::types
return false;
}
bool is_integer_like_type(types::type const & type)
{
return false
|| is_bool_type(type)
|| is_integer_type(type)
|| is_function_type(type)
|| is_pointer_type(type)
;
}
std::size_t type_size(type const & type)
{
if (std::get_if<unit_type>(&type))