Add short-circuiting versions of and/or (&& and ||) and implement them for bools and integers in interpreter and aarch64 compiler

This commit is contained in:
Nikita Lisitsa 2026-03-12 21:37:15 +03:00
parent 76f80f8135
commit 927368aa80
8 changed files with 213 additions and 39 deletions

View file

@ -173,10 +173,9 @@ int main(int argc, char ** argv)
{
// TODO: remove, testing-only code; should execute entry point instead
auto offset = pcontext.symbols.at("add_or_sub");
using type = std::uint32_t(*)(std::uint32_t);
auto fptr = (type(*)(bool))(executable.data.get() + offset);
auto x = fptr(false)(10u);
auto offset = pcontext.symbols.at("test");
auto fptr = (std::uint32_t(*)())(executable.data.get() + offset);
auto x = fptr();
std::cout << "Result: " << std::boolalpha << x << std::endl;
}
}

View file

@ -0,0 +1,9 @@
func g() -> u32:
return g()
func h() -> bool:
return h()
func test() -> u32:
return (!1u) || g()

View file

@ -16,7 +16,9 @@ namespace pslang::ast
multiplication,
division,
remainder,
binary_and,
logical_and,
binary_or,
logical_or,
logical_xor,
equals,
@ -62,11 +64,17 @@ namespace pslang::ast
case binary_operation_type::remainder:
out << "remainder";
break;
case binary_operation_type::binary_and:
out << "binary and";
break;
case binary_operation_type::logical_and:
out << "and";
out << "logical and";
break;
case binary_operation_type::binary_or:
out << "binary or";
break;
case binary_operation_type::logical_or:
out << "or";
out << "logical or";
break;
case binary_operation_type::logical_xor:
out << "xor";

View file

@ -295,6 +295,13 @@ namespace pslang::ast
return;
}
break;
case binary_operation_type::binary_and:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::logical_and:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
@ -302,6 +309,13 @@ namespace pslang::ast
return;
}
break;
case binary_operation_type::binary_or:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::logical_or:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{

View file

@ -164,6 +164,18 @@ namespace pslang::interpreter
return true;
}
bool is_short_circuiting(ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::logical_and:
case ast::binary_operation_type::logical_or:
return true;
default:
return false;
}
}
template <typename T>
value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base<T> const & arg1, value const & arg2_generic, ast::location const & location)
{
@ -211,7 +223,7 @@ namespace pslang::interpreter
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value % arg2.value)});
}
break;
case ast::binary_operation_type::logical_and:
case ast::binary_operation_type::binary_and:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
@ -221,7 +233,9 @@ namespace pslang::interpreter
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value & arg2.value)});
}
break;
case ast::binary_operation_type::logical_or:
case ast::binary_operation_type::logical_and:
throw internal_error("logical_and must be handled separately", location);
case ast::binary_operation_type::binary_or:
if constexpr (std::is_same_v<T, bool>)
{
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)});
@ -231,6 +245,8 @@ namespace pslang::interpreter
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value | arg2.value)});
}
break;
case ast::binary_operation_type::logical_or:
throw internal_error("logical_or must be handled separately", location);
case ast::binary_operation_type::logical_xor:
if constexpr (std::is_same_v<T, bool>)
{
@ -279,26 +295,108 @@ namespace pslang::interpreter
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2, location); }, arg1);
}
template <typename T, typename LazyArg2>
value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value_base<T> const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location)
{
switch (type)
{
case ast::binary_operation_type::logical_and:
if constexpr (std::is_same_v<T, bool>)
{
if (!arg1.value)
return primitive_value(primitive_value_base<T>{false});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
if (arg1.value == T{})
return primitive_value(primitive_value_base<T>{arg1.value});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
break;
case ast::binary_operation_type::logical_or:
if constexpr (std::is_same_v<T, bool>)
{
if (arg1.value)
return primitive_value(primitive_value_base<T>{true});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
if (arg1.value == ~T{})
return primitive_value(primitive_value_base<T>{arg1.value});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
break;
default:
throw internal_error("invalid operator type in short-circuiting branch", location);
}
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(primitive_value(arg1)));
os << " and ";
types::print(os, type_of(lazy_arg2()));
throw internal_error(os.str(), location);
}
template <typename Value, typename LazyArg2>
value short_circuiting_impl_same_type(ast::binary_operation_type type, Value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location)
{
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(arg1));
os << " and ";
types::print(os, type_of(lazy_arg2()));
throw internal_error(os.str(), location);
}
template <typename LazyArg2>
value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location)
{
return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(type, value, lazy_arg2, location); }, arg1);
}
value eval_impl(context & context, ast::binary_operation const & binary_operation)
{
auto arg1 = eval_impl(context, binary_operation.arg1);
auto arg2 = eval_impl(context, binary_operation.arg2);
auto type1 = ast::get_type(*binary_operation.arg1);
auto type2 = ast::get_type(*binary_operation.arg2);
if (requires_same_argument_type(binary_operation.type))
{
auto type1 = type_of(arg1);
auto type2 = type_of(arg2);
bool const is_same_type = types::equal(*type1, *type2);
if (!types::equal(type1, type2))
if (requires_same_argument_type(binary_operation.type) && !is_same_type)
{
std::ostringstream os;
os << "Cannot apply " << binary_operation.type << " to values of type ";
types::print(os, type1);
types::print(os, *type1);
os << " and ";
types::print(os, type2);
types::print(os, *type2);
throw internal_error(os.str(), binary_operation.location);
}
auto arg1 = eval_impl(context, binary_operation.arg1);
if (is_short_circuiting(binary_operation.type))
{
auto lazy_arg2 = [&]{ return eval_impl(context, binary_operation.arg2); };
return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(binary_operation.type, value, lazy_arg2, binary_operation.location); }, arg1);
}
if (is_same_type)
{
auto arg2 = eval_impl(context, binary_operation.arg2);
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2, binary_operation.location); }, arg1);
}

