Implement IR + Aarch64 pointers

This commit is contained in:
Nikita Lisitsa 2026-03-25 16:30:43 +03:00
parent 42e7f7961e
commit 5d7968c30b
7 changed files with 181 additions and 29 deletions

View file

@ -7,14 +7,23 @@ func print32(n: u32):
print32(n / 10u)
print('0' + ((n % 10u) as u8))
func factorial(n: u32) -> u32:
if n == 0u:
return 1u
return n * factorial(n - 1u)
func alloc(size: u64) -> unit mut*:
foreign func malloc(size: u64) -> unit mut*
return malloc(size)
foreign func sinf(x: f32) -> f32
foreign func free(ptr: unit*)
print32(factorial(10u)) // 3628800
let count = 30
let array = alloc(4 * count as u64) as u32 mut*
array[0] = 0u
array[1] = 1u
print32(array[0])
print('\n')
print32(sinf(1.0) * 1000000.0 as u32) // 841471
print32(array[1])
print('\n')
mut i = 2
while i < count:
array[i] = array[i - 1] + array[i - 2]
print32(array[i])
print('\n')
i = i + 1

View file

@ -59,4 +59,6 @@ namespace pslang::ast
using type_impl::type_impl;
};
std::size_t type_size(types::type const & type);
}

View file

@ -1,6 +1,8 @@
#include <pslang/ast/type.hpp>
#include <pslang/ast/type_visitor.hpp>
#include <pslang/ast/struct.hpp>
#include <pslang/types/type.hpp>
#include <pslang/types/type_visitor.hpp>
namespace pslang::ast
{
@ -44,6 +46,42 @@ namespace pslang::ast
}
};
struct size_visitor
: types::const_visitor<size_visitor>
{
using const_visitor::apply;
std::size_t apply(types::unit_type const & type)
{
return 0;
}
std::size_t apply(types::primitive_type const & type)
{
return types::type_size(type);
}
std::size_t apply(types::array_type const & type)
{
return apply(*type.element_type) * type.size;
}
std::size_t apply(types::function_type const &)
{
return 8;
}
std::size_t apply(types::pointer_type const &)
{
return 8;
}
std::size_t apply(types::struct_type const & type)
{
return type.node->layout.size;
}
};
}
types::type_ptr get_type(type const & type)
@ -51,4 +89,9 @@ namespace pslang::ast
return get_type_visitor{}.apply(type);
}
std::size_t type_size(types::type const & type)
{
return size_visitor{}.apply(type);
}
}

View file

@ -349,14 +349,18 @@ namespace pslang::ast
node.inferred_type = arg1_type;
return;
}
if (types::is_pointer_type(*arg1_type) && types::is_integer_type(*arg2_type))
if (auto pointer_type = std::get_if<types::pointer_type>(arg1_type.get()); pointer_type && types::is_integer_type(*arg2_type))
{
if (ast::type_size(*pointer_type->referenced_type) == 0)
throw type_error("Pointer arithmetic with empty types is not allowed", node.location);
node.inferred_type = arg1_type;
return;
}
if (types::is_integer_type(*arg1_type) && types::is_pointer_type(*arg2_type))
if (auto pointer_type = std::get_if<types::pointer_type>(arg2_type.get()); pointer_type && types::is_integer_type(*arg1_type))
{
node.inferred_type = arg1_type;
if (ast::type_size(*pointer_type->referenced_type) == 0)
throw type_error("Pointer arithmetic with empty types is not allowed", node.location);
node.inferred_type = arg2_type;
return;
}
break;
@ -366,13 +370,17 @@ namespace pslang::ast
node.inferred_type = arg1_type;
return;
}
if (types::is_pointer_type(*arg1_type) && types::is_integer_type(*arg2_type))
if (auto pointer_type = std::get_if<types::pointer_type>(arg1_type.get()); pointer_type && types::is_integer_type(*arg2_type))
{
if (ast::type_size(*pointer_type->referenced_type) == 0)
throw type_error("Pointer arithmetic with empty types is not allowed", node.location);
node.inferred_type = arg1_type;
return;
}
if (types::is_pointer_type(*arg1_type) && types::is_pointer_type(*arg2_type))
if (types::is_pointer_type(*arg1_type) && types::is_pointer_type(*arg2_type) && types::equal(*arg1_type, *arg2_type))
{
if (ast::type_size(*std::get<types::pointer_type>(*arg1_type).referenced_type) == 0)
throw type_error("Pointer arithmetic with empty types is not allowed", node.location);
node.inferred_type = std::make_unique<types::type>(types::primitive_type{types::i64_type{}});
return;
}

View file

