From 42e7f7961e39f991b61d8ffc1ba06edea214d21e Mon Sep 17 00:00:00 2001 From: lisyarus Date: Wed, 25 Mar 2026 00:22:50 +0300 Subject: [PATCH] New IR -> Aarch64 compiler wip: basic operations done (no pointers, structs, arrays) --- apps/interpreter/source/main.cpp | 62 +- examples/ir_test.psl | 28 +- libs/ast/include/pslang/ast/preprocess.hpp | 6 +- libs/ast/source/resolve_identifiers.cpp | 4 +- libs/ast/source/type_check.cpp | 4 +- libs/ast/source/validate.cpp | 4 +- libs/ir/include/pslang/ir/compiler.hpp | 10 +- libs/ir/include/pslang/ir/node.hpp | 21 +- libs/ir/source/compiler.cpp | 61 +- libs/ir/source/node.cpp | 55 ++ libs/ir/source/print.cpp | 4 +- .../pslang/jit/arch/aarch64/compiler.hpp | 7 +- .../jit/arch/aarch64/instruction_builder.hpp | 17 +- libs/jit/include/pslang/jit/jit.hpp | 2 + .../include/pslang/jit/program_context.hpp | 9 +- libs/jit/source/arch/aarch64/compiler.cpp | 4 +- libs/jit/source/arch/aarch64/compiler_v2.cpp | 753 ++++++++++++++++++ .../arch/aarch64/instruction_builder.cpp | 14 +- libs/jit/source/jit.cpp | 28 +- libs/parser/include/pslang/parser/parser.hpp | 2 +- libs/parser/source/parser.cpp | 18 +- libs/types/include/pslang/types/type_fwd.hpp | 4 + libs/types/source/type.cpp | 10 + 23 files changed, 1016 insertions(+), 111 deletions(-) create mode 100644 libs/ir/source/node.cpp create mode 100644 libs/jit/source/arch/aarch64/compiler_v2.cpp diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 2eada0c..26f285b 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -95,7 +95,8 @@ int main(int argc, char ** argv) bool jit = false; std::vector filenames; - std::vector parsed; + std::vector parsed; + std::vector ir_compiled; bool no_more_options = false; @@ -181,11 +182,14 @@ int main(int argc, char ** argv) try { filenames.push_back(argv[arg]); - auto ast = parser::parse(filenames.back()); - ast::resolve_identifiers(ast); - ast::check_and_infer_types(ast); - ast::validate(ast); - parsed.push_back(std::move(ast)); + auto root = parser::parse(filenames.back()); + ast::resolve_identifiers(root); + ast::check_and_infer_types(root); + ast::validate(root); + parsed.push_back(std::move(root)); + + ir_compiled.emplace_back(); + ir::compile(ir_compiled.back(), parsed.back()); } catch (ast::parse_error const & error) { @@ -230,52 +234,52 @@ int main(int argc, char ** argv) for (std::size_t i = 0; i < filenames.size(); ++i) { std::cout << "Input file " << filenames[i] << " AST dump:\n\n"; - ast::print(std::cout, *parsed[i]); + if (auto function_definition = std::get_if(parsed[i].get())) + ast::print(std::cout, *function_definition->statements); std::cout << "\n"; } std::cout << std::flush; } - // NB: IR isn't used in JIT or interpreter right now. if (dump_ir) { for (std::size_t i = 0; i < filenames.size(); ++i) { std::cout << "Input file " << filenames[i] << " IR dump:\n\n"; - ir::module_context context; - ir::compile(context, parsed[i]); - ir::print(std::cout, context); + ir::print(std::cout, ir_compiled[i]); std::cout << "\n"; } + std::cout << std::flush; } if (jit) { - jit::program_context pcontext + // TODO: treat all input files as modules combined into a single program + for (std::size_t i = 0; i < filenames.size(); ++i) { - .abi = jit::host_abi(), - }; + jit::program_context pcontext + { + .abi = jit::host_abi(), + }; - for (auto const & ast : parsed) - jit::compile(pcontext, ast); + jit::compile(pcontext, ir_compiled[i]); - for (auto const & resolve : pcontext.foreign_resolve) - { - auto fptr = jit::load_foreign(resolve.name); - std::copy_n((std::uint8_t const *)(&fptr), 8, pcontext.code.data() + resolve.offset); + for (auto const & resolve : pcontext.foreign_resolve) + { + auto fptr = jit::load_foreign(resolve.name); + std::copy_n((std::uint8_t const *)(&fptr), 8, pcontext.code.data() + resolve.offset); + } + + auto executable = jit::make_host_executable(pcontext.code); + + auto entry_point = (void(*)())(executable.data.get() + pcontext.entry_point); + entry_point(); } - - auto executable = jit::make_host_executable(pcontext.code); - - // TODO: multiple input files => multiple entry points? - // Should probably compile them as separate modules. - auto entry_point = (void(*)())(executable.data.get() + pcontext.entry_point); - entry_point(); } else { - for (auto const & ast : parsed) - interpreter::exec(context, ast); + // for (auto const & ast : parsed) + // interpreter::exec(context, ast); if (dump) interpreter::dump(std::cout, context); diff --git a/examples/ir_test.psl b/examples/ir_test.psl index 93bbd54..8644e6a 100644 --- a/examples/ir_test.psl +++ b/examples/ir_test.psl @@ -1,12 +1,20 @@ -func alloc(size: u64) -> unit mut*: - foreign func malloc(size: u64) -> unit mut* - return malloc(size) +func print(c: u8): + foreign func putchar(c: i32) -> i32 + putchar(c as i32) -foreign func free(ptr: unit*) +func print32(n: u32): + if n >= 10u: + print32(n / 10u) + print('0' + ((n % 10u) as u8)) -let array = alloc(400ul) as i32 mut* -*array = 10 -*(array + 1) = 20 -let test = array[10] -array[15] = 1500 -free(array as unit*) +func factorial(n: u32) -> u32: + if n == 0u: + return 1u + return n * factorial(n - 1u) + +foreign func sinf(x: f32) -> f32 + +print32(factorial(10u)) // 3628800 +print('\n') +print32(sinf(1.0) * 1000000.0 as u32) // 841471 +print('\n') diff --git a/libs/ast/include/pslang/ast/preprocess.hpp b/libs/ast/include/pslang/ast/preprocess.hpp index 8899676..1e2510d 100644 --- a/libs/ast/include/pslang/ast/preprocess.hpp +++ b/libs/ast/include/pslang/ast/preprocess.hpp @@ -5,8 +5,8 @@ namespace pslang::ast { - void resolve_identifiers(statement_list_ptr & statements); - void check_and_infer_types(statement_list_ptr & statements); - void validate(statement_list_ptr & statements); + void resolve_identifiers(statement_ptr & statements); + void check_and_infer_types(statement_ptr & statements); + void validate(statement_ptr & statements); } diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index 879d010..a33d402 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -375,12 +375,12 @@ namespace pslang::ast } - void resolve_identifiers(statement_list_ptr & statements) + void resolve_identifiers(statement_ptr & root) { std::vector scopes; scopes.emplace_back(); resolve_identifiers_visitor visitor{{}, {}, {}, scopes}; - visitor.apply(*statements); + visitor.apply(*root); } } diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 8d876f8..61c5d80 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -903,11 +903,11 @@ namespace pslang::ast } - void check_and_infer_types(statement_list_ptr & statements) + void check_and_infer_types(statement_ptr & root) { local_context lcontext; check_visitor visitor{{}, {}, lcontext}; - visitor.apply(*statements); + visitor.apply(*root); for (auto & struct_data : lcontext.structs) compute_layout(lcontext, struct_data.first, struct_data.second); diff --git a/libs/ast/source/validate.cpp b/libs/ast/source/validate.cpp index 56e183a..90b8675 100644 --- a/libs/ast/source/validate.cpp +++ b/libs/ast/source/validate.cpp @@ -52,9 +52,9 @@ namespace pslang::ast } - void validate(statement_list_ptr & statements) + void validate(statement_ptr & root) { - validate_visitor{}.apply(*statements); + validate_visitor{}.apply(*root); } } diff --git a/libs/ir/include/pslang/ir/compiler.hpp b/libs/ir/include/pslang/ir/compiler.hpp index 2eb98b6..8164b79 100644 --- a/libs/ir/include/pslang/ir/compiler.hpp +++ b/libs/ir/include/pslang/ir/compiler.hpp @@ -21,10 +21,16 @@ namespace pslang::ir std::unordered_map labels; - std::unordered_map symbols; + struct symbol_info + { + node_ref begin; + node_ref end; + }; + + std::unordered_map symbols; node_ref entry_point; }; - void compile(module_context & context, ast::statement_list_ptr const & statements); + void compile(module_context & context, ast::statement_ptr const & root); } diff --git a/libs/ir/include/pslang/ir/node.hpp b/libs/ir/include/pslang/ir/node.hpp index 522960b..95f5522 100644 --- a/libs/ir/include/pslang/ir/node.hpp +++ b/libs/ir/include/pslang/ir/node.hpp @@ -6,12 +6,13 @@ #include #include +#include namespace pslang::ir { // Used primarily as jump target - struct nop + struct label {}; struct literal @@ -108,7 +109,7 @@ namespace pslang::ir }; using instruction = std::variant< - nop, + label, literal, copy, load, @@ -135,4 +136,20 @@ namespace pslang::ir types::type_ptr inferred_type = nullptr; }; + bool is_value_instruction(instruction const & instruction); + +} + +namespace std +{ + + template <> + struct hash<::pslang::ir::node_ref> + { + std::size_t operator()(pslang::ir::node_ref const & ref) const + { + return std::hash()(ref.operator->()); + } + }; + } diff --git a/libs/ir/source/compiler.cpp b/libs/ir/source/compiler.cpp index 5cd3572..68d937e 100644 --- a/libs/ir/source/compiler.cpp +++ b/libs/ir/source/compiler.cpp @@ -13,8 +13,7 @@ namespace pslang::ir struct local_context { - std::unordered_map functions; - std::unordered_map foreign_functions; + std::unordered_map> functions; std::unordered_map variables; struct scope @@ -113,7 +112,7 @@ namespace pslang::ir auto arg2 = apply(*node.arg2); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_and, arg1, arg2}, node.inferred_type); mcontext.nodes->emplace_back(assignment{arg1, last()}); - mcontext.nodes->emplace_back(nop{}); + mcontext.nodes->emplace_back(label{}); std::get(jump->instruction).target = last(); return arg1; } @@ -126,7 +125,7 @@ namespace pslang::ir auto arg2 = apply(*node.arg2); mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::binary_or, arg1, arg2}, node.inferred_type); mcontext.nodes->emplace_back(assignment{arg1, last()}); - mcontext.nodes->emplace_back(nop{}); + mcontext.nodes->emplace_back(label{}); std::get(jump->instruction).target = last(); return arg1; } @@ -266,7 +265,7 @@ namespace pslang::ir mcontext.nodes->emplace_back(jump{}); jumps_to_end.push_back(last()); - mcontext.nodes->emplace_back(nop{}); + mcontext.nodes->emplace_back(label{}); if (jump_to_next) std::get((*jump_to_next)->instruction).target = last(); } @@ -280,7 +279,8 @@ namespace pslang::ir node_ref apply(ast::while_block const & node) { - auto before = last(); + mcontext.nodes->emplace_back(label{}); + auto begin = last(); auto condition = apply(*node.condition); mcontext.nodes->emplace_back(jump_if_zero{condition, {}}); auto jump1 = last(); @@ -289,8 +289,8 @@ namespace pslang::ir apply(*node.statements); lcontext.scopes.pop_back(); - mcontext.nodes->emplace_back(jump{std::next(before)}); - mcontext.nodes->emplace_back(nop{}); + mcontext.nodes->emplace_back(jump{begin}); + mcontext.nodes->emplace_back(label{}); std::get(jump1->instruction).target = last(); return last(); } @@ -326,7 +326,8 @@ namespace pslang::ir void do_apply(ast::function_definition const & node) { - auto before = last(); + mcontext.nodes->emplace_back(label{}); + auto begin = last(); lcontext.scopes.emplace_back(); for (std::size_t i = 0; i < node.arguments.size(); ++i) @@ -339,11 +340,12 @@ namespace pslang::ir if (types::equal(*ast::get_type(*node.return_type), types::unit_type{})) mcontext.nodes->emplace_back(return_value{}, std::make_shared(types::unit_type{})); - lcontext.scopes.pop_back(); + auto end = last(); - auto entry = std::next(before); - lcontext.functions[&node] = entry; - mcontext.labels[&(*entry)] = lcontext.scopes.back().label_prefix + node.name; + lcontext.functions[&node] = {begin, end}; + mcontext.labels[&(*begin)] = lcontext.scopes.back().label_prefix + node.name; + + lcontext.scopes.pop_back(); } private: @@ -408,40 +410,31 @@ namespace pslang::ir } - void compile(module_context & mcontext, ast::statement_list_ptr const & statements) + void compile(module_context & mcontext, ast::statement_ptr const & root) { if (!mcontext.nodes) mcontext.nodes = std::make_shared(); - // Add a fake AST node for the entry point - auto root = std::make_shared(ast::function_definition{ - { - {}, - {}, - std::make_shared(types::unit_type{}), - {}, - }, - statements, - {}, - }); - - mcontext.nodes->emplace_back(nop{}); - auto extra_nop = std::prev(mcontext.nodes->end()); local_context lcontext; + + mcontext.nodes->emplace_back(label{}); + auto extra_label = std::prev(mcontext.nodes->end()); compile_visitor{{}, mcontext, lcontext}.apply(*root); - mcontext.labels[&(*std::next(extra_nop))] = "[entry point]"; - mcontext.nodes->erase(extra_nop); + mcontext.nodes->erase(extra_label); + + for (auto & symbol : lcontext.functions) + symbol.second.second++; for (auto const & resolve : lcontext.resolve_address) - std::get(resolve.address->instruction).target = lcontext.functions.at(resolve.target); + std::get(resolve.address->instruction).target = lcontext.functions.at(resolve.target).first; for (auto const & resolve : lcontext.resolve_call) - std::get(resolve.call->instruction).target = lcontext.functions.at(resolve.target); + std::get(resolve.call->instruction).target = lcontext.functions.at(resolve.target).first; for (auto const & symbol : lcontext.functions) - mcontext.symbols[symbol.first] = symbol.second; + mcontext.symbols[symbol.first] = {.begin = symbol.second.first, .end = symbol.second.second}; - mcontext.entry_point = lcontext.functions.at(std::get_if(root.get())); + mcontext.entry_point = lcontext.functions.at(std::get_if(root.get())).first; } } diff --git a/libs/ir/source/node.cpp b/libs/ir/source/node.cpp new file mode 100644 index 0000000..9228d29 --- /dev/null +++ b/libs/ir/source/node.cpp @@ -0,0 +1,55 @@ +#include + +namespace pslang::ir +{ + + namespace + { + + struct visitor + { + template + bool operator()(Node const &) + { + return true; + } + + bool operator()(label const &) + { + return false; + } + + bool operator()(store const &) + { + return false; + } + + bool operator()(assignment const &) + { + return false; + } + + bool operator()(jump const &) + { + return false; + } + + bool operator()(jump_if_zero const &) + { + return false; + } + + bool operator()(return_value const &) + { + return false; + } + }; + + } + + bool is_value_instruction(instruction const & instruction) + { + return std::visit(visitor{}, instruction); + } + +} diff --git a/libs/ir/source/print.cpp b/libs/ir/source/print.cpp index a33cd3f..2dfc27d 100644 --- a/libs/ir/source/print.cpp +++ b/libs/ir/source/print.cpp @@ -188,9 +188,9 @@ namespace pslang::ir out << std::right << std::setfill(' ') << std::setw(indent); } - void operator()(nop const &) + void operator()(label const &) { - out << "nop"; + out << "label"; } void operator()(literal const & instruction) diff --git a/libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp b/libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp index ce8946e..550cd82 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/compiler.hpp @@ -1,10 +1,13 @@ #pragma once #include +#include namespace pslang::jit::aarch64 { - void compile(program_context & context, ast::statement_list_ptr const & statements); + void compile(program_context & context, ast::statement_list_ptr const & statements); -} \ No newline at end of file + void compile(program_context & pcontext, ir::module_context const & mcontext); + +} diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 97749fc..107b01d 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -149,10 +149,19 @@ namespace pslang::jit::aarch64 // 26-bit signed @offset multiplied by 4 void b(std::int32_t offset); + // Unconditionally move the program counter to the value of + // 26-bit signed @offset multiplied by 4, and store the address of the next + // instruction in link register (x30) + void bl(std::int32_t offset); + // Unconditionally move the program counter to the value of the register @reg void b_reg(std::uint8_t reg); - // Inject the 26-bit signed @offset into the opcode of b instruction + // Unconditionally move the program counter to the value of the register @reg + // and store the address of the next instruction in link register (x30) + void bl_reg(std::uint8_t reg); + + // Inject the 26-bit signed @offset into the opcode of b or bl instruction // starting at @opcode void b_inject(std::uint8_t * opcode, std::int32_t offset); @@ -176,12 +185,12 @@ namespace pslang::jit::aarch64 // register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit void ldur_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); - // Store a floating-point value from @reg_src into the address stored in + // Store a floating-point value from @reg_src into the address stored in // @reg_addr plus a 12-bit unsigned @offset multiplied by sizeof the type (2, 4 or 8). // @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); - // Store a floating-point value from @reg_src into the address stored in + // Store a floating-point value from @reg_src into the address stored in // @reg_addr plus a 9-bit signed @offset. // @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit void stur_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::int16_t offset); @@ -226,4 +235,4 @@ namespace pslang::jit::aarch64 void do_push(std::uint32_t opcode); }; -} \ No newline at end of file +} diff --git a/libs/jit/include/pslang/jit/jit.hpp b/libs/jit/include/pslang/jit/jit.hpp index f1e4ea7..1db055e 100644 --- a/libs/jit/include/pslang/jit/jit.hpp +++ b/libs/jit/include/pslang/jit/jit.hpp @@ -2,10 +2,12 @@ #include #include +#include namespace pslang::jit { void compile(program_context & context, ast::statement_list_ptr const & statements); + void compile(program_context & pcontext, ir::module_context const & mcontext); } diff --git a/libs/jit/include/pslang/jit/program_context.hpp b/libs/jit/include/pslang/jit/program_context.hpp index 90a9516..8ea9240 100644 --- a/libs/jit/include/pslang/jit/program_context.hpp +++ b/libs/jit/include/pslang/jit/program_context.hpp @@ -7,6 +7,13 @@ #include #include +namespace pslang::ast +{ + + struct function_definition; + +} + namespace pslang::jit { @@ -22,7 +29,7 @@ namespace pslang::jit jit::abi abi; std::vector code = {}; - std::unordered_map symbols = {}; + std::unordered_map symbols = {}; std::int32_t entry_point = 0; std::vector foreign_resolve = {}; }; diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 4b273af..e07eaf8 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -901,7 +901,7 @@ namespace pslang::jit::aarch64 pop(reg); push(30); - builder.b_reg(reg); + builder.bl_reg(reg); pop(30); } else // if (node.type) @@ -1417,7 +1417,7 @@ namespace pslang::jit::aarch64 pcontext.foreign_resolve.push_back({foreign.first, foreign.second}); for (auto const & function : lcontext.scopes.front().functions) - pcontext.symbols[function.first] = lcontext.functions.at(function.second); + pcontext.symbols[function.second] = lcontext.functions.at(function.second); pcontext.entry_point = lcontext.functions.at(std::get_if(root.get())); } diff --git a/libs/jit/source/arch/aarch64/compiler_v2.cpp b/libs/jit/source/arch/aarch64/compiler_v2.cpp new file mode 100644 index 0000000..4bb45be --- /dev/null +++ b/libs/jit/source/arch/aarch64/compiler_v2.cpp @@ -0,0 +1,753 @@ +#include +#include +#include +#include +#include + +#include + +namespace pslang::jit::aarch64 +{ + + namespace + { + + struct local_context + { + std::unordered_map extern_symbols; + std::unordered_map nodes; + + std::unordered_map f16_constants; + std::unordered_map f32_constants; + std::unordered_map f64_constants; + + struct resolve_data + { + std::int32_t offset; + ir::node_ref target; + }; + + std::vector branch_resolve; + std::vector cbranch_resolve; + std::vector adr_resolve; + }; + + std::uint8_t fp_mode_for(types::type const & type) + { + if (types::equal(type, types::primitive_type(types::f16_type{}))) + return 1; + if (types::equal(type, types::primitive_type(types::f32_type{}))) + return 2; + return 3; + } + + std::int32_t fp_size(std::uint8_t mode) + { + return 1 << mode; + } + + struct populate_const_data_visitor + { + program_context & pcontext; + local_context & lcontext; + + template + void apply(Node const & node, types::type_ptr const &) + {} + + void apply(ir::literal const & node, types::type_ptr const &) + { + if (auto f16_literal = std::get_if(&node.value)) + { + lcontext.f16_constants[f16_literal->value.repr] = pcontext.code.size(); + push_bytes(f16_literal->value.repr); + } + else if (auto f32_literal = std::get_if(&node.value)) + { + lcontext.f32_constants[f32_literal->value] = pcontext.code.size(); + push_bytes(f32_literal->value); + } + else if (auto f64_literal = std::get_if(&node.value)) + { + lcontext.f32_constants[f64_literal->value] = pcontext.code.size(); + push_bytes(f64_literal->value); + } + } + + void apply(ir::extern_symbol const & node, types::type_ptr const &) + { + std::int32_t offset = pcontext.code.size(); + lcontext.extern_symbols[node.name] = offset; + pcontext.foreign_resolve.push_back({node.name, offset}); + push_bytes(nullptr); + } + + private: + template + void push_bytes(T const & value) + { + auto begin = (std::uint8_t const *)(&value); + auto end = begin + sizeof(value); + pcontext.code.insert(pcontext.code.end(), begin, end); + } + }; + + // Set register @reg to -1 (all bits = 1) + void set_m1(instruction_builder & builder, std::uint8_t reg) + { + builder.or_not_reg(31, 31, reg); + } + + struct literal_visitor + { + program_context & pcontext; + local_context & lcontext; + instruction_builder & builder; + + void operator()(ast::bool_literal const & node) + { + if (node.value) + set_m1(builder, 0); + else + builder.movz(0, 0); + } + + template + requires(std::is_integral_v && !std::is_same_v) + void operator()(ast::primitive_literal_base const & node) + { + for (std::size_t i = 0; i < sizeof(T); i += 2) + { + if (i == 0) + { + builder.movz(0, std::uint64_t(node.value)); + } + else + { + auto val = std::uint16_t(std::uint64_t(node.value) >> (i * 8)); + if (val != 0) + builder.movk(0, val, i / 2); + } + } + + if (sizeof(T) < 8) + { + if constexpr (std::is_signed_v) + { + if (node.value < 0) + builder.sbfm(0, 0, sizeof(T) * 8); + } + } + } + + void operator()(ast::f16_literal const & node) + { + auto offset = lcontext.f16_constants.at(node.value.repr); + std::int32_t current = pcontext.code.size(); + builder.ldr_fp_pc(0, 0, (offset - current) / 4); + builder.fcvt(0, 0b10, 0, 0b01); + } + + void operator()(ast::f32_literal const & node) + { + auto offset = lcontext.f32_constants.at(node.value); + std::int32_t current = pcontext.code.size(); + builder.ldr_fp_pc(0, 0, (offset - current) / 4); + } + + void operator()(ast::f64_literal const & node) + { + auto offset = lcontext.f64_constants.at(node.value); + std::int32_t current = pcontext.code.size(); + builder.ldr_fp_pc(0, 1, (offset - current) / 4); + } + }; + + struct reg_extend_visitor + : types::const_visitor + { + using const_visitor::apply; + + instruction_builder & builder; + std::uint8_t reg; + + void apply(types::bool_type const &) + {} + + void apply(types::f32_type const &) + {} + + void apply(types::f64_type const &) + {} + + template + void apply(types::primitive_type_base const &) + { + if constexpr (sizeof(T) == 8) + { + return; + } + + if constexpr (std::is_signed_v) + { + builder.sbfm(reg, reg, sizeof(T) * 8); + } + + if constexpr (std::is_unsigned_v) + { + builder.ubfm(reg, reg, sizeof(T) * 8); + } + } + + template + void apply(T const &) + { + throw std::runtime_error(std::string("reg_extend_visitor is not implemented for ") + typeid(T).name()); + } + }; + + struct compile_visitor + { + program_context & pcontext; + ir::module_context const & mcontext; + local_context & lcontext; + instruction_builder & builder; + + std::unordered_map stack_position; + std::int32_t stack_size = 0; + + void apply(ir::node_ref, ir::label const &, types::type_ptr const &) + {} + + void apply(ir::node_ref it, ir::literal const & node, types::type_ptr const & type) + { + std::visit(literal_visitor{pcontext, lcontext, builder}, node.value); + if (types::is_integer_like_type(*type)) + store(it, 0); + else if (types::is_floating_point_type(*type)) + store_fp(it, 0, fp_mode_for(*type)); + } + + void apply(ir::node_ref it, ir::copy const & node, types::type_ptr const &) + { + // TODO: struct/array copy? + load(node.source, 0); + store(it, 0); + } + + void apply(ir::node_ref, ir::load const &, types::type_ptr const &) + { + throw std::runtime_error("Not implemented"); + } + + void apply(ir::node_ref, ir::store const &, types::type_ptr const &) + { + throw std::runtime_error("Not implemented"); + } + + void apply(ir::node_ref it, ir::unary_operation const & node, types::type_ptr const & type) + { + switch (node.type) + { + case ast::unary_operation_type::negation: + if (types::is_integer_type(*type)) + { + load(node.arg1, 0); + builder.sub_reg(31, 0, 0); + extend(0, type); + store(it, 0); + } + else if (types::is_floating_point_type(*type)) + { + auto mode = fp_mode_for(*type); + load_fp(it, 0, mode); + builder.fneg(0, mode, 0); + store_fp(it, 0, mode); + } + break; + case ast::unary_operation_type::logical_not: + load(node.arg1, 0); + builder.or_not_reg(31, 0, 0); + if (types::is_integer_type(*type)) + extend(0, type); + store(it, 0); + break; + case ast::unary_operation_type::address_of: + case ast::unary_operation_type::mutable_address_of: + case ast::unary_operation_type::dereference: + throw std::runtime_error("Not implemented"); + } + } + + void apply(ir::node_ref it, ir::binary_operation const & node, types::type_ptr const & type) + { + auto arg1_type = node.arg1->inferred_type; + bool const is_fp = types::is_floating_point_type(*arg1_type); + std::uint8_t const fp_mode = fp_mode_for(*arg1_type); + + if (is_fp) + { + load_fp(node.arg1, 0, fp_mode); + load_fp(node.arg2, 1, fp_mode); + } + else + { + load(node.arg1, 0); + load(node.arg2, 1); + } + + switch (node.type) + { + case ast::binary_operation_type::addition: + if (is_fp) + builder.fadd(0, 1, fp_mode, 0); + else + { + builder.add_reg(0, 1, 0); + extend(0, type); + } + break; + case ast::binary_operation_type::subtraction: + if (is_fp) + builder.fsub(0, 1, fp_mode, 0); + else + { + builder.sub_reg(0, 1, 0); + extend(0, type); + } + break; + case ast::binary_operation_type::multiplication: + if (is_fp) + builder.fmul(0, 1, fp_mode, 0); + else + { + builder.mul_reg(0, 1, 0); + extend(0, type); + } + break; + case ast::binary_operation_type::division: + if (is_fp) + builder.fdiv(0, 1, fp_mode, 0); + else + { + if (types::is_signed_integer_type(*type)) + builder.sdiv_reg(0, 1, 0); + else + builder.udiv_reg(0, 1, 0); + extend(0, type); + } + break; + case ast::binary_operation_type::remainder: + if (types::is_signed_integer_type(*type)) + { + builder.sdiv_reg(0, 1, 2); + builder.mul_reg(1, 2, 1); + builder.sub_reg(0, 1, 0); + } + else if (types::is_unsigned_integer_type(*type)) + { + builder.udiv_reg(0, 1, 2); + builder.mul_reg(1, 2, 1); + builder.sub_reg(0, 1, 0); + } + break; + case ast::binary_operation_type::binary_and: + builder.and_reg(0, 1, 0); + break; + case ast::binary_operation_type::logical_and: + throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler"); + case ast::binary_operation_type::binary_or: + builder.or_reg(0, 1, 0); + break; + case ast::binary_operation_type::logical_or: + throw std::runtime_error("Short-circuiting operators must have been unwrapped in IR compiler"); + case ast::binary_operation_type::logical_xor: + builder.xor_reg(0, 1, 0); + break; + case ast::binary_operation_type::equals: + if (is_fp) + { + builder.fcmp(0, 1, fp_mode); + builder.csetm(0, 0b0000); + } + else + { + builder.cmp_reg(0, 1); + builder.csetm(0, 0b0000); + } + break; + case ast::binary_operation_type::not_equals: + if (is_fp) + { + builder.fcmp(0, 1, fp_mode); + builder.csetm(0, 0b0001); + } + else + { + builder.cmp_reg(0, 1); + builder.csetm(0, 0b0001); + } + break; + case ast::binary_operation_type::less: + if (is_fp) + { + builder.fcmp(0, 1, fp_mode); + builder.csetm(0, 0b0100); + } + else + { + builder.cmp_reg(1, 0); + if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) + builder.csetm(0, 0b1000); + else + builder.csetm(0, 0b1100); + } + break; + case ast::binary_operation_type::greater: + if (is_fp) + { + builder.fcmp(1, 0, fp_mode); + builder.csetm(0, 0b0100); + } + else + { + builder.cmp_reg(0, 1); + if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) + builder.csetm(0, 0b1000); + else + builder.csetm(0, 0b1100); + } + break; + case ast::binary_operation_type::less_equals: + if (is_fp) + { + builder.fcmp(0, 1, fp_mode); + builder.csetm(0, 0b1001); + } + else + { + builder.cmp_reg(0, 1); + if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) + builder.csetm(0, 0b1001); + else + builder.csetm(0, 0b1101); + } + break; + case ast::binary_operation_type::greater_equals: + if (is_fp) + { + builder.fcmp(1, 0, fp_mode); + builder.csetm(0, 0b1001); + } + else + { + builder.cmp_reg(1, 0); + if (types::is_bool_type(*node.arg1->inferred_type) || types::is_unsigned_integer_type(*node.arg1->inferred_type)) + builder.csetm(0, 0b1001); + else + builder.csetm(0, 0b1101); + } + break; + default: + { + std::ostringstream os; + os << "binary operation " << node.type << " is not implemented"; + throw std::runtime_error(os.str()); + } + } + + if (is_fp) + store_fp(it, 0, fp_mode); + else + store(it, 0); + } + + void apply(ir::node_ref it, ir::cast_operation const & node, types::type_ptr const &) + { + auto src_type = node.arg1->inferred_type; + auto dst_type = node.target_type; + + if (types::equal(*src_type, *dst_type) || (types::is_pointer_type(*src_type) && types::is_pointer_type(*dst_type))) + { + load(node.arg1, 0); + store(it, 0); + return; + } + + if (types::is_integer_type(*src_type)) + { + load(node.arg1, 0); + if (types::is_integer_type(*dst_type)) + { + extend(0, dst_type); + } + else if (types::is_floating_point_type(*dst_type)) + { + auto dst_mode = fp_mode_for(*dst_type); + if (types::is_signed_integer_type(*src_type)) + { + builder.fmov(0, 0, 3, 1); + builder.scvtf(0, 0, 3); + if (dst_mode != 3) + builder.fcvt(0, 3, 0, dst_mode); + } + else if (types::is_unsigned_integer_type(*src_type)) + { + builder.fmov(0, 0, 3, 1); + builder.ucvtf(0, 0, 3); + if (dst_mode != 3) + builder.fcvt(0, 3, 0, dst_mode); + } + } + } + else if (types::is_floating_point_type(*src_type)) + { + auto src_mode = fp_mode_for(*src_type); + if (types::is_integer_type(*dst_type)) + { + if (types::is_signed_integer_type(*dst_type)) + { + builder.fcvtns(0, 0, src_mode); + extend(0, dst_type); + } + else if (types::is_unsigned_integer_type(*dst_type)) + { + builder.fcvtnu(0, 0, src_mode); + extend(0, dst_type); + } + } + else if (types::is_floating_point_type(*dst_type)) + { + auto dst_mode = fp_mode_for(*dst_type); + builder.fcvt(0, src_mode, 0, dst_mode); + } + } + + if (types::is_integer_type(*dst_type)) + { + store(it, 0); + } + else if (types::is_floating_point_type(*dst_type)) + { + store_fp(it, 0, fp_mode_for(*dst_type)); + } + } + + void apply(ir::node_ref it, ir::argument const & node, types::type_ptr const & type) + { + // TODO: compute argument layout before compiling the function? + // TODO: floating-point arguments? struct/array arguments? + store(it, node.index); + } + + void apply(ir::node_ref it, ir::instruction_address const & node, types::type_ptr const &) + { + lcontext.adr_resolve.emplace_back(pcontext.code.size(), node.target); + builder.adr(0, 0); + store(it, 0); + } + + void apply(ir::node_ref it, ir::extern_symbol const & node, types::type_ptr const &) + { + builder.ldr_pc(0, (lcontext.extern_symbols[node.name] - (std::int32_t)pcontext.code.size()) / 4); + store(it, 0); + } + + void apply(ir::node_ref, ir::assignment const & node, types::type_ptr const &) + { + // TODO: struct/array assignment? + load(node.rhs, 0); + store(node.lhs, 0); + } + + void apply(ir::node_ref, ir::jump const & node, types::type_ptr const &) + { + lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target); + builder.b(0); + } + + void apply(ir::node_ref, ir::jump_if_zero const & node, types::type_ptr const &) + { + load(node.condition, 0); + lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target); + builder.cbz(0, 0); + } + + void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type) + { + // TODO: struct/array arguments? + std::uint8_t reg = 0; + std::uint8_t fp_reg = 0; + for (auto const & argument : node.arguments) + { + if (types::is_integer_like_type(*argument->inferred_type)) + load(argument, reg++); + else if (types::is_floating_point_type(*argument->inferred_type)) + load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type)); + else + throw std::runtime_error("Unsupported function argument type"); + } + builder.sub_imm(31, 31, 16); + builder.str(30, 31, 0); + lcontext.branch_resolve.emplace_back(pcontext.code.size(), node.target); + builder.bl(0); + builder.ldr(30, 31, 0); + builder.add_imm(31, 31, 16); + + // TODO: struct/array return value? + if (types::is_unit_type(*type)) + {} + else if (types::is_integer_like_type(*type)) + store(it, 0); + else if (types::is_floating_point_type(*type)) + store_fp(it, 0, fp_mode_for(*type)); + else + throw std::runtime_error("Unsupported return value type"); + } + + void apply(ir::node_ref it, ir::call_pointer const & node, types::type_ptr const & type) + { + // TODO: struct/array arguments? + std::uint8_t reg = 0; + std::uint8_t fp_reg = 0; + for (auto const & argument : node.arguments) + { + if (types::is_integer_like_type(*argument->inferred_type)) + load(argument, reg++); + else if (types::is_floating_point_type(*argument->inferred_type)) + load_fp(argument, fp_reg++, fp_mode_for(*argument->inferred_type)); + else + throw std::runtime_error("Unsupported function argument type"); + } + load(node.pointer, reg); + builder.sub_imm(31, 31, 16); + builder.str(30, 31, 0); + builder.bl_reg(reg); + builder.ldr(30, 31, 0); + builder.add_imm(31, 31, 16); + + // TODO: struct/array return value? + if (types::is_unit_type(*type)) + {} + else if (types::is_integer_like_type(*type)) + store(it, 0); + else if (types::is_floating_point_type(*type)) + store_fp(it, 0, fp_mode_for(*type)); + else + throw std::runtime_error("Unsupported return value type"); + } + + void apply(ir::node_ref, ir::return_value const & node, types::type_ptr const &) + { + // TODO: struct/array return value? + if (node.value) + { + auto type = (*node.value)->inferred_type; + if (types::is_integer_like_type(*type)) + load(*node.value, 0); + else if (types::is_floating_point_type(*type)) + load_fp(*node.value, 0, fp_mode_for(*type)); + else + throw std::runtime_error("Unsupported return value type"); + } + if (stack_size > 0) + builder.add_imm(31, 31, stack_size); + builder.ret(); + } + + void compile(ir::node_ref begin, ir::node_ref end) + { + stack_size = 0; + for (auto it = begin; it != end; ++it) + { + if (ir::is_value_instruction(it->instruction)) + { + stack_size += 8; + stack_position[it] = stack_size; + } + } + stack_size = ((stack_size + 15) / 16) * 16; + + if (!std::holds_alternative(begin->instruction)) + throw std::runtime_error("First IR node of a function must be a label"); + + auto it = begin; + + lcontext.nodes[it] = pcontext.code.size(); + if (stack_size > 0) + builder.sub_imm(31, 31, stack_size); + ++it; + + for (; it != end; ++it) + { + lcontext.nodes[it] = pcontext.code.size(); + std::visit([&](auto const & instruction){ apply(it, instruction, it->inferred_type); }, it->instruction); + } + } + + private: + void load(ir::node_ref ir, std::uint8_t reg) + { + std::int32_t offset = stack_size - stack_position.at(ir); + builder.ldr(reg, 31, offset / 8); + } + + void load_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode) + { + std::int32_t offset = stack_size - stack_position.at(ir); + builder.ldr_fp(reg, mode, 31, offset / fp_size(mode)); + } + + void store(ir::node_ref ir, std::uint8_t reg) + { + std::int32_t offset = stack_size - stack_position.at(ir); + builder.str(reg, 31, offset / 8); + } + + void store_fp(ir::node_ref ir, std::uint8_t reg, std::uint8_t mode) + { + std::int32_t offset = stack_size - stack_position.at(ir); + builder.str_fp(reg, mode, 31, offset / fp_size(mode)); + } + + // Sign- or zero-extend the register depending on the exact type + void extend(std::uint8_t reg, types::type_ptr const & type) + { + reg_extend_visitor{{}, builder, reg}.apply(*type); + } + }; + + } + + void compile(program_context & pcontext, ir::module_context const & mcontext) + { + local_context lcontext; + + instruction_builder builder{pcontext.code}; + + { + populate_const_data_visitor visitor{pcontext, lcontext}; + for (auto it = mcontext.nodes->begin(); it != mcontext.nodes->end(); ++it) + std::visit([&](auto const & instruction){ visitor.apply(instruction, it->inferred_type); }, it->instruction); + } + + for (auto const & symbol : mcontext.symbols) + { + pcontext.symbols[symbol.first] = pcontext.code.size(); + compile_visitor visitor{pcontext, mcontext, lcontext, builder}; + visitor.compile(symbol.second.begin, symbol.second.end); + } + + pcontext.entry_point = lcontext.nodes.at(mcontext.entry_point); + + for (auto const & resolve : lcontext.branch_resolve) + builder.b_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4); + + for (auto const & resolve : lcontext.cbranch_resolve) + builder.cb_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4); + + for (auto const & resolve : lcontext.adr_resolve) + builder.adr_inject(pcontext.code.data() + resolve.offset, (lcontext.nodes.at(resolve.target) - resolve.offset) / 4); + } + +} diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp index d5e48e1..b88fc54 100644 --- a/libs/jit/source/arch/aarch64/instruction_builder.cpp +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -197,7 +197,17 @@ namespace pslang::jit::aarch64 do_push(0x14000000u | (std::uint32_t(offset) & 0x3ffffffu)); } + void instruction_builder::bl(std::int32_t offset) + { + do_push(0x94000000u | (std::uint32_t(offset) & 0x3ffffffu)); + } + void instruction_builder::b_reg(std::uint8_t reg) + { + do_push(0xd61f0000u | ((reg & REG_MASK) << 5)); + } + + void instruction_builder::bl_reg(std::uint8_t reg) { do_push(0xd63f0000u | ((reg & REG_MASK) << 5)); } @@ -286,7 +296,7 @@ namespace pslang::jit::aarch64 void instruction_builder::fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op) { - do_push(0x1e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (mode == 2 ? 0u : 0x8000000u) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16)); + do_push(0x9e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16)); } void instruction_builder::scvtf(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode) @@ -328,4 +338,4 @@ namespace pslang::jit::aarch64 code.push_back((opcode >> 24) & 0xffu); } -} \ No newline at end of file +} diff --git a/libs/jit/source/jit.cpp b/libs/jit/source/jit.cpp index 57ba72e..f520c39 100644 --- a/libs/jit/source/jit.cpp +++ b/libs/jit/source/jit.cpp @@ -10,13 +10,27 @@ namespace pslang::jit { switch (context.abi) { - case abi::itanium: - throw std::runtime_error("Itanium ABI JIT not implemented"); - case abi::msvc: - throw std::runtime_error("MSVC ABI JIT not implemented"); - case abi::armv8: - aarch64::compile(context, statements); - break; + case abi::itanium: + throw std::runtime_error("Itanium ABI JIT not implemented"); + case abi::msvc: + throw std::runtime_error("MSVC ABI JIT not implemented"); + case abi::armv8: + aarch64::compile(context, statements); + break; + } + } + + void compile(program_context & pcontext, ir::module_context const & mcontext) + { + switch (pcontext.abi) + { + case abi::itanium: + throw std::runtime_error("Itanium ABI JIT not implemented"); + case abi::msvc: + throw std::runtime_error("MSVC ABI JIT not implemented"); + case abi::armv8: + aarch64::compile(pcontext, mcontext); + break; } } diff --git a/libs/parser/include/pslang/parser/parser.hpp b/libs/parser/include/pslang/parser/parser.hpp index 6090dbb..ed46a89 100644 --- a/libs/parser/include/pslang/parser/parser.hpp +++ b/libs/parser/include/pslang/parser/parser.hpp @@ -7,6 +7,6 @@ namespace pslang::parser { - ast::statement_list_ptr parse(std::string_view path); + ast::statement_ptr parse(std::string_view path); } diff --git a/libs/parser/source/parser.cpp b/libs/parser/source/parser.cpp index be45284..886567d 100644 --- a/libs/parser/source/parser.cpp +++ b/libs/parser/source/parser.cpp @@ -8,15 +8,15 @@ namespace pslang::parser { - ast::statement_list_ptr parse(std::string_view path) + ast::statement_ptr parse(std::string_view path) { yyin = fopen(path.data(), "r"); if (!yyin) throw std::system_error(std::make_error_code(static_cast(errno))); ast::location location{.begin = {.filename = path}, .end = {.filename = path}}; - indented_statement_list result; - context ctx{location, result}; + indented_statement_list statements; + context ctx{location, statements}; bison::parser parser(ctx); @@ -24,7 +24,17 @@ namespace pslang::parser fclose(yyin); - return finalize(std::move(result)); + // Add a fake AST node for the entry point + return std::make_shared(ast::function_definition{ + { + "[entry point]", + {}, + std::make_shared(types::unit_type{}), + {}, + }, + finalize(std::move(statements)), + {}, + }); } } diff --git a/libs/types/include/pslang/types/type_fwd.hpp b/libs/types/include/pslang/types/type_fwd.hpp index 06c8c3b..a66049c 100644 --- a/libs/types/include/pslang/types/type_fwd.hpp +++ b/libs/types/include/pslang/types/type_fwd.hpp @@ -25,6 +25,10 @@ namespace pslang::types bool is_function_type(type const & type); bool is_pointer_type(type const & type); + // Everything that is probably stored in non-fp registers: + // bool, integers, pointers, function points + bool is_integer_like_type(types::type const & type); + std::size_t type_size(type const & type); } diff --git a/libs/types/source/type.cpp b/libs/types/source/type.cpp index 4b020c8..533c62e 100644 --- a/libs/types/source/type.cpp +++ b/libs/types/source/type.cpp @@ -137,6 +137,16 @@ namespace pslang::types return false; } + bool is_integer_like_type(types::type const & type) + { + return false + || is_bool_type(type) + || is_integer_type(type) + || is_function_type(type) + || is_pointer_type(type) + ; + } + std::size_t type_size(type const & type) { if (std::get_if(&type))