Add short-circuiting versions of and/or (&& and ||) and implement them for bools and integers in interpreter and aarch64 compiler

This commit is contained in:
Nikita Lisitsa 2026-03-12 21:37:15 +03:00
parent 76f80f8135
commit 927368aa80
8 changed files with 213 additions and 39 deletions

View file

@ -173,10 +173,9 @@ int main(int argc, char ** argv)
{ {
// TODO: remove, testing-only code; should execute entry point instead // TODO: remove, testing-only code; should execute entry point instead
auto offset = pcontext.symbols.at("add_or_sub"); auto offset = pcontext.symbols.at("test");
using type = std::uint32_t(*)(std::uint32_t); auto fptr = (std::uint32_t(*)())(executable.data.get() + offset);
auto fptr = (type(*)(bool))(executable.data.get() + offset); auto x = fptr();
auto x = fptr(false)(10u);
std::cout << "Result: " << std::boolalpha << x << std::endl; std::cout << "Result: " << std::boolalpha << x << std::endl;
} }
} }

View file

@ -0,0 +1,9 @@
func g() -> u32:
return g()
func h() -> bool:
return h()
func test() -> u32:
return (!1u) || g()

View file

@ -16,7 +16,9 @@ namespace pslang::ast
multiplication, multiplication,
division, division,
remainder, remainder,
binary_and,
logical_and, logical_and,
binary_or,
logical_or, logical_or,
logical_xor, logical_xor,
equals, equals,
@ -62,11 +64,17 @@ namespace pslang::ast
case binary_operation_type::remainder: case binary_operation_type::remainder:
out << "remainder"; out << "remainder";
break; break;
case binary_operation_type::binary_and:
out << "binary and";
break;
case binary_operation_type::logical_and: case binary_operation_type::logical_and:
out << "and"; out << "logical and";
break;
case binary_operation_type::binary_or:
out << "binary or";
break; break;
case binary_operation_type::logical_or: case binary_operation_type::logical_or:
out << "or"; out << "logical or";
break; break;
case binary_operation_type::logical_xor: case binary_operation_type::logical_xor:
out << "xor"; out << "xor";

View file

@ -295,6 +295,13 @@ namespace pslang::ast
return; return;
} }
break; break;
case binary_operation_type::binary_and:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::logical_and: case binary_operation_type::logical_and:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{ {
@ -302,6 +309,13 @@ namespace pslang::ast
return; return;
} }
break; break;
case binary_operation_type::binary_or:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{
node.inferred_type = arg1_type;
return;
}
break;
case binary_operation_type::logical_or: case binary_operation_type::logical_or:
if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type)))
{ {

View file

@ -164,6 +164,18 @@ namespace pslang::interpreter
return true; return true;
} }
bool is_short_circuiting(ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::logical_and:
case ast::binary_operation_type::logical_or:
return true;
default:
return false;
}
}
template <typename T> template <typename T>
value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base<T> const & arg1, value const & arg2_generic, ast::location const & location) value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base<T> const & arg1, value const & arg2_generic, ast::location const & location)
{ {
@ -211,7 +223,7 @@ namespace pslang::interpreter
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value % arg2.value)}); return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value % arg2.value)});
} }
break; break;
case ast::binary_operation_type::logical_and: case ast::binary_operation_type::binary_and:
if constexpr (std::is_same_v<T, bool>) if constexpr (std::is_same_v<T, bool>)
{ {
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)}); return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
@ -221,7 +233,9 @@ namespace pslang::interpreter
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value & arg2.value)}); return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value & arg2.value)});
} }
break; break;
case ast::binary_operation_type::logical_or: case ast::binary_operation_type::logical_and:
throw internal_error("logical_and must be handled separately", location);
case ast::binary_operation_type::binary_or:
if constexpr (std::is_same_v<T, bool>) if constexpr (std::is_same_v<T, bool>)
{ {
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)}); return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)});
@ -231,6 +245,8 @@ namespace pslang::interpreter
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value | arg2.value)}); return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value | arg2.value)});
} }
break; break;
case ast::binary_operation_type::logical_or:
throw internal_error("logical_or must be handled separately", location);
case ast::binary_operation_type::logical_xor: case ast::binary_operation_type::logical_xor:
if constexpr (std::is_same_v<T, bool>) if constexpr (std::is_same_v<T, bool>)
{ {
@ -279,26 +295,108 @@ namespace pslang::interpreter
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2, location); }, arg1); return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2, location); }, arg1);
} }
template <typename T, typename LazyArg2>
value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value_base<T> const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location)
{
switch (type)
{
case ast::binary_operation_type::logical_and:
if constexpr (std::is_same_v<T, bool>)
{
if (!arg1.value)
return primitive_value(primitive_value_base<T>{false});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
if (arg1.value == T{})
return primitive_value(primitive_value_base<T>{arg1.value});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
break;
case ast::binary_operation_type::logical_or:
if constexpr (std::is_same_v<T, bool>)
{
if (arg1.value)
return primitive_value(primitive_value_base<T>{true});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value || arg2.value)});
}
else if constexpr (std::is_integral_v<T>)
{
if (arg1.value == ~T{})
return primitive_value(primitive_value_base<T>{arg1.value});
value const & arg2_generic = lazy_arg2();
primitive_value_base<T> const & arg2 = std::get<primitive_value_base<T>>(std::get<primitive_value>(arg2_generic));
return primitive_value(primitive_value_base<T>{static_cast<T>(arg1.value && arg2.value)});
}
break;
default:
throw internal_error("invalid operator type in short-circuiting branch", location);
}
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(primitive_value(arg1)));
os << " and ";
types::print(os, type_of(lazy_arg2()));
throw internal_error(os.str(), location);
}
template <typename Value, typename LazyArg2>
value short_circuiting_impl_same_type(ast::binary_operation_type type, Value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location)
{
std::ostringstream os;
os << "Cannot apply " << type << " to values of type ";
types::print(os, type_of(arg1));
os << " and ";
types::print(os, type_of(lazy_arg2()));
throw internal_error(os.str(), location);
}
template <typename LazyArg2>
value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location)
{
return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(type, value, lazy_arg2, location); }, arg1);
}
value eval_impl(context & context, ast::binary_operation const & binary_operation) value eval_impl(context & context, ast::binary_operation const & binary_operation)
{ {
auto arg1 = eval_impl(context, binary_operation.arg1); auto type1 = ast::get_type(*binary_operation.arg1);
auto arg2 = eval_impl(context, binary_operation.arg2); auto type2 = ast::get_type(*binary_operation.arg2);
if (requires_same_argument_type(binary_operation.type)) bool const is_same_type = types::equal(*type1, *type2);
if (requires_same_argument_type(binary_operation.type) && !is_same_type)
{ {
auto type1 = type_of(arg1); std::ostringstream os;
auto type2 = type_of(arg2); os << "Cannot apply " << binary_operation.type << " to values of type ";
types::print(os, *type1);
os << " and ";
types::print(os, *type2);
throw internal_error(os.str(), binary_operation.location);
}
if (!types::equal(type1, type2)) auto arg1 = eval_impl(context, binary_operation.arg1);
{
std::ostringstream os;
os << "Cannot apply " << binary_operation.type << " to values of type ";
types::print(os, type1);
os << " and ";
types::print(os, type2);
throw internal_error(os.str(), binary_operation.location);
}
if (is_short_circuiting(binary_operation.type))
{
auto lazy_arg2 = [&]{ return eval_impl(context, binary_operation.arg2); };
return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(binary_operation.type, value, lazy_arg2, binary_operation.location); }, arg1);
}
if (is_same_type)
{
auto arg2 = eval_impl(context, binary_operation.arg2);
return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2, binary_operation.location); }, arg1); return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2, binary_operation.location); }, arg1);
} }

View file

@ -43,6 +43,18 @@ namespace pslang::jit::aarch64
return 3; return 3;
} }
bool is_short_circuiting(ast::binary_operation_type type)
{
switch (type)
{
case ast::binary_operation_type::logical_and:
case ast::binary_operation_type::logical_or:
return true;
default:
return false;
}
}
struct populate_constants_visitor struct populate_constants_visitor
: ast::const_expression_visitor<populate_constants_visitor> : ast::const_expression_visitor<populate_constants_visitor>
, ast::const_statement_visitor<populate_constants_visitor> , ast::const_statement_visitor<populate_constants_visitor>
@ -386,19 +398,22 @@ namespace pslang::jit::aarch64
bool const is_fp = types::is_floating_point_type(*arg1_type); bool const is_fp = types::is_floating_point_type(*arg1_type);
std::uint8_t const fp_mode = fp_mode_for(*arg1_type); std::uint8_t const fp_mode = fp_mode_for(*arg1_type);
if (is_fp) apply(*node.arg1);
if (!is_short_circuiting(node.type))
{ {
apply(*node.arg1); if (is_fp)
push_fp(0, fp_mode); {
apply(*node.arg2); push_fp(0, fp_mode);
pop_fp(1, fp_mode); apply(*node.arg2);
} pop_fp(1, fp_mode);
else }
{ else
apply(*node.arg1); {
push(0); push(0);
apply(*node.arg2); apply(*node.arg2);
pop(1); pop(1);
}
} }
switch (node.type) switch (node.type)
@ -445,12 +460,37 @@ namespace pslang::jit::aarch64
case ast::binary_operation_type::remainder: case ast::binary_operation_type::remainder:
// TODO: implement via div & mul & sub // TODO: implement via div & mul & sub
throw std::runtime_error("Not implemented"); throw std::runtime_error("Not implemented");
case ast::binary_operation_type::logical_and: case ast::binary_operation_type::binary_and:
builder.and_reg(1, 0, 0); builder.and_reg(1, 0, 0);
break; break;
case ast::binary_operation_type::logical_or: case ast::binary_operation_type::logical_and:
{
std::int32_t start = pcontext.code.size();
builder.cbz(0, 0);
push(0);
apply(*node.arg2);
pop(1);
builder.and_reg(1, 0, 0);
std::int32_t end = pcontext.code.size();
builder.cb_inject(pcontext.code.data() + start, (end - start) / 4);
}
break;
case ast::binary_operation_type::binary_or:
builder.or_reg(1, 0, 0); builder.or_reg(1, 0, 0);
break; break;
case ast::binary_operation_type::logical_or:
{
builder.or_not_reg(31, 0, 1);
std::int32_t start = pcontext.code.size();
builder.cbz(1, 0);
push(0);
apply(*node.arg2);
pop(1);
builder.or_reg(1, 0, 0);
std::int32_t end = pcontext.code.size();
builder.cb_inject(pcontext.code.data() + start, (end - start) / 4);
}
break;
case ast::binary_operation_type::logical_xor: case ast::binary_operation_type::logical_xor:
builder.xor_reg(1, 0, 0); builder.xor_reg(1, 0, 0);
break; break;

View file

@ -67,7 +67,9 @@ f64 { return bp::make_f64(ctx.location); }
"*" { return bp::make_asterisk(ctx.location); } "*" { return bp::make_asterisk(ctx.location); }
"/" { return bp::make_slash(ctx.location); } "/" { return bp::make_slash(ctx.location); }
"%" { return bp::make_percent(ctx.location); } "%" { return bp::make_percent(ctx.location); }
"&&" { return bp::make_double_ampersand(ctx.location); }
"&" { return bp::make_ampersand(ctx.location); } "&" { return bp::make_ampersand(ctx.location); }
"||" { return bp::make_double_vertical_bar(ctx.location); }
"|" { return bp::make_vertical_bar(ctx.location); } "|" { return bp::make_vertical_bar(ctx.location); }
"^" { return bp::make_circumflex(ctx.location); } "^" { return bp::make_circumflex(ctx.location); }
"!" { return bp::make_exclamation(ctx.location); } "!" { return bp::make_exclamation(ctx.location); }

View file

@ -91,7 +91,9 @@ template <typename T>
%token slash "/" %token slash "/"
%token percent "%" %token percent "%"
%token ampersand "&" %token ampersand "&"
%token double_ampersand "&&"
%token vertical_bar "|" %token vertical_bar "|"
%token double_vertical_bar "||"
%token circumflex "^" %token circumflex "^"
%token exclamation "!" %token exclamation "!"
%token equals "==" %token equals "=="
@ -150,7 +152,7 @@ template <typename T>
%token end 0 %token end 0
%right arrow %right arrow
%left ampersand vertical_bar circumflex %left ampersand double_ampersand vertical_bar double_vertical_bar circumflex
%left equals not_equals less greater less_equals greater_equals %left equals not_equals less greater less_equals greater_equals
%nonassoc as %nonassoc as
%left plus minus %left plus minus
@ -282,8 +284,10 @@ two_or_more_type_list
; ;
expression expression
: expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; } : expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::binary_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; } | expression double_ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::binary_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression double_vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression circumflex expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_xor, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; } | expression circumflex expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_xor, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression equals expression { $$ = ast::binary_operation{ast::binary_operation_type::equals, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; } | expression equals expression { $$ = ast::binary_operation{ast::binary_operation_type::equals, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }
| expression not_equals expression { $$ = ast::binary_operation{ast::binary_operation_type::not_equals, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; } | expression not_equals expression { $$ = ast::binary_operation{ast::binary_operation_type::not_equals, std::make_unique<ast::expression>($1), std::make_unique<ast::expression>($3), @$ }; }