View file

@ -43,6 +43,18 @@ namespace pslang::jit::aarch64
return 3;
}
bool is_short_circuiting(ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::logical_and:
case ast::binary_operation_type::logical_or:
return true;
default:
return false;
}
}
struct populate_constants_visitor
: ast::const_expression_visitor<populate_constants_visitor>
, ast::const_statement_visitor<populate_constants_visitor>
@ -386,20 +398,23 @@ namespace pslang::jit::aarch64
bool const is_fp = types::is_floating_point_type(*arg1_type);
std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
apply(*node.arg1);
if (!is_short_circuiting(node.type))
{
if (is_fp)
{
apply(*node.arg1);
push_fp(0, fp_mode);
apply(*node.arg2);
pop_fp(1, fp_mode);
}
else
{
apply(*node.arg1);
push(0);
apply(*node.arg2);
pop(1);
}
}
switch (node.type)
{
@ -445,12 +460,37 @@ namespace pslang::jit::aarch64
case ast::binary_operation_type::remainder:
// TODO: implement via div & mul & sub
throw std::runtime_error("Not implemented");
case ast::binary_operation_type::logical_and:
case ast::binary_operation_type::binary_and:
builder.and_reg(1, 0, 0);
break;
case ast::binary_operation_type::logical_or:
case ast::binary_operation_type::logical_and:
{
std::int32_t start = pcontext.code.size();
builder.cbz(0, 0);
push(0);
apply(*node.arg2);
pop(1);
builder.and_reg(1, 0, 0);
std::int32_t end = pcontext.code.size();
builder.cb_inject(pcontext.code.data() + start, (end - start) / 4);
}
break;
case ast::binary_operation_type::binary_or:
builder.or_reg(1, 0, 0);
break;
case ast::binary_operation_type::logical_or:
{
builder.or_not_reg(31, 0, 1);
std::int32_t start = pcontext.code.size();
builder.cbz(1, 0);
push(0);
apply(*node.arg2);
pop(1);
builder.or_reg(1, 0, 0);
std::int32_t end = pcontext.code.size();
builder.cb_inject(pcontext.code.data() + start, (end - start) / 4);
}
break;
case ast::binary_operation_type::logical_xor:
builder.xor_reg(1, 0, 0);
break;

View file

@ -67,7 +67,9 @@ f64 { return bp::make_f64(ctx.location); }
"*" { return bp::make_asterisk(ctx.location); }
"/" { return bp::make_slash(ctx.location); }
"%" { return bp::make_percent(ctx.location); }
"&&" { return bp::make_double_ampersand(ctx.location); }
"&" { return bp::make_ampersand(ctx.location); }
"||" { return bp::make_double_vertical_bar(ctx.location); }
"|" { return bp::make_vertical_bar(ctx.location); }
"^" { return bp::make_circumflex(ctx.location); }
"!" { return bp::make_exclamation(ctx.location); }

View file

@ -91,7 +91,9 @@ template <typename T>
%token slash "/"
%token percent "%"
%token ampersand "&"
%token double_ampersand "&&"
%token vertical_bar "|"
%token double_vertical_bar "||"
%token circumflex "^"
%token exclamation "!"
%token equals "=="
@ -150,7 +152,7 @@ template <typename T>
%token end 0
%right arrow
%left ampersand vertical_bar circumflex
%left ampersand double_ampersand vertical_bar double_vertical_bar circumflex
%left equals not_equals less greater less_equals greater_equals
%nonassoc as
%left plus minus
@ -282,8 +284,10 @@ two_or_more_type_list
;
expression
: expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
: expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::binary_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression double_ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::binary_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression double_vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression circumflex expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_xor, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression equals expression { $$ = ast::binary_operation{ast::binary_operation_type::equals, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression not_equals expression { $$ = ast::binary_operation{ast::binary_operation_type::not_equals, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }