From 5d7968c30b843476f7923a168ea57b77e4ab28eb Mon Sep 17 00:00:00 2001 From: lisyarus Date: Wed, 25 Mar 2026 16:30:43 +0300 Subject: [PATCH] Implement IR + Aarch64 pointers --- examples/ir_test.psl | 23 +++++-- libs/ast/include/pslang/ast/type.hpp | 2 + libs/ast/source/type.cpp | 43 ++++++++++++ libs/ast/source/type_check.cpp | 18 +++-- libs/ir/source/compiler.cpp | 66 ++++++++++++++++--- .../jit/arch/aarch64/instruction_builder.hpp | 2 +- libs/jit/source/arch/aarch64/compiler_v2.cpp | 56 ++++++++++++++-- 7 files changed, 181 insertions(+), 29 deletions(-) diff --git a/examples/ir_test.psl b/examples/ir_test.psl index 8644e6a..e74e70c 100644 --- a/examples/ir_test.psl +++ b/examples/ir_test.psl @@ -7,14 +7,23 @@ func print32(n: u32): print32(n / 10u) print('0' + ((n % 10u) as u8)) -func factorial(n: u32) -> u32: - if n == 0u: - return 1u - return n * factorial(n - 1u) +func alloc(size: u64) -> unit mut*: + foreign func malloc(size: u64) -> unit mut* + return malloc(size) -foreign func sinf(x: f32) -> f32 +foreign func free(ptr: unit*) -print32(factorial(10u)) // 3628800 +let count = 30 +let array = alloc(4 * count as u64) as u32 mut* +array[0] = 0u +array[1] = 1u +print32(array[0]) print('\n') -print32(sinf(1.0) * 1000000.0 as u32) // 841471 +print32(array[1]) print('\n') +mut i = 2 +while i < count: + array[i] = array[i - 1] + array[i - 2] + print32(array[i]) + print('\n') + i = i + 1 diff --git a/libs/ast/include/pslang/ast/type.hpp b/libs/ast/include/pslang/ast/type.hpp index 789db70..95e9a42 100644 --- a/libs/ast/include/pslang/ast/type.hpp +++ b/libs/ast/include/pslang/ast/type.hpp @@ -59,4 +59,6 @@ namespace pslang::ast using type_impl::type_impl; }; + std::size_t type_size(types::type const & type); + } diff --git a/libs/ast/source/type.cpp b/libs/ast/source/type.cpp index b25fe1c..6b55fb5 100644 --- a/libs/ast/source/type.cpp +++ b/libs/ast/source/type.cpp @@ -1,6 +1,8 @@ #include #include +#include #include +#include namespace pslang::ast { @@ -44,6 +46,42 @@ namespace pslang::ast } }; + struct size_visitor + : types::const_visitor + { + using const_visitor::apply; + + std::size_t apply(types::unit_type const & type) + { + return 0; + } + + std::size_t apply(types::primitive_type const & type) + { + return types::type_size(type); + } + + std::size_t apply(types::array_type const & type) + { + return apply(*type.element_type) * type.size; + } + + std::size_t apply(types::function_type const &) + { + return 8; + } + + std::size_t apply(types::pointer_type const &) + { + return 8; + } + + std::size_t apply(types::struct_type const & type) + { + return type.node->layout.size; + } + }; + } types::type_ptr get_type(type const & type) @@ -51,4 +89,9 @@ namespace pslang::ast return get_type_visitor{}.apply(type); } + std::size_t type_size(types::type const & type) + { + return size_visitor{}.apply(type); + } + } diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 61c5d80..e180935 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -349,14 +349,18 @@ namespace pslang::ast node.inferred_type = arg1_type; return; } - if (types::is_pointer_type(*arg1_type) && types::is_integer_type(*arg2_type)) + if (auto pointer_type = std::get_if(arg1_type.get()); pointer_type && types::is_integer_type(*arg2_type)) { + if (ast::type_size(*pointer_type->referenced_type) == 0) + throw type_error("Pointer arithmetic with empty types is not allowed", node.location); node.inferred_type = arg1_type; return; } - if (types::is_integer_type(*arg1_type) && types::is_pointer_type(*arg2_type)) + if (auto pointer_type = std::get_if(arg2_type.get()); pointer_type && types::is_integer_type(*arg1_type)) { - node.inferred_type = arg1_type; + if (ast::type_size(*pointer_type->referenced_type) == 0) + throw type_error("Pointer arithmetic with empty types is not allowed", node.location); + node.inferred_type = arg2_type; return; } break; @@ -366,13 +370,17 @@ namespace pslang::ast node.inferred_type = arg1_type; return; } - if (types::is_pointer_type(*arg1_type) && types::is_integer_type(*arg2_type)) + if (auto pointer_type = std::get_if(arg1_type.get()); pointer_type && types::is_integer_type(*arg2_type)) { + if (ast::type_size(*pointer_type->referenced_type) == 0) + throw type_error("Pointer arithmetic with empty types is not allowed", node.location); node.inferred_type = arg1_type; return; } - if (types::is_pointer_type(*arg1_type) && types::is_pointer_type(*arg2_type)) + if (types::is_pointer_type(*arg1_type) && types::is_pointer_type(*arg2_type) && types::equal(*arg1_type, *arg2_type)) { + if (ast::type_size(*std::get(*arg1_type).referenced_type) == 0) + throw type_error("Pointer arithmetic with empty types is not allowed", node.location); node.inferred_type = std::make_unique(types::primitive_type{types::i64_type{}}); return; } diff --git a/libs/ir/source/compiler.cpp b/libs/ir/source/compiler.cpp index 68d937e..ed3f472 100644 --- a/libs/ir/source/compiler.cpp +++ b/libs/ir/source/compiler.cpp @@ -105,6 +105,8 @@ namespace pslang::ir { auto arg1 = apply(*node.arg1); + // Handle short-circuit operators + if (node.type == ast::binary_operation_type::logical_and) { mcontext.nodes->emplace_back(jump_if_zero{arg1}); @@ -130,7 +132,57 @@ namespace pslang::ir return arg1; } + // Handle pointer arithmetic + auto arg2 = apply(*node.arg2); + + auto arg1_type = get_type(*node.arg1); + auto arg2_type = get_type(*node.arg2); + + auto arg1_is_pointer = types::is_pointer_type(*arg1_type); + auto arg2_is_pointer = types::is_pointer_type(*arg2_type); + + if ((node.type == ast::binary_operation_type::addition || node.type == ast::binary_operation_type::subtraction) + && (arg1_is_pointer || arg2_is_pointer)) + { + // Pointer types are equal and referenced types are non-empty - guaranteed by type checker + std::int64_t element_size = 0; + if (arg1_is_pointer) + element_size = ast::type_size(*std::get(*arg1_type).referenced_type); + else + element_size = ast::type_size(*std::get(*arg2_type).referenced_type); + + auto i64_type = std::make_shared(types::primitive_type{types::i64_type{}}); + + mcontext.nodes->emplace_back(literal{ast::i64_literal{element_size}}, i64_type); + auto element_size_node = last(); + + if (node.type == ast::binary_operation_type::addition) + { + if (arg1_is_pointer) + { + mcontext.nodes->emplace_back(cast_operation{arg2, i64_type}, i64_type); + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type); + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg1, last()}, node.inferred_type); + } + else // if (arg2_is_pointer) + { + mcontext.nodes->emplace_back(cast_operation{arg1, i64_type}, i64_type); + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::multiplication, last(), element_size_node}, i64_type); + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, arg2, last()}, node.inferred_type); + } + } + else if (node.type == ast::binary_operation_type::subtraction) + { + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::subtraction, arg1, arg2}, node.inferred_type); + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::division, last(), element_size_node}, i64_type); + } + + return last(); + } + + // General case + mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type); return last(); } @@ -170,13 +222,11 @@ namespace pslang::ir node_ref apply(ast::array_access const & node) { - auto array = apply(*node.array); - auto index = apply(*node.index); auto array_type = ast::get_type(*node.array); - if (std::get_if(array_type.get())) + if (types::is_pointer_type(*array_type)) { - mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, array, index}, array_type); - mcontext.nodes->emplace_back(load{last()}, node.inferred_type); + auto new_ptr = apply(ast::binary_operation{ast::binary_operation_type::addition, node.array, node.index, {}, array_type}); + mcontext.nodes->emplace_back(load{new_ptr}, node.inferred_type); return last(); } throw std::runtime_error("Unknown array access left-hand side"); @@ -205,11 +255,9 @@ namespace pslang::ir if (auto array_access = std::get_if(node.lhs.get())) { auto array_type = get_type(*array_access->array); - if (std::get_if(array_type.get())) + if (types::is_pointer_type(*array_type)) { - auto base_ptr = apply(*array_access->array); - auto index = apply(*array_access->index); - mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, base_ptr, index}, array_type); + apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type}); mcontext.nodes->emplace_back(store{last(), rhs}, ast::get_type(*node.rhs)); return last(); } diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 107b01d..bab2792 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -40,7 +40,7 @@ namespace pslang::jit::aarch64 // register @reg_addr plus an signed 9-bit offset. void sturh(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); - // Store the lowest 9 bits of the value of the register @reg_src at the address specified by the value of + // Store the lowest 8 bits of the value of the register @reg_src at the address specified by the value of // register @reg_addr plus an signed 9-bit offset. void sturb(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); diff --git a/libs/jit/source/arch/aarch64/compiler_v2.cpp b/libs/jit/source/arch/aarch64/compiler_v2.cpp index 4bb45be..69ceefb 100644 --- a/libs/jit/source/arch/aarch64/compiler_v2.cpp +++ b/libs/jit/source/arch/aarch64/compiler_v2.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -172,6 +173,11 @@ namespace pslang::jit::aarch64 std::uint8_t reg; void apply(types::bool_type const &) + { + builder.ubfm(reg, reg, 8); + } + + void apply(types::f16_type const &) {} void apply(types::f32_type const &) @@ -235,14 +241,42 @@ namespace pslang::jit::aarch64 store(it, 0); } - void apply(ir::node_ref, ir::load const &, types::type_ptr const &) + void apply(ir::node_ref it, ir::load const & node, types::type_ptr const & type) { - throw std::runtime_error("Not implemented"); + // TODO: struct/array load? + load(node.ptr, 0); + auto size = ast::type_size(*type); + + if (size == 1) + builder.ldurb(0, 0, 0); + else if (size == 2) + builder.ldurh(0, 0, 0); + else if (size == 4) + builder.ldurw(0, 0, 0); + else if (size == 8) + builder.ldur(0, 0, 0); + else + throw std::runtime_error(std::format("Unsupported load size: {}", size)); + + store(it, 0); } - void apply(ir::node_ref, ir::store const &, types::type_ptr const &) + void apply(ir::node_ref, ir::store const & node, types::type_ptr const & type) { - throw std::runtime_error("Not implemented"); + // TODO: struct/array store? + load(node.ptr, 0); + load(node.value, 1); + auto size = ast::type_size(*type); + if (size == 1) + builder.sturb(1, 0, 0); + else if (size == 2) + builder.sturh(1, 0, 0); + else if (size == 4) + builder.sturw(1, 0, 0); + else if (size == 8) + builder.stur(1, 0, 0); + else + throw std::runtime_error(std::format("Unsupported store size: {}", size)); } void apply(ir::node_ref it, ir::unary_operation const & node, types::type_ptr const & type) @@ -274,8 +308,14 @@ namespace pslang::jit::aarch64 break; case ast::unary_operation_type::address_of: case ast::unary_operation_type::mutable_address_of: + builder.add_imm(31, 0, stack_size - stack_position.at(node.arg1)); + store(it, 0); + break; case ast::unary_operation_type::dereference: - throw std::runtime_error("Not implemented"); + load(node.arg1, 0); + builder.ldr(0, 0 ,0); + store(it, 0); + break; } } @@ -304,7 +344,8 @@ namespace pslang::jit::aarch64 else { builder.add_reg(0, 1, 0); - extend(0, type); + if (!types::is_pointer_type(*type)) + extend(0, type); } break; case ast::binary_operation_type::subtraction: @@ -313,7 +354,8 @@ namespace pslang::jit::aarch64 else { builder.sub_reg(0, 1, 0); - extend(0, type); + if (!types::is_pointer_type(*type)) + extend(0, type); } break; case ast::binary_operation_type::multiplication: