Support comparison operators for integers of different types
This commit is contained in:
parent
81825aee7f
commit
14cb393076
6 changed files with 179 additions and 6 deletions
|
|
@ -32,6 +32,22 @@ namespace pslang::ast
|
|||
greater_equals,
|
||||
};
|
||||
|
||||
inline bool is_comparison(binary_operation_type type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case binary_operation_type::equals:
|
||||
case binary_operation_type::not_equals:
|
||||
case binary_operation_type::less:
|
||||
case binary_operation_type::greater:
|
||||
case binary_operation_type::less_equals:
|
||||
case binary_operation_type::greater_equals:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Ostream>
|
||||
Ostream & operator << (Ostream & out, unary_operation_type type)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -340,6 +340,7 @@ namespace pslang::ast
|
|||
auto arg2_type = get_type(*node.arg2);
|
||||
|
||||
bool equal = types::equal(*arg1_type, *arg2_type);
|
||||
bool both_integers = types::is_integer_type(*arg1_type) && types::is_integer_type(*arg2_type);
|
||||
|
||||
switch (node.type)
|
||||
{
|
||||
|
|
@ -442,42 +443,42 @@ namespace pslang::ast
|
|||
}
|
||||
break;
|
||||
case binary_operation_type::equals:
|
||||
if (equal)
|
||||
if (equal || both_integers)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::not_equals:
|
||||
if (equal)
|
||||
if (equal || both_integers)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::less:
|
||||
if (equal)
|
||||
if (equal || both_integers)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::greater:
|
||||
if (equal)
|
||||
if (equal || both_integers)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::less_equals:
|
||||
if (equal)
|
||||
if (equal || both_integers)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case binary_operation_type::greater_equals:
|
||||
if (equal)
|
||||
if (equal || both_integers)
|
||||
{
|
||||
node.inferred_type = std::make_unique<types::type>(types::bool_type{});
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -89,6 +89,13 @@ namespace pslang::ir
|
|||
node_ref target;
|
||||
};
|
||||
|
||||
struct jump_if_nonzero
|
||||
{
|
||||
// Condition must be a bool-valued node
|
||||
node_ref condition;
|
||||
node_ref target;
|
||||
};
|
||||
|
||||
// Call to a specific node
|
||||
struct call
|
||||
{
|
||||
|
|
@ -123,6 +130,7 @@ namespace pslang::ir
|
|||
assignment,
|
||||
jump,
|
||||
jump_if_zero,
|
||||
jump_if_nonzero,
|
||||
call,
|
||||
call_pointer,
|
||||
return_value
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <pslang/ir/node.hpp>
|
||||
#include <pslang/ast/statement_visitor.hpp>
|
||||
#include <pslang/ast/expression_visitor.hpp>
|
||||
#include <pslang/types/type_visitor.hpp>
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
|
|
@ -39,6 +40,27 @@ namespace pslang::ir
|
|||
std::vector<resolve_call_data> resolve_call;
|
||||
};
|
||||
|
||||
struct zero_literal_visitor
|
||||
: types::const_visitor<zero_literal_visitor>
|
||||
{
|
||||
module_context & mcontext;
|
||||
|
||||
using const_visitor::apply;
|
||||
|
||||
template <typename T>
|
||||
void apply(types::primitive_type_base<T>)
|
||||
{
|
||||
mcontext.nodes->emplace_back(literal{ast::literal{ast::primitive_literal_base<T>{.value = {}}}},
|
||||
std::make_shared<types::type>(types::primitive_type{types::primitive_type_base<T>{}}));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void apply(T const &)
|
||||
{
|
||||
throw std::runtime_error("Invalid type for zero_literal_visitor");
|
||||
}
|
||||
};
|
||||
|
||||
// Compile a single function and store the entry point node_ref
|
||||
// in local_context
|
||||
struct compile_function_visitor
|
||||
|
|
@ -181,6 +203,120 @@ namespace pslang::ir
|
|||
return last();
|
||||
}
|
||||
|
||||
// Different-type integer comparison
|
||||
|
||||
if (!types::equal(*arg1_type, *arg2_type)
|
||||
&& types::is_integer_type(*arg1_type)
|
||||
&& types::is_integer_type(*arg2_type)
|
||||
&& ast::is_comparison(node.type))
|
||||
{
|
||||
bool arg1_unsigned = types::is_unsigned_integer_type(*arg1_type);
|
||||
bool arg2_unsigned = types::is_unsigned_integer_type(*arg2_type);
|
||||
std::size_t arg1_size = types::type_size(*arg1_type);
|
||||
std::size_t arg2_size = types::type_size(*arg2_type);
|
||||
std::size_t max_size = std::max(arg1_size, arg2_size);
|
||||
|
||||
types::type_ptr max_type = (arg1_size > arg2_size) ? arg1_type : arg2_type;
|
||||
|
||||
if ((arg1_unsigned && arg2_unsigned) || (!arg1_unsigned && !arg2_unsigned))
|
||||
{
|
||||
// Both signed or both unsigned: just cast the smaller one to the larger type
|
||||
if (arg1_size < arg2_size)
|
||||
{
|
||||
mcontext.nodes->emplace_back(cast_operation{arg1, max_type}, max_type);
|
||||
mcontext.nodes->emplace_back(binary_operation{node.type, last(), arg2}, node.inferred_type);
|
||||
return last();
|
||||
}
|
||||
else
|
||||
{
|
||||
mcontext.nodes->emplace_back(cast_operation{arg2, max_type}, max_type);
|
||||
mcontext.nodes->emplace_back(binary_operation{node.type, arg1, last()}, node.inferred_type);
|
||||
return last();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Different signedness
|
||||
|
||||
// Swap arg1 and arg2 if arg1 is unsigned, and reverse the operation
|
||||
auto type = node.type;
|
||||
if (arg1_unsigned)
|
||||
{
|
||||
std::swap(arg1, arg2);
|
||||
std::swap(arg1_type, arg2_type);
|
||||
std::swap(arg1_size, arg2_size);
|
||||
std::swap(arg1_is_pointer, arg2_is_pointer);
|
||||
|
||||
if (type == ast::binary_operation_type::less)
|
||||
type = ast::binary_operation_type::greater;
|
||||
else if (type == ast::binary_operation_type::greater)
|
||||
type = ast::binary_operation_type::less;
|
||||
else if (type == ast::binary_operation_type::less_equals)
|
||||
type = ast::binary_operation_type::greater_equals;
|
||||
else if (type == ast::binary_operation_type::greater_equals)
|
||||
type = ast::binary_operation_type::less_equals;
|
||||
}
|
||||
|
||||
// Compare with zero first
|
||||
zero_literal_visitor{{}, mcontext}.apply(*arg1_type);
|
||||
switch (type)
|
||||
{
|
||||
case ast::binary_operation_type::equals:
|
||||
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
|
||||
break;
|
||||
case ast::binary_operation_type::not_equals:
|
||||
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
|
||||
break;
|
||||
case ast::binary_operation_type::less:
|
||||
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less, arg1, last()}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
|
||||
break;
|
||||
case ast::binary_operation_type::greater:
|
||||
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater, arg1, last()}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
|
||||
break;
|
||||
case ast::binary_operation_type::less_equals:
|
||||
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::less_equals, arg1, last()}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(jump_if_nonzero{last(), {}});
|
||||
break;
|
||||
case ast::binary_operation_type::greater_equals:
|
||||
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::greater_equals, arg1, last()}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(jump_if_zero{last(), {}});
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
auto result = std::prev(last());
|
||||
auto jump_node = last();
|
||||
|
||||
// Less-than-zero case handled, here arg1 is nonnegative - cast both to the largest unsigned type
|
||||
types::type_ptr max_unsigned_type;
|
||||
if (max_size == 1)
|
||||
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u8_type{}});
|
||||
else if (max_size == 2)
|
||||
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u16_type{}});
|
||||
else if (max_size == 4)
|
||||
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u32_type{}});
|
||||
else if (max_size == 8)
|
||||
max_unsigned_type = std::make_unique<types::type>(types::primitive_type{types::u64_type{}});
|
||||
|
||||
mcontext.nodes->emplace_back(cast_operation{arg1, max_unsigned_type}, max_unsigned_type);
|
||||
auto new_arg1 = last();
|
||||
mcontext.nodes->emplace_back(cast_operation{arg2, max_unsigned_type}, max_unsigned_type);
|
||||
auto new_arg2 = last();
|
||||
mcontext.nodes->emplace_back(binary_operation{type, new_arg1, new_arg2}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(assignment{result, last()});
|
||||
mcontext.nodes->emplace_back(label{});
|
||||
if (auto jump_if_zero = std::get_if<ir::jump_if_zero>(&jump_node->instruction))
|
||||
jump_if_zero->target = last();
|
||||
else if (auto jump_if_nonzero = std::get_if<ir::jump_if_nonzero>(&jump_node->instruction))
|
||||
jump_if_nonzero->target = last();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// General case
|
||||
|
||||
mcontext.nodes->emplace_back(binary_operation{node.type, arg1, arg2}, node.inferred_type);
|
||||
|
|
|
|||
|
|
@ -261,6 +261,11 @@ namespace pslang::ir
|
|||
out << "jz $" << get_index(instruction.condition) << " $" << get_index(instruction.target);
|
||||
}
|
||||
|
||||
void operator()(jump_if_nonzero const & instruction)
|
||||
{
|
||||
out << "jnz $" << get_index(instruction.condition) << " $" << get_index(instruction.target);
|
||||
}
|
||||
|
||||
void operator()(call const & instruction)
|
||||
{
|
||||
out << "call $" << get_index(instruction.target);
|
||||
|
|
|
|||
|
|
@ -620,6 +620,13 @@ namespace pslang::jit::aarch64
|
|||
builder.cbz(0, 0);
|
||||
}
|
||||
|
||||
void apply(ir::node_ref, ir::jump_if_nonzero const & node, types::type_ptr const &)
|
||||
{
|
||||
load(node.condition, 0);
|
||||
lcontext.cbranch_resolve.emplace_back(pcontext.code.size(), node.target);
|
||||
builder.cbnz(0, 0);
|
||||
}
|
||||
|
||||
void apply(ir::node_ref it, ir::call const & node, types::type_ptr const & type)
|
||||
{
|
||||
// TODO: struct/array arguments?
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue