Support comparison operators for integers of different types

This commit is contained in:
Nikita Lisitsa 2026-03-26 23:49:28 +03:00
parent 81825aee7f
commit 14cb393076
6 changed files with 179 additions and 6 deletions

View file

@ -32,6 +32,22 @@ namespace pslang::ast
greater_equals, greater_equals,
}; };
inline bool is_comparison(binary_operation_type type)
{
switch (type)
{
case binary_operation_type::equals:
case binary_operation_type::not_equals:
case binary_operation_type::less:
case binary_operation_type::greater:
case binary_operation_type::less_equals:
case binary_operation_type::greater_equals:
return true;
default:
return false;
}
}
template <typename Ostream> template <typename Ostream>
Ostream & operator << (Ostream & out, unary_operation_type type) Ostream & operator << (Ostream & out, unary_operation_type type)
{ {

View file

@ -340,6 +340,7 @@ namespace pslang::ast
auto arg2_type = get_type(*node.arg2); auto arg2_type = get_type(*node.arg2);
bool equal = types::equal(*arg1_type, *arg2_type); bool equal = types::equal(*arg1_type, *arg2_type);
bool both_integers = types::is_integer_type(*arg1_type) && types::is_integer_type(*arg2_type);
switch (node.type) switch (node.type)
{ {
@ -442,42 +443,42 @@ namespace pslang::ast
} }
break; break;
case binary_operation_type::equals: case binary_operation_type::equals:
if (equal) if (equal || both_integers)
{ {
node.inferred_type = std::make_unique<types::type>(types::bool_type{}); node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return; return;
} }
break; break;
case binary_operation_type::not_equals: case binary_operation_type::not_equals:
if (equal) if (equal || both_integers)
{ {
node.inferred_type = std::make_unique<types::type>(types::bool_type{}); node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return; return;
} }
break; break;
case binary_operation_type::less: case binary_operation_type::less:
if (equal) if (equal || both_integers)
{ {
node.inferred_type = std::make_unique<types::type>(types::bool_type{}); node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return; return;
} }
break; break;
case binary_operation_type::greater: case binary_operation_type::greater:
if (equal) if (equal || both_integers)
{ {
node.inferred_type = std::make_unique<types::type>(types::bool_type{}); node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return; return;
} }
break; break;
case binary_operation_type::less_equals: case binary_operation_type::less_equals:
if (equal) if (equal || both_integers)
{ {
node.inferred_type = std::make_unique<types::type>(types::bool_type{}); node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return; return;
} }
break; break;
case binary_operation_type::greater_equals: case binary_operation_type::greater_equals:
if (equal) if (equal || both_integers)
{ {
node.inferred_type = std::make_unique<types::type>(types::bool_type{}); node.inferred_type = std::make_unique<types::type>(types::bool_type{});
return; return;

View file

@ -89,6 +89,13 @@ namespace pslang::ir
node_ref target; node_ref target;
}; };
struct jump_if_nonzero
{
// Condition must be a bool-valued node
node_ref condition;
node_ref target;
};
// Call to a specific node // Call to a specific node
struct call struct call
{ {
@ -123,6 +130,7 @@ namespace pslang::ir
assignment, assignment,
jump, jump,
jump_if_zero, jump_if_zero,
jump_if_nonzero,
call, call,
call_pointer, call_pointer,
return_value return_value

View file

@ -2,6 +2,7 @@
#include <pslang/ir/node.hpp> #include <pslang/ir/node.hpp>
#include <pslang/ast/statement_visitor.hpp> #include <pslang/ast/statement_visitor.hpp>
#include <pslang/ast/expression_visitor.hpp> #include <pslang/ast/expression_visitor.hpp>
#include <pslang/types/type_visitor.hpp>
#include <unordered_map> #include <unordered_map>
@ -39,6 +40,27 @@ namespace pslang::ir
std::vector<resolve_call_data> resolve_call; std::vector<resolve_call_data> resolve_call;
}; };
struct zero_literal_visitor
: types::const_visitor<zero_literal_visitor>
{
module_context & mcontext;
using const_visitor::apply;
template <typename T>
void apply(types::primitive_type_base<T>)
{
mcontext.nodes->emplace_back(literal{ast::literal{ast::primitive_literal_base<T>{.value = {}}}},
std::make_shared<types::type>(types::primitive_type{types::primitive_type_base<T>{}}));
}
template <typename T>
void apply(T const &)
{
throw std::runtime_error("Invalid type for zero_literal_visitor");
}
};
// Compile a single function and store the entry point node_ref // Compile a single function and store the entry point node_ref
// in local_context // in local_context
struct compile_function_visitor struct compile_function_visitor
@ -181,6 +203,120 @@ namespace pslang::ir
return last(); return last();
} }
// Different-type integer comparison
if (!types::equal(*arg1_type, *arg2_type)
&& types::is_integer_type(*arg1_type)
&& types::is_integer_type(*arg2_type)
&& ast::is_comparison(node.type))
{
bool arg1_unsigned = types::is_unsigned_integer_type(*arg1_type);
bool arg2_unsigned = types::is_unsigned_integer_type(*arg2_type);
std::size_t arg1_size = types::type_size(*arg1_type);
std::size_t arg2_size = types::type_size(*arg2_type);
std::size_t max_size = std::max(arg1_size, arg2_size);
types::type_ptr max_type = (arg1_size > arg2_size) ? arg1_type : arg2_type;
if ((arg1_unsigned && arg2_unsigned) || (!arg1_unsigned && !arg2_unsigned))
{
// Both signed or both unsigned: just cast the smaller one to the larger type
if (arg1_size < arg2_size)
{
mcontext.nodes->emplace_back(cast_operation{arg1, max_type}, max_type);
mcontext.nodes->emplace_back(binary_operation{node.type, last(), arg2}, node.inferred_type);
return last();
}
else
{
mcontext.nodes->emplace_back(cast_operation{arg2, max_type}, max_type);
mcontext.nodes->emplace_back(binary_operation{node.type, arg1, last()}, node.inferred_type);
return last();
}
}
else
{
// Different signedness
// Swap arg1 and arg2 if arg1 is unsigned, and reverse the operation
auto type = node.type;
if (arg1_unsigned)
{
std::swap(arg1, arg2);
std::swap(arg1_type, arg2_type);
std::swap(arg1_size, arg2_size);
std::swap(arg1_is_pointer, arg2_is_pointer);
if (type == ast::binary_operation_type::less)
type = ast::binary_operation_type::greater;
else if (type == ast::binary_operation_type::greater)
type = ast::binary_operation_type::less;
else if (type == ast::binary_operation_type::less_equals)
type = ast::binary_operation_type::greater_equals;
else if (type == ast::binary_operation_type::greater_equals)
type = ast::binary_operation_type::less_equals;
}
// Compare with zero first
zero_literal_visitor{{}, mcontext}.apply(*arg1_type);
switch (type)
{
case ast::binary_operation_type::equals:
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type);
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
break;
case ast::binary_operation_type::not_equals:
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type);
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
break;
case ast::binary_operation_type::less:
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type);
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
break;
case ast::binary_operation_type::greater:
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater, arg1, last()}, node.inferred_type);
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
break;
case ast::binary_operation_type::less_equals:
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less_equals, arg1, last()}, node.inferred_type);
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
break;
case ast::binary_operation_type::greater_equals:
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type);
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
break;
default:
break;
}
auto result = std::prev(last());
auto jump_node = last();
// Less-than-zero case handled, here arg1 is nonnegative - cast both to the largest unsigned type
types::type_ptr max_unsigned_type;
if (max_size == 1)
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u8_type{}});
else if (max_size == 2)
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u16_type{}});
else if (max_size == 4)
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u32_type{}});
else if (max_size == 8)
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u64_type{}});
mcontext.nodes->emplace_back(cast_operation{arg1, max_unsigned_type}, max_unsigned_type);
auto new_arg1 = last();
mcontext.nodes->emplace_back(cast_operation{arg2, max_unsigned_type}, max_unsigned_type);
auto new_arg2 = last();
mcontext.nodes->emplace_back(binary_operation{type, new_arg1, new_arg2}, node.inferred_type);
mcontext.nodes->emplace_back(assignment{result, last()});
mcontext.nodes->emplace_back(label{});
if (auto jump_if_zero = std::get_if<ir::jump_if_zero>(&jump_node->instruction))
jump_if_zero->target = last();
else if (auto jump_if_nonzero = std::get_if<ir::jump_if_nonzero>(&jump_node->instruction))
jump_if_nonzero->target = last();
return result;
}
}
// General case // General case
mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type); mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type);

View file

@ -261,6 +261,11 @@ namespace pslang::ir
out << "jz $" << get_index(instruction.condition) << " $" << get_index(instruction.target); out << "jz $" << get_index(instruction.condition) << " $" << get_index(instruction.target);
} }
void operator()(jump_if_nonzero const & instruction)
{
out << "jnz $" << get_index(instruction.condition) << " $" << get_index(instruction.target);
}
void operator()(call const & instruction) void operator()(call const & instruction)
{ {
out << "call $" << get_index(instruction.target); out << "call $" << get_index(instruction.target);

View file

@ -620,6 +620,13 @@ namespace pslang::jit::aarch64
builder.cbz(0, 0); builder.cbz(0, 0);
} }
void apply(ir::node_ref, ir::jump_if_nonzero const & node, types::type_ptr const &)
{
load(node.condition, 0);
lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target);
builder.cbnz(0, 0);
}
void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type) void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type)
{ {
// TODO: struct/array arguments? // TODO: struct/array arguments?