From 1a0dd2a48f7735879090cc5414074c9e660fcdda Mon Sep 17 00:00:00 2001 From: lisyarus Date: Tue, 17 Mar 2026 18:08:23 +0300 Subject: [PATCH] Aarch64 JIT compiler wip: partial struct support --- examples/jit_test.psl | 34 +- libs/ast/source/type_check.cpp | 41 +- .../jit/arch/aarch64/instruction_builder.hpp | 49 +- libs/jit/source/arch/aarch64/compiler.cpp | 442 +++++++++++++++--- .../arch/aarch64/instruction_builder.cpp | 52 ++- libs/types/include/pslang/types/type_fwd.hpp | 2 +- libs/types/source/type.cpp | 9 +- plans.txt | 20 +- 8 files changed, 570 insertions(+), 79 deletions(-) diff --git a/examples/jit_test.psl b/examples/jit_test.psl index 815aa5c..fd736e4 100644 --- a/examples/jit_test.psl +++ b/examples/jit_test.psl @@ -15,8 +15,6 @@ foreign func putchar(c: i32) -> i32 func print(c: u8): putchar(c as i32) -let x = 0 - func test1(): print('H') print('e') @@ -44,20 +42,28 @@ func test() -> i32: else: return b() -struct interval: - min: f32 - max: f32 +struct vec2: + x : f32 + y : f32 -struct box2f: - x: interval - y: interval +struct vecX: + a: f32 + b: f32 + c: f64 -struct weird: - a: i32 - b: u8 - c: f32 -> f32 - d: f64 - e: f16 +func test2(): + //let v = vec2(1.0, 2.0) + vec2(1.0, 2.0) + +//test2() + +mut v = vecX(6.0, 9.0, 4.0l) +v = vecX(1.0, 2.0, 5.0l) +v.a = 1.0 +//v.b = 2.0 +v.c = 5.0l +print('0' + (v.b as u8)) +print('\n') //func test1(): // let str = ['H', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd', '!', '\n'] diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 4e99efa..bb68879 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -63,12 +63,12 @@ namespace pslang::ast size_and_alignment apply(types::unit_type const &) { - return {.size = 1, .alignment = 1}; + return {.size = 0, .alignment = 1}; } size_and_alignment apply(types::primitive_type const & type) { - auto size = types::builtin_type_size(type); + auto size = types::type_size(type); return {.size = size, .alignment = size}; } @@ -561,7 +561,7 @@ namespace pslang::ast os << "Cannot create built-in type "; types::print(os, *type); os << ": expected 0 arguments, but got " << node.arguments.size(); - throw std::runtime_error(os.str()); + throw type_error(os.str(), node.location); } else if (auto named_type = std::get_if(type.get())) { @@ -702,6 +702,17 @@ namespace pslang::ast { apply(node.lhs); apply(node.rhs); + + if (auto lvalue = classify_lvalue(node.lhs)) + { + if (*lvalue != ast::value_category::_mutable) + throw type_error("Cannot assign to a non-mutable object", node.location); + } + else + { + throw type_error("Cannot assign to a non-lvalue", node.location); + } + auto ltype = get_type(*node.lhs); auto rtype = get_type(*node.rhs); // TODO: check lvalue @@ -863,6 +874,30 @@ namespace pslang::ast for (auto & struct_data : scopes.back().structs) compute_layout(struct_data.second, scopes); } + + private: + std::optional classify_lvalue(expression_ptr const & node) + { + if (auto identifier = std::get_if(node.get())) + { + auto const & scope = scopes[identifier->level]; + if (scope.functions.contains(identifier->name)) + return ast::value_category::constant; + if (auto it = scope.variables.find(identifier->name); it != scope.variables.end()) + return it->second.category; + return std::nullopt; + } + else if (auto field_access = std::get_if(node.get())) + { + return classify_lvalue(field_access->object); + } + else if (auto array_access = std::get_if(node.get())) + { + return classify_lvalue(array_access->array); + } + + return std::nullopt; + } }; } diff --git a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp index 59dacf5..97749fc 100644 --- a/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp +++ b/libs/jit/include/pslang/jit/arch/aarch64/instruction_builder.hpp @@ -20,7 +20,7 @@ namespace pslang::jit::aarch64 // @shift must be 0, 1, 2 or 3 void movk(std::uint8_t reg, std::uint16_t val, std::uint8_t shift = 0); - // Load the value of the register @reg_src at the address specified by the value of + // Store the value of the register @reg_src at the address specified by the value of // register @reg_addr plus an unsigned 12-bit offset multiplied by 8. void str(std::uint8_t reg_src, std::uint8_t reg_addr, std::uint16_t offset); @@ -28,6 +28,22 @@ namespace pslang::jit::aarch64 // plus a signed 9-bit offset. Store the new address value in @reg_addr void str_pre(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); + // Store the value of the register @reg_src at the address specified by the value of + // register @reg_addr plus an signed 9-bit offset. + void stur(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); + + // Store the lowest 32 bits of the value of the register @reg_src at the address specified by the value of + // register @reg_addr plus an signed 9-bit offset. + void sturw(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); + + // Store the lowest 16 bits of the value of the register @reg_src at the address specified by the value of + // register @reg_addr plus an signed 9-bit offset. + void sturh(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); + + // Store the lowest 9 bits of the value of the register @reg_src at the address specified by the value of + // register @reg_addr plus an signed 9-bit offset. + void sturb(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset); + // Load the value at address specified by the value of register @reg_addr // plus an unsigned 12-bit offset multiplied by 8 and store it in the register @reg_dst. void ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset); @@ -42,6 +58,22 @@ namespace pslang::jit::aarch64 // Store the address plus a signed 9-bit offset in @reg_addr void ldr_post(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); + // Load the value at address specified by the value of register @reg_addr + // plus an signed 9-bit offset and store it in the register @reg_dst. + void ldur(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); + + // Load the lowest 32 bits of the value at address specified by the value of register @reg_addr + // plus an signed 9-bit offset and store it in the register @reg_dst. + void ldurw(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); + + // Load the lowest 16 bits of the value at address specified by the value of register @reg_addr + // plus an signed 9-bit offset and store it in the register @reg_dst. + void ldurh(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); + + // Load the lowest 8 bits of the value at address specified by the value of register @reg_addr + // plus an signed 9-bit offset and store it in the register @reg_dst. + void ldurb(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset); + // Load the value at address specified by the value of the program counter (PC) // plus a signed 19-bit @offset multiplied by 4, and store it into register @reg_dst void ldr_pc(std::uint8_t reg_dst, std::int32_t offset); @@ -135,15 +167,25 @@ namespace pslang::jit::aarch64 void ldr_fp_pc(std::uint8_t reg_dst, std::uint8_t mode, std::int32_t offset); // Load a floating-point value from the address stored in @reg_addr plus a - // 12-bit unsigned @offset multiplied by 4, and store it in floating-point + // 12-bit unsigned @offset multiplied by sizeof the type (2, 4 or 8), and store it in floating-point // register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit void ldr_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); + // Load a floating-point value from the address stored in @reg_addr plus a + // 9-bit signed @offset, and store it in floating-point + // register @reg_dst. @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit + void ldur_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); + // Store a floating-point value from @reg_src into the address stored in - // @reg_addr plus a 12-bit unsigned @offset multiplied by 4. + // @reg_addr plus a 12-bit unsigned @offset multiplied by sizeof the type (2, 4 or 8). // @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit void str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset); + // Store a floating-point value from @reg_src into the address stored in + // @reg_addr plus a 9-bit signed @offset. + // @mode should be 1 for 16-bit values, 2 for 32-bit, 3 for 64-bit + void stur_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::int16_t offset); + // Convert a floating-point value from @reg_src stored in format @mode_src // into a floating-point value in @reg_dst in format @mode_dst // Modes are the same as for ldr_fp/str_fp @@ -161,7 +203,6 @@ namespace pslang::jit::aarch64 // If @op is 0, move the value of a floating-point register @reg_src in format @mode into a general register @reg_dst // If @op is 1, move the value of a general register @reg_src into a floating-point register @reg_dst in format @mode - // NB: mode = 2 (single precision) isn't supported void fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op); // Convert signed integer in floating-point register @reg_src into floating-point format @mode and diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 661f0e3..403bd7f 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -18,6 +18,14 @@ namespace pslang::jit::aarch64 namespace { + // Homogeneous floating-point aggregate: up to 4 floating-point members + // of the same type (after struct flattening) + struct hfa_data + { + types::type_ptr type; + std::size_t count; + }; + struct local_context { std::unordered_map f16_constants; @@ -28,10 +36,18 @@ namespace pslang::jit::aarch64 std::unordered_map functions; + struct struct_data + { + std::optional hfa = {}; + }; + + std::unordered_map structs; + struct scope { std::unordered_set foreign_functions; std::unordered_map functions; + std::unordered_map structs; }; std::vector scopes; @@ -54,6 +70,9 @@ namespace pslang::jit::aarch64 if (it->functions.contains(name)) return false; + + if (it->structs.contains(name)) + return false; } return false; } @@ -62,11 +81,30 @@ namespace pslang::jit::aarch64 { for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) { + if (auto jt = it->functions.find(name); jt != it->functions.end()) + return jt->second; + if (it->foreign_functions.contains(name)) return nullptr; - if (auto jt = it->functions.find(name); jt != it->functions.end()) + if (it->structs.contains(name)) + return nullptr; + } + return nullptr; + } + + ast::struct_definition const * is_struct(std::string const & name) + { + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + if (auto jt = it->structs.find(name); jt != it->structs.end()) return jt->second; + + if (it->foreign_functions.contains(name)) + return nullptr; + + if (it->functions.contains(name)) + return nullptr; } return nullptr; } @@ -158,7 +196,8 @@ namespace pslang::jit::aarch64 void apply(ast::function_call const & node) { - apply(*node.function); + if (node.function) + apply(*node.function); for (auto const & argument : node.arguments) apply(*argument); } @@ -258,6 +297,56 @@ namespace pslang::jit::aarch64 } }; + std::optional get_hfa_data(ast::struct_definition const & node, local_context & lcontext) + { + if (auto it = lcontext.structs.find(&node); it != lcontext.structs.end()) + return it->second.hfa; + + types::type_ptr type = nullptr; + std::size_t count = 0; + + for (auto const & field : node.fields) + { + if (types::is_builtin_type(*field.inferred_type)) + { + if (!types::is_floating_point_type(*field.inferred_type)) + return std::nullopt; + + if (type && !types::equal(*type, *field.inferred_type)) + return std::nullopt; + + type = field.inferred_type; + ++count; + } + else if (auto named_type = std::get_if(field.inferred_type.get())) + { + // NB: recursion must be impossible due to prior checks in type checker + if (auto struct_node = lcontext.is_struct(named_type->name)) + { + if (auto subdata = get_hfa_data(*struct_node, lcontext)) + { + if (type && !types::equal(*type, *subdata->type)) + return std::nullopt; + + type = subdata->type; + count += subdata->count; + } + else + return std::nullopt; + } + else + throw std::runtime_error("Unknown named type: \"" + named_type->name + "\""); + } + else + return std::nullopt; + } + + if (count <= 4) + return hfa_data{type, count}; + + return std::nullopt; + } + // Iterate over a single scope (i.e. not visiting subscopes recursively) // and add all defined functions & foreign functions to the current scope struct populate_symbols_visitor @@ -297,7 +386,16 @@ namespace pslang::jit::aarch64 void apply(ast::field_definition const &) {} - void apply(ast::struct_definition const &) {} + void apply(ast::struct_definition const & node) + { + lcontext.scopes.back().structs[node.name] = &node; + if (!lcontext.structs.contains(&node)) + { + // NB: make sure not to add struct to lcontext.structs before computing hfa data + auto hfa = get_hfa_data(node, lcontext); + lcontext.structs[&node].hfa = hfa; + } + } }; struct reg_extend_visitor @@ -363,9 +461,9 @@ namespace pslang::jit::aarch64 struct variable_data { - // Difference between initial stack pointer at scope enter + // Difference between initial stack pointer at function enter // and the variable address - // Must be a multiple of 8 + // Must be a multiple of 16 std::uint32_t frame_offset; }; @@ -450,8 +548,30 @@ namespace pslang::jit::aarch64 { if (auto jt = it->variables.find(node.name); jt != it->variables.end()) { - if (types::is_floating_point_type(*node.inferred_type)) - builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / builtin_type_size(*node.inferred_type)); + if (auto named_type = std::get_if(node.inferred_type.get())) + { + if (auto struct_node = lcontext.is_struct(named_type->name)) + { + std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16; + builder.sub_imm(31, 31, stack_size); + stack_offset += stack_size; + scopes.back().stack_offset += stack_size; + std::size_t variable_offset = stack_offset - jt->second.frame_offset; + for (std::size_t offset = 0; offset < stack_size; offset += 16) + { + builder.ldr(0, 31, (variable_offset + offset) / 8); + builder.ldr(1, 31, (variable_offset + offset) / 8 + 1); + builder.str(0, 31, offset / 8); + builder.str(1, 31, offset / 8 + 1); + } + } + else + throw std::runtime_error("Unknown type \"" + named_type->name + "\""); + } + else if (types::is_unit_type(*node.inferred_type)) + {} + else if (types::is_floating_point_type(*node.inferred_type)) + builder.ldr_fp(0, fp_mode_for(*node.inferred_type), 31, (stack_offset - jt->second.frame_offset) / type_size(*node.inferred_type)); else builder.ldr(0, 31, (stack_offset - jt->second.frame_offset) / 8); return; @@ -767,71 +887,228 @@ namespace pslang::jit::aarch64 void apply(ast::function_call const & node) { - apply(*node.function); - push(0); - - for (std::size_t i = node.arguments.size(); i --> 0;) + if (node.function) { - auto const & arg = node.arguments[i]; - apply(*arg); - auto type = ast::get_type(*arg); - if (types::is_bool_type(*type) || types::is_integer_type(*type)) + apply(*node.function); + push(0); + + for (std::size_t i = node.arguments.size(); i --> 0;) { - push(0); + auto const & arg = node.arguments[i]; + apply(*arg); + auto type = ast::get_type(*arg); + if (types::is_bool_type(*type) || types::is_integer_type(*type)) + { + push(0); + } + else if (types::is_floating_point_type(*type)) + { + push_fp(0, fp_mode_for(*type)); + } } - else if (types::is_floating_point_type(*type)) + + std::uint8_t reg = 0; + std::uint8_t fp_reg = 0; + for (auto const & arg : node.arguments) { - push_fp(0, fp_mode_for(*type)); + auto type = ast::get_type(*arg); + if (types::is_bool_type(*type) || types::is_integer_type(*type)) + { + pop(reg); + ++reg; + } + else if (types::is_floating_point_type(*type)) + { + pop_fp(fp_reg, fp_mode_for(*type)); + ++fp_reg; + } + } + + pop(reg); + push(30); + builder.b_reg(reg); + pop(30); + } + else // if (node.type) + { + if (types::is_unit_type(*node.inferred_type)) + { + // Do nothing + } + else if (types::is_bool_type(*node.inferred_type) || types::is_integer_type(*node.inferred_type) || types::is_function_type(*node.inferred_type)) + { + builder.xor_reg(0, 0, 0); + } + else if (types::is_floating_point_type(*node.inferred_type)) + { + builder.xor_reg(0, 0, 0); + builder.fmov(0, 0, fp_mode_for(*node.inferred_type), 1); + } + else if (auto named_type = std::get_if(node.inferred_type.get())) + { + if (auto struct_node = lcontext.is_struct(named_type->name)) + { + // Allocate stack space for the struct + std::size_t stack_size = ((struct_node->layout.size + 15) / 16) * 16; + auto offset = stack_offset; + stack_offset += stack_size; + scopes.back().stack_offset += stack_size; + builder.sub_imm(31, 31, stack_size); + + // Evaluate each field of the struct (i.e. each constructor argument) + // and copy it to the corresponding place in the struct + for (std::size_t i = 0; i < node.arguments.size(); ++i) + { + auto type = ast::get_type(*node.arguments[i]); + apply(*node.arguments[i]); + + if (std::get_if(type.get())) + { + // TODO: struct field + throw std::runtime_error("Not implemented"); + } + else if (types::is_floating_point_type(*type)) + { + builder.stur_fp(0, fp_mode_for(*type), 31, struct_node->fields[i].layout.offset); + } + else + { + auto size = types::type_size(*type); + + if (size == 1) + builder.sturb(0, 31, struct_node->fields[i].layout.offset); + else if (size == 2) + builder.sturh(0, 31, struct_node->fields[i].layout.offset); + else if (size == 4) + builder.sturw(0, 31, struct_node->fields[i].layout.offset); + else if (size == 8) + builder.stur(0, 31, struct_node->fields[i].layout.offset); + } + } + } + else + throw std::runtime_error("Unknown type \"" + named_type->name + "\""); + } + } + } + + void apply(ast::field_access const & node) + { + auto struct_type = get_type(*node.object); + if (auto named_type = std::get_if(struct_type.get())) + { + if (auto struct_node = lcontext.is_struct(named_type->name)) + { + std::size_t field_id = -1; + for (std::size_t i = 0; i < struct_node->fields.size(); ++i) + { + if (struct_node->fields[i].name == node.field_name) + { + field_id = i; + break; + } + } + + if (field_id == -1) + throw std::runtime_error("Unknown field \"" + node.field_name + "\" in struct \"" + named_type->name + "\""); + + apply(*node.object); + + auto stack_size = ((struct_node->layout.size + 15) / 16) * 16; + + auto const & field = struct_node->fields[field_id]; + if (types::is_unit_type(*field.inferred_type)) + {} + else if (types::is_floating_point_type(*field.inferred_type)) + { + builder.ldur_fp(0, fp_mode_for(*field.inferred_type), 31, field.layout.offset); + + builder.add_imm(31, 31, stack_size); + stack_offset -= stack_size; + scopes.back().stack_offset -= stack_size; + } + else if (types::is_bool_type(*field.inferred_type) || types::is_integer_type(*field.inferred_type) || types::is_function_type(*field.inferred_type)) + { + auto size = types::type_size(*field.inferred_type); + if (size == 1) + builder.ldurb(0, 31, field.layout.offset); + else if (size == 2) + builder.ldurh(0, 31, field.layout.offset); + else if (size == 4) + builder.ldurw(0, 31, field.layout.offset); + else if (size == 8) + builder.ldur(0, 31, field.layout.offset); + + builder.add_imm(31, 31, stack_size); + stack_offset -= stack_size; + scopes.back().stack_offset -= stack_size; + } + else if (auto named_type = std::get_if(field.inferred_type.get())) + { + if (auto field_struct_node = lcontext.is_struct(named_type->name)) + { + // TODO: copy the struct-typed field on stack, overriding + // the struct itself, and update the stack offset + throw std::runtime_error("Not implemented"); + } + else + throw std::runtime_error("Unknown type \"" + named_type->name + "\""); + } + + return; } } - std::uint8_t reg = 0; - std::uint8_t fp_reg = 0; - for (auto const & arg : node.arguments) - { - auto type = ast::get_type(*arg); - if (types::is_bool_type(*type) || types::is_integer_type(*type)) - { - pop(reg); - ++reg; - } - else if (types::is_floating_point_type(*type)) - { - pop_fp(fp_reg, fp_mode_for(*type)); - ++fp_reg; - } - } - - pop(reg); - push(30); - builder.b_reg(reg); - pop(30); + throw std::runtime_error("Unknown object in field access"); } void apply(ast::expression_ptr const & node) { + auto stack_offset_before = stack_offset; apply(*node); + + // Restore stack offset in case the expression evaluated to a struct + // (in which case the struct would be placed on the stack) + auto stack_delta = stack_offset - stack_offset_before; + if (stack_delta > 0) + { + builder.add_imm(31, 31, stack_delta); + stack_offset -= stack_delta; + scopes.back().stack_offset -= stack_delta; + } } void apply(ast::assignment const & node) { - auto identifier = std::get_if(node.lhs.get()); - if (!identifier) - throw std::runtime_error("assignment for non-identifier lhs is not implemented"); + auto frame_offset = lvalue_offset(node.lhs); apply(*node.rhs); - - for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + auto type = ast::get_type(*node.rhs); + if (types::is_unit_type(*type)) + {} + else if (types::is_floating_point_type(*type)) + builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - frame_offset) / type_size(*type)); + else if (types::is_bool_type(*type) || types::is_integer_type(*type) || types::is_function_type(*type)) { - if (auto jt = it->variables.find(identifier->name); jt != it->variables.end()) + auto size = types::type_size(*type); + if (size == 1) + builder.sturb(0, 31, stack_offset - frame_offset); + else if (size == 2) + builder.sturh(0, 31, stack_offset - frame_offset); + else if (size == 4) + builder.sturw(0, 31, stack_offset - frame_offset); + else if (size == 8) + builder.stur(0, 31, stack_offset - frame_offset); + } + else if (auto named_type = std::get_if(type.get())) + { + if (auto struct_node = lcontext.is_struct(named_type->name)) { - auto type = ast::get_type(*node.rhs); - if (types::is_floating_point_type(*type)) - builder.str_fp(0, fp_mode_for(*type), 31, (stack_offset - jt->second.frame_offset) / builtin_type_size(*type)); - else - builder.str(0, 31, (stack_offset - jt->second.frame_offset) / 8); - break; + // TODO: whole-struct assignment + throw std::runtime_error("Not implemented"); } + else + throw std::runtime_error("Unknown type \"" + named_type->name + "\""); } } @@ -839,8 +1116,18 @@ namespace pslang::jit::aarch64 { apply(*node.initializer); auto type = ast::get_type(*node.initializer); - if (types::is_floating_point_type(*type)) + if (std::get_if(type.get())) + { + // Nothing to be done: the struct is already on the stack + // Just record the stack offset as variable location + } + else if (types::is_floating_point_type(*type)) push_fp(0, fp_mode_for(*type)); + else if (types::is_unit_type(*type)) + { + // Nothing to be done: unit type has zero size + // Its stack position is recorded but de facto unused + } else push(0); scopes.back().variables[node.name] = {.frame_offset = stack_offset}; @@ -909,6 +1196,7 @@ namespace pslang::jit::aarch64 void apply(ast::return_statement const & node) { + // TODO: struct return value if (node.value) apply(*node.value); do_return(); @@ -924,6 +1212,11 @@ namespace pslang::jit::aarch64 // Must be handled prior to that in populate_symbols_visitor } + void apply(ast::struct_definition const &) + { + // Must be handled prior to that in populate_symbols_visitor + } + void apply(ast::statement_list const & node) { lcontext.scopes.emplace_back(); @@ -1036,6 +1329,51 @@ namespace pslang::jit::aarch64 stack_offset -= scopes.back().stack_offset; } } + + // Returns offset from function entry stack frame + std::size_t lvalue_offset(ast::expression_ptr const & node) + { + if (auto identifier = std::get_if(node.get())) + { + auto const & scope = scopes[identifier->level]; + if (auto it = scope.variables.find(identifier->name); it != scope.variables.end()) + return it->second.frame_offset; + throw std::runtime_error("Non-lvalue identifier: \"" + identifier->name + "\""); + } + else if (auto field_access = std::get_if(node.get())) + { + auto base_offset = lvalue_offset(field_access->object); + auto type = ast::get_type(*field_access->object); + if (auto named_type = std::get_if(type.get())) + { + if (auto struct_node = lcontext.is_struct(named_type->name)) + { + std::size_t field_id = -1; + for (std::size_t i = 0; i < struct_node->fields.size(); ++i) + if (struct_node->fields[i].name == field_access->field_name) + { + field_id = i; + break; + } + + if (field_id == -1) + throw std::runtime_error("Invalid field \"" + field_access->field_name + "\""); + + return base_offset - struct_node->fields[field_id].layout.offset; + } + else + throw std::runtime_error("Invalid struct \"" + named_type->name + "\""); + } + else + throw std::runtime_error("Invalid field access node"); + } + else if (auto array_access = std::get_if(node.get())) + { + throw std::runtime_error("Not implemented"); + } + + throw std::runtime_error("Unknown lvalue node"); + } }; // Main compilation visitor diff --git a/libs/jit/source/arch/aarch64/instruction_builder.cpp b/libs/jit/source/arch/aarch64/instruction_builder.cpp index 63e4311..d5e48e1 100644 --- a/libs/jit/source/arch/aarch64/instruction_builder.cpp +++ b/libs/jit/source/arch/aarch64/instruction_builder.cpp @@ -25,6 +25,26 @@ namespace pslang::jit::aarch64 do_push(0xf8000c00u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); } + void instruction_builder::stur(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0xf8000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::sturw(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0xb8000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::sturh(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0x78000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::sturb(std::uint8_t reg_src, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0x38000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + void instruction_builder::ldr(std::uint8_t reg_dst, std::uint8_t reg_addr, std::uint16_t offset) { do_push(0xf9400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0xfffu) << 10)); @@ -40,6 +60,26 @@ namespace pslang::jit::aarch64 do_push(0xf8400400u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); } + void instruction_builder::ldur(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0xf8400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::ldurw(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0xb8400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::ldurh(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0x78400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + + void instruction_builder::ldurb(std::uint8_t reg_dst, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0x38400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12)); + } + void instruction_builder::ldr_pc(std::uint8_t reg_dst, std::int32_t offset) { do_push(0x58000000u | (reg_dst & REG_MASK) | ((((std::uint32_t)offset) & 0x7ffffu) << 5)); @@ -194,11 +234,21 @@ namespace pslang::jit::aarch64 do_push(0x3d400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30)); } + void instruction_builder::ldur_fp(std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset) + { + do_push(0x3c400000u | (reg_dst & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12) | ((mode & 0x3u) << 30)); + } + void instruction_builder::str_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::uint16_t offset) { do_push(0x3d000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((offset & 0xfffu) << 10) | ((mode & 0x3u) << 30)); } + void instruction_builder::stur_fp(std::uint8_t reg_src, std::uint8_t mode, std::uint8_t reg_addr, std::int16_t offset) + { + do_push(0x3c000000u | (reg_src & REG_MASK) | ((reg_addr & REG_MASK) << 5) | ((std::uint16_t(offset) & 0x1ffu) << 12) | ((mode & 0x3u) << 30)); + } + void instruction_builder::fcvt(std::uint8_t reg_src, std::uint8_t mode_src, std::uint8_t reg_dst, std::uint8_t mode_dst) { do_push(0x1e224000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode_dst & 0x3u) ^ 0x2u) << 15) | (((mode_src & 0x3u) ^ 0x2u) << 22)); @@ -236,7 +286,7 @@ namespace pslang::jit::aarch64 void instruction_builder::fmov(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode, std::uint8_t op) { - do_push(0x9e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16)); + do_push(0x1e260000u | (reg_dst & REG_MASK) | ((reg_src & REG_MASK) << 5) | (mode == 2 ? 0u : 0x8000000u) | (((mode & 0x3u) ^ 0x2u) << 22) | ((op & 0x1u) << 16)); } void instruction_builder::scvtf(std::uint8_t reg_src, std::uint8_t reg_dst, std::uint8_t mode) diff --git a/libs/types/include/pslang/types/type_fwd.hpp b/libs/types/include/pslang/types/type_fwd.hpp index 5d81453..2a1a519 100644 --- a/libs/types/include/pslang/types/type_fwd.hpp +++ b/libs/types/include/pslang/types/type_fwd.hpp @@ -24,6 +24,6 @@ namespace pslang::types bool is_builtin_type(type const & type); bool is_function_type(type const & type); - std::size_t builtin_type_size(type const & type); + std::size_t type_size(type const & type); } diff --git a/libs/types/source/type.cpp b/libs/types/source/type.cpp index b660953..f7b343b 100644 --- a/libs/types/source/type.cpp +++ b/libs/types/source/type.cpp @@ -130,10 +130,13 @@ namespace pslang::types return false; } - std::size_t builtin_type_size(type const & type) + std::size_t type_size(type const & type) { if (std::get_if(&type)) - return 1; + return 0; + + if (std::get_if(&type)) + return 8; if (auto ptype = std::get_if(&type)) { @@ -148,7 +151,7 @@ namespace pslang::types }, *ptype); } - return 0; + throw std::runtime_error("Unknown type"); } } diff --git a/plans.txt b/plans.txt index eb4a6e2..5bf4e82 100644 --- a/plans.txt +++ b/plans.txt @@ -1,4 +1,5 @@ Future plans: +* Split compiler into IR generation, IR optimization, and target-specific bytecode emitter * Globals (requires a separate mmaped segment in JIT compiler) * Pointers: pointer types, address-of operator (&), dereferencing, scope-based lifetime tracking in interpreter * Function overloading: separate functions from values (again) in interpreter, allow casting to specific function type to take function value @@ -15,9 +16,26 @@ Interpreter backlog: * C FFI (foreign functions) Aarch64 compiler backlog: -* Structs +* Struct fields in structs (initialization & field access) +* Struct function arguments & return values * Arrays +IR outline: +* Doubly-linked list (std::list will do) of instruction nodes +* Nodes reference other nodes for input data or branching target +* No temporary variables, nodes themselves represent data flow +* Assignment expressions turn into nodes that assign to previously-defined nodes +* Jump targets (first `while` node, node after an `if`, etc) are generated as nop nodes, to prevent them from being removed by the optimizer +* Structs & arrays can be split into several independent nodes per-field (before pointers are introduced) + +IR optimizations: +* Inlining: track function IR size (number of nodes will do), substitute its code instead of calling if it is small enough (pay attention to recursion) +* Constant folding & propagation: if all node arguments are `const` nodes, replace the current node with the computed value +* Arithmetic simplification: replace a+0 with a, etc +* Branch removal: `if` nodes with condition nodes being `const` are replaced with unconditional jumps +* Jump removal: `jump` that jumps to the immediate successor node is removed +* Dead code elimination: DFS from function `return` nodes & impure function calls, remove all nodes not visited (make sure to not remove function arguments) + General backlog: * Mutually recursive structs (relevant only with pointers) * Empty array expression