diff --git a/libs/ast/include/pslang/ast/operation.hpp b/libs/ast/include/pslang/ast/operation.hpp index abc8f76..898b71e 100644 --- a/libs/ast/include/pslang/ast/operation.hpp +++ b/libs/ast/include/pslang/ast/operation.hpp @@ -32,6 +32,22 @@ namespace pslang::ast greater_equals, }; + inline bool is_comparison(binary_operation_type type) + { + switch (type) + { + case binary_operation_type::equals: + case binary_operation_type::not_equals: + case binary_operation_type::less: + case binary_operation_type::greater: + case binary_operation_type::less_equals: + case binary_operation_type::greater_equals: + return true; + default: + return false; + } + } + template Ostream & operator << (Ostream & out, unary_operation_type type) { diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 73d1152..ed1e562 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -340,6 +340,7 @@ namespace pslang::ast auto arg2_type = get_type(*node.arg2); bool equal = types::equal(*arg1_type, *arg2_type); + bool both_integers = types::is_integer_type(*arg1_type) && types::is_integer_type(*arg2_type); switch (node.type) { @@ -442,42 +443,42 @@ namespace pslang::ast } break; case binary_operation_type::equals: - if (equal) + if (equal || both_integers) { node.inferred_type = std::make_unique(types::bool_type{}); return; } break; case binary_operation_type::not_equals: - if (equal) + if (equal || both_integers) { node.inferred_type = std::make_unique(types::bool_type{}); return; } break; case binary_operation_type::less: - if (equal) + if (equal || both_integers) { node.inferred_type = std::make_unique(types::bool_type{}); return; } break; case binary_operation_type::greater: - if (equal) + if (equal || both_integers) { node.inferred_type = std::make_unique(types::bool_type{}); return; } break; case binary_operation_type::less_equals: - if (equal) + if (equal || both_integers) { node.inferred_type = std::make_unique(types::bool_type{}); return; } break; case binary_operation_type::greater_equals: - if (equal) + if (equal || both_integers) { node.inferred_type = std::make_unique(types::bool_type{}); return; diff --git a/libs/ir/include/pslang/ir/node.hpp b/libs/ir/include/pslang/ir/node.hpp index 95f5522..6b396c5 100644 --- a/libs/ir/include/pslang/ir/node.hpp +++ b/libs/ir/include/pslang/ir/node.hpp @@ -89,6 +89,13 @@ namespace pslang::ir node_ref target; }; + struct jump_if_nonzero + { + // Condition must be a bool-valued node + node_ref condition; + node_ref target; + }; + // Call to a specific node struct call { @@ -123,6 +130,7 @@ namespace pslang::ir assignment, jump, jump_if_zero, + jump_if_nonzero, call, call_pointer, return_value diff --git a/libs/ir/source/compiler.cpp b/libs/ir/source/compiler.cpp index ed3f472..9417af6 100644 --- a/libs/ir/source/compiler.cpp +++ b/libs/ir/source/compiler.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include @@ -39,6 +40,27 @@ namespace pslang::ir std::vector resolve_call; }; + struct zero_literal_visitor + : types::const_visitor + { + module_context & mcontext; + + using const_visitor::apply; + + template + void apply(types::primitive_type_base) + { + mcontext.nodes->emplace_back(literal{ast::literal{ast::primitive_literal_base{.value = {}}}}, + std::make_shared(types::primitive_type{types::primitive_type_base{}})); + } + + template + void apply(T const &) + { + throw std::runtime_error("Invalid type for zero_literal_visitor"); + } + }; + // Compile a single function and store the entry point node_ref // in local_context struct compile_function_visitor @@ -181,6 +203,120 @@ namespace pslang::ir return last(); } + // Different-type integer comparison + + if (!types::equal(*arg1_type, *arg2_type) + && types::is_integer_type(*arg1_type) + && types::is_integer_type(*arg2_type) + && ast::is_comparison(node.type)) + { + bool arg1_unsigned = types::is_unsigned_integer_type(*arg1_type); + bool arg2_unsigned = types::is_unsigned_integer_type(*arg2_type); + std::size_t arg1_size = types::type_size(*arg1_type); + std::size_t arg2_size = types::type_size(*arg2_type); + std::size_t max_size = std::max(arg1_size, arg2_size); + + types::type_ptr max_type = (arg1_size > arg2_size) ? arg1_type : arg2_type; + + if ((arg1_unsigned && arg2_unsigned) || (!arg1_unsigned && !arg2_unsigned)) + { + // Both signed or both unsigned: just cast the smaller one to the larger type + if (arg1_size < arg2_size) + { + mcontext.nodes->emplace_back(cast_operation{arg1, max_type}, max_type); + mcontext.nodes->emplace_back(binary_operation{node.type, last(), arg2}, node.inferred_type); + return last(); + } + else + { + mcontext.nodes->emplace_back(cast_operation{arg2, max_type}, max_type); + mcontext.nodes->emplace_back(binary_operation{node.type, arg1, last()}, node.inferred_type); + return last(); + } + } + else + { + // Different signedness + + // Swap arg1 and arg2 if arg1 is unsigned, and reverse the operation + auto type = node.type; + if (arg1_unsigned) + { + std::swap(arg1, arg2); + std::swap(arg1_type, arg2_type); + std::swap(arg1_size, arg2_size); + std::swap(arg1_is_pointer, arg2_is_pointer); + + if (type == ast::binary_operation_type::less) + type = ast::binary_operation_type::greater; + else if (type == ast::binary_operation_type::greater) + type = ast::binary_operation_type::less; + else if (type == ast::binary_operation_type::less_equals) + type = ast::binary_operation_type::greater_equals; + else if (type == ast::binary_operation_type::greater_equals) + type = ast::binary_operation_type::less_equals; + } + + // Compare with zero first + zero_literal_visitor{{}, mcontext}.apply(*arg1_type); + switch (type) + { + case ast::binary_operation_type::equals: + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type); + mcontext.nodes->emplace_back(jump_if_zero{last(), {}}); + break; + case ast::binary_operation_type::not_equals: + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type); + mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}}); + break; + case ast::binary_operation_type::less: + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type); + mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}}); + break; + case ast::binary_operation_type::greater: + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater, arg1, last()}, node.inferred_type); + mcontext.nodes->emplace_back(jump_if_zero{last(), {}}); + break; + case ast::binary_operation_type::less_equals: + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less_equals, arg1, last()}, node.inferred_type); + mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}}); + break; + case ast::binary_operation_type::greater_equals: + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type); + mcontext.nodes->emplace_back(jump_if_zero{last(), {}}); + break; + default: + break; + } + auto result = std::prev(last()); + auto jump_node = last(); + + // Less-than-zero case handled, here arg1 is nonnegative - cast both to the largest unsigned type + types::type_ptr max_unsigned_type; + if (max_size == 1) + max_unsigned_type = std::make_unique(types::primitive_type{types::u8_type{}}); + else if (max_size == 2) + max_unsigned_type = std::make_unique(types::primitive_type{types::u16_type{}}); + else if (max_size == 4) + max_unsigned_type = std::make_unique(types::primitive_type{types::u32_type{}}); + else if (max_size == 8) + max_unsigned_type = std::make_unique(types::primitive_type{types::u64_type{}}); + + mcontext.nodes->emplace_back(cast_operation{arg1, max_unsigned_type}, max_unsigned_type); + auto new_arg1 = last(); + mcontext.nodes->emplace_back(cast_operation{arg2, max_unsigned_type}, max_unsigned_type); + auto new_arg2 = last(); + mcontext.nodes->emplace_back(binary_operation{type, new_arg1, new_arg2}, node.inferred_type); + mcontext.nodes->emplace_back(assignment{result, last()}); + mcontext.nodes->emplace_back(label{}); + if (auto jump_if_zero = std::get_if(&jump_node->instruction)) + jump_if_zero->target = last(); + else if (auto jump_if_nonzero = std::get_if(&jump_node->instruction)) + jump_if_nonzero->target = last(); + return result; + } + } + // General case mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type); diff --git a/libs/ir/source/print.cpp b/libs/ir/source/print.cpp index 2dfc27d..6de6774 100644 --- a/libs/ir/source/print.cpp +++ b/libs/ir/source/print.cpp @@ -261,6 +261,11 @@ namespace pslang::ir out << "jz $" << get_index(instruction.condition) << " $" << get_index(instruction.target); } + void operator()(jump_if_nonzero const & instruction) + { + out << "jnz $" << get_index(instruction.condition) << " $" << get_index(instruction.target); + } + void operator()(call const & instruction) { out << "call $" << get_index(instruction.target); diff --git a/libs/jit/source/arch/aarch64/compiler_v2.cpp b/libs/jit/source/arch/aarch64/compiler_v2.cpp index c77b154..1e7150c 100644 --- a/libs/jit/source/arch/aarch64/compiler_v2.cpp +++ b/libs/jit/source/arch/aarch64/compiler_v2.cpp @@ -620,6 +620,13 @@ namespace pslang::jit::aarch64 builder.cbz(0, 0); } + void apply(ir::node_ref, ir::jump_if_nonzero const & node, types::type_ptr const &) + { + load(node.condition, 0); + lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target); + builder.cbnz(0, 0); + } + void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type) { // TODO: struct/array arguments?