@ -105,6 +105,8 @@ namespace pslang::ir
{
auto arg1 = apply(*node.arg1);
// Handle short-circuit operators
if (node.type == ast::binary_operation_type::logical_and)
{
mcontext.nodes->emplace_back(jump_if_zero{arg1});
@ -130,7 +132,57 @@ namespace pslang::ir
return arg1;
}
// Handle pointer arithmetic
auto arg2 = apply(*node.arg2);
auto arg1_type = get_type(*node.arg1);
auto arg2_type = get_type(*node.arg2);
auto arg1_is_pointer = types::is_pointer_type(*arg1_type);
auto arg2_is_pointer = types::is_pointer_type(*arg2_type);
if ((node.type == ast::binary_operation_type::addition || node.type == ast::binary_operation_type::subtraction)
&& (arg1_is_pointer || arg2_is_pointer))
{
// Pointer types are equal and referenced types are non-empty - guaranteed by type checker
std::int64_t element_size = 0;
if (arg1_is_pointer)
element_size = ast::type_size(*std::get<types::pointer_type>(*arg1_type).referenced_type);
else
element_size = ast::type_size(*std::get<types::pointer_type>(*arg2_type).referenced_type);
auto i64_type = std::make_shared<types::type>(types::primitive_type{types::i64_type{}});
mcontext.nodes->emplace_back(literal{ast::i64_literal{element_size}}, i64_type);
auto element_size_node = last();
if (node.type == ast::binary_operation_type::addition)
{
if (arg1_is_pointer)
{
mcontext.nodes->emplace_back(cast_operation{arg2, i64_type}, i64_type);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg1, last()}, node.inferred_type);
}
else // if (arg2_is_pointer)
{
mcontext.nodes->emplace_back(cast_operation{arg1, i64_type}, i64_type);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg2, last()}, node.inferred_type);
}
}
else if (node.type == ast::binary_operation_type::subtraction)
{
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::subtraction, arg1, arg2}, node.inferred_type);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::division, last(), element_size_node}, i64_type);
}
return last();
}
// General case
mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type);
return last();
}
@ -170,13 +222,11 @@ namespace pslang::ir
node_ref apply(ast::array_access const & node)
{
auto array = apply(*node.array);
auto index = apply(*node.index);
auto array_type = ast::get_type(*node.array);
if (std::get_if<types::pointer_type>(array_type.get()))
if (types::is_pointer_type(*array_type))
{
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, array, index}, array_type);
mcontext.nodes->emplace_back(load{last()}, node.inferred_type);
auto new_ptr = apply(ast::binary_operation{ast::binary_operation_type::addition, node.array, node.index, {}, array_type});
mcontext.nodes->emplace_back(load{new_ptr}, node.inferred_type);
return last();
}
throw std::runtime_error("Unknown array access left-hand side");
@ -205,11 +255,9 @@ namespace pslang::ir
if (auto array_access = std::get_if<ast::array_access>(node.lhs.get()))
{
auto array_type = get_type(*array_access->array);
if (std::get_if<types::pointer_type>(array_type.get()))
if (types::is_pointer_type(*array_type))
{
auto base_ptr = apply(*array_access->array);
auto index = apply(*array_access->index);
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, base_ptr, index}, array_type);
apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
mcontext.nodes->emplace_back(store{last(), rhs}, ast::get_type(*node.rhs));
return last();
}

View file

@ -40,7 +40,7 @@ namespace pslang::jit::aarch64
// register @reg_addr plus an signed 9-bit offset.
void sturh(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset);
// Store the lowest 9 bits of the value of the register @reg_src at the address specified by the value of
// Store the lowest 8 bits of the value of the register @reg_src at the address specified by the value of
// register @reg_addr plus an signed 9-bit offset.
void sturb(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset);

View file

@ -2,6 +2,7 @@
#include <pslang/jit/arch/aarch64/instruction_builder.hpp>
#include <pslang/ir/node.hpp>
#include <pslang/ir/compiler.hpp>
#include <pslang/ast/type.hpp>
#include <pslang/types/type_visitor.hpp>
#include <sstream>
@ -172,6 +173,11 @@ namespace pslang::jit::aarch64
std::uint8_t reg;
void apply(types::bool_type const &)
{
builder.ubfm(reg, reg, 8);
}
void apply(types::f16_type const &)
{}
void apply(types::f32_type const &)
@ -235,14 +241,42 @@ namespace pslang::jit::aarch64
store(it, 0);
}
void apply(ir::node_ref, ir::load const &, types::type_ptr const &)
void apply(ir::node_ref it, ir::load const & node, types::type_ptr const & type)
{
throw std::runtime_error("Not implemented");
// TODO: struct/array load?
load(node.ptr, 0);
auto size = ast::type_size(*type);
if (size == 1)
builder.ldurb(0, 0, 0);
else if (size == 2)
builder.ldurh(0, 0, 0);
else if (size == 4)
builder.ldurw(0, 0, 0);
else if (size == 8)
builder.ldur(0, 0, 0);
else
throw std::runtime_error(std::format("Unsupported load size: {}", size));
store(it, 0);
}
void apply(ir::node_ref, ir::store const &, types::type_ptr const &)
void apply(ir::node_ref, ir::store const & node, types::type_ptr const & type)
{
throw std::runtime_error("Not implemented");
// TODO: struct/array store?
load(node.ptr, 0);
load(node.value, 1);
auto size = ast::type_size(*type);
if (size == 1)
builder.sturb(1, 0, 0);
else if (size == 2)
builder.sturh(1, 0, 0);
else if (size == 4)
builder.sturw(1, 0, 0);
else if (size == 8)
builder.stur(1, 0, 0);
else
throw std::runtime_error(std::format("Unsupported store size: {}", size));
}
void apply(ir::node_ref it, ir::unary_operation const & node, types::type_ptr const & type)
@ -274,8 +308,14 @@ namespace pslang::jit::aarch64
break;
case ast::unary_operation_type::address_of:
case ast::unary_operation_type::mutable_address_of:
builder.add_imm(31, 0, stack_size - stack_position.at(node.arg1));
store(it, 0);
break;
case ast::unary_operation_type::dereference:
throw std::runtime_error("Not implemented");
load(node.arg1, 0);
builder.ldr(0, 0 ,0);
store(it, 0);
break;
}
}
@ -304,7 +344,8 @@ namespace pslang::jit::aarch64
else
{
builder.add_reg(0, 1, 0);
extend(0, type);
if (!types::is_pointer_type(*type))
extend(0, type);
}
break;
case ast::binary_operation_type::subtraction:
@ -313,7 +354,8 @@ namespace pslang::jit::aarch64
else
{
builder.sub_reg(0, 1, 0);
extend(0, type);
if (!types::is_pointer_type(*type))
extend(0, type);
}
break;
case ast::binary_operation_type::multiplication: