diff --git a/apps/interpreter/source/main.cpp b/apps/interpreter/source/main.cpp index 4c8bec5..b8dd48b 100644 --- a/apps/interpreter/source/main.cpp +++ b/apps/interpreter/source/main.cpp @@ -173,10 +173,9 @@ int main(int argc, char ** argv) { // TODO: remove, testing-only code; should execute entry point instead - auto offset = pcontext.symbols.at("add_or_sub"); - using type = std::uint32_t(*)(std::uint32_t); - auto fptr = (type(*)(bool))(executable.data.get() + offset); - auto x = fptr(false)(10u); + auto offset = pcontext.symbols.at("test"); + auto fptr = (std::uint32_t(*)())(executable.data.get() + offset); + auto x = fptr(); std::cout << "Result: " << std::boolalpha << x << std::endl; } } diff --git a/examples/short-circuit.psl b/examples/short-circuit.psl new file mode 100644 index 0000000..bad56ff --- /dev/null +++ b/examples/short-circuit.psl @@ -0,0 +1,9 @@ +func g() -> u32: + return g() + +func h() -> bool: + return h() + +func test() -> u32: + return (!1u) || g() + diff --git a/libs/ast/include/pslang/ast/operation.hpp b/libs/ast/include/pslang/ast/operation.hpp index 709d281..4e6bee2 100644 --- a/libs/ast/include/pslang/ast/operation.hpp +++ b/libs/ast/include/pslang/ast/operation.hpp @@ -16,7 +16,9 @@ namespace pslang::ast multiplication, division, remainder, + binary_and, logical_and, + binary_or, logical_or, logical_xor, equals, @@ -62,11 +64,17 @@ namespace pslang::ast case binary_operation_type::remainder: out << "remainder"; break; + case binary_operation_type::binary_and: + out << "binary and"; + break; case binary_operation_type::logical_and: - out << "and"; + out << "logical and"; + break; + case binary_operation_type::binary_or: + out << "binary or"; break; case binary_operation_type::logical_or: - out << "or"; + out << "logical or"; break; case binary_operation_type::logical_xor: out << "xor"; diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index efb1592..bbabf92 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -295,6 +295,13 @@ namespace pslang::ast return; } break; + case binary_operation_type::binary_and: + if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) + { + node.inferred_type = arg1_type; + return; + } + break; case binary_operation_type::logical_and: if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) { @@ -302,6 +309,13 @@ namespace pslang::ast return; } break; + case binary_operation_type::binary_or: + if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) + { + node.inferred_type = arg1_type; + return; + } + break; case binary_operation_type::logical_or: if (equal && (types::is_bool_type(*arg1_type) || types::is_integer_type(*arg1_type))) { diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index ec4d7c4..7fafe46 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -164,6 +164,18 @@ namespace pslang::interpreter return true; } + bool is_short_circuiting(ast::binary_operation_type type) + { + switch (type) + { + case ast::binary_operation_type::logical_and: + case ast::binary_operation_type::logical_or: + return true; + default: + return false; + } + } + template value binary_operation_impl_same_type(ast::binary_operation_type type, primitive_value_base const & arg1, value const & arg2_generic, ast::location const & location) { @@ -211,7 +223,7 @@ namespace pslang::interpreter return primitive_value(primitive_value_base{static_cast(arg1.value % arg2.value)}); } break; - case ast::binary_operation_type::logical_and: + case ast::binary_operation_type::binary_and: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); @@ -221,7 +233,9 @@ namespace pslang::interpreter return primitive_value(primitive_value_base{static_cast(arg1.value & arg2.value)}); } break; - case ast::binary_operation_type::logical_or: + case ast::binary_operation_type::logical_and: + throw internal_error("logical_and must be handled separately", location); + case ast::binary_operation_type::binary_or: if constexpr (std::is_same_v) { return primitive_value(primitive_value_base{static_cast(arg1.value || arg2.value)}); @@ -231,6 +245,8 @@ namespace pslang::interpreter return primitive_value(primitive_value_base{static_cast(arg1.value | arg2.value)}); } break; + case ast::binary_operation_type::logical_or: + throw internal_error("logical_or must be handled separately", location); case ast::binary_operation_type::logical_xor: if constexpr (std::is_same_v) { @@ -279,26 +295,108 @@ namespace pslang::interpreter return std::visit([&](auto const & value){ return binary_operation_impl_same_type(type, value, arg2, location); }, arg1); } + template + value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value_base const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location) + { + switch (type) + { + case ast::binary_operation_type::logical_and: + if constexpr (std::is_same_v) + { + if (!arg1.value) + return primitive_value(primitive_value_base{false}); + + value const & arg2_generic = lazy_arg2(); + primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); + return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); + } + else if constexpr (std::is_integral_v) + { + if (arg1.value == T{}) + return primitive_value(primitive_value_base{arg1.value}); + + value const & arg2_generic = lazy_arg2(); + primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); + return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); + } + break; + case ast::binary_operation_type::logical_or: + if constexpr (std::is_same_v) + { + if (arg1.value) + return primitive_value(primitive_value_base{true}); + + value const & arg2_generic = lazy_arg2(); + primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); + return primitive_value(primitive_value_base{static_cast(arg1.value || arg2.value)}); + } + else if constexpr (std::is_integral_v) + { + if (arg1.value == ~T{}) + return primitive_value(primitive_value_base{arg1.value}); + + value const & arg2_generic = lazy_arg2(); + primitive_value_base const & arg2 = std::get>(std::get(arg2_generic)); + return primitive_value(primitive_value_base{static_cast(arg1.value && arg2.value)}); + } + break; + default: + throw internal_error("invalid operator type in short-circuiting branch", location); + } + + std::ostringstream os; + os << "Cannot apply " << type << " to values of type "; + types::print(os, type_of(primitive_value(arg1))); + os << " and "; + types::print(os, type_of(lazy_arg2())); + throw internal_error(os.str(), location); + } + + template + value short_circuiting_impl_same_type(ast::binary_operation_type type, Value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location) + { + std::ostringstream os; + os << "Cannot apply " << type << " to values of type "; + types::print(os, type_of(arg1)); + os << " and "; + types::print(os, type_of(lazy_arg2())); + throw internal_error(os.str(), location); + } + + template + value short_circuiting_impl_same_type(ast::binary_operation_type type, primitive_value const & arg1, LazyArg2 const & lazy_arg2, ast::location const & location) + { + return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(type, value, lazy_arg2, location); }, arg1); + } + value eval_impl(context & context, ast::binary_operation const & binary_operation) { - auto arg1 = eval_impl(context, binary_operation.arg1); - auto arg2 = eval_impl(context, binary_operation.arg2); + auto type1 = ast::get_type(*binary_operation.arg1); + auto type2 = ast::get_type(*binary_operation.arg2); - if (requires_same_argument_type(binary_operation.type)) + bool const is_same_type = types::equal(*type1, *type2); + + if (requires_same_argument_type(binary_operation.type) && !is_same_type) { - auto type1 = type_of(arg1); - auto type2 = type_of(arg2); + std::ostringstream os; + os << "Cannot apply " << binary_operation.type << " to values of type "; + types::print(os, *type1); + os << " and "; + types::print(os, *type2); + throw internal_error(os.str(), binary_operation.location); + } - if (!types::equal(type1, type2)) - { - std::ostringstream os; - os << "Cannot apply " << binary_operation.type << " to values of type "; - types::print(os, type1); - os << " and "; - types::print(os, type2); - throw internal_error(os.str(), binary_operation.location); - } + auto arg1 = eval_impl(context, binary_operation.arg1); + if (is_short_circuiting(binary_operation.type)) + { + auto lazy_arg2 = [&]{ return eval_impl(context, binary_operation.arg2); }; + return std::visit([&](auto const & value){ return short_circuiting_impl_same_type(binary_operation.type, value, lazy_arg2, binary_operation.location); }, arg1); + } + + if (is_same_type) + { + auto arg2 = eval_impl(context, binary_operation.arg2); return std::visit([&](auto const & value){ return binary_operation_impl_same_type(binary_operation.type, value, arg2, binary_operation.location); }, arg1); } diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 92a59bc..8165df4 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -43,6 +43,18 @@ namespace pslang::jit::aarch64 return 3; } + bool is_short_circuiting(ast::binary_operation_type type) + { + switch (type) + { + case ast::binary_operation_type::logical_and: + case ast::binary_operation_type::logical_or: + return true; + default: + return false; + } + } + struct populate_constants_visitor : ast::const_expression_visitor , ast::const_statement_visitor @@ -386,19 +398,22 @@ namespace pslang::jit::aarch64 bool const is_fp = types::is_floating_point_type(*arg1_type); std::uint8_t const fp_mode = fp_mode_for(*arg1_type); - if (is_fp) + apply(*node.arg1); + + if (!is_short_circuiting(node.type)) { - apply(*node.arg1); - push_fp(0, fp_mode); - apply(*node.arg2); - pop_fp(1, fp_mode); - } - else - { - apply(*node.arg1); - push(0); - apply(*node.arg2); - pop(1); + if (is_fp) + { + push_fp(0, fp_mode); + apply(*node.arg2); + pop_fp(1, fp_mode); + } + else + { + push(0); + apply(*node.arg2); + pop(1); + } } switch (node.type) @@ -445,12 +460,37 @@ namespace pslang::jit::aarch64 case ast::binary_operation_type::remainder: // TODO: implement via div & mul & sub throw std::runtime_error("Not implemented"); - case ast::binary_operation_type::logical_and: + case ast::binary_operation_type::binary_and: builder.and_reg(1, 0, 0); break; - case ast::binary_operation_type::logical_or: + case ast::binary_operation_type::logical_and: + { + std::int32_t start = pcontext.code.size(); + builder.cbz(0, 0); + push(0); + apply(*node.arg2); + pop(1); + builder.and_reg(1, 0, 0); + std::int32_t end = pcontext.code.size(); + builder.cb_inject(pcontext.code.data() + start, (end - start) / 4); + } + break; + case ast::binary_operation_type::binary_or: builder.or_reg(1, 0, 0); break; + case ast::binary_operation_type::logical_or: + { + builder.or_not_reg(31, 0, 1); + std::int32_t start = pcontext.code.size(); + builder.cbz(1, 0); + push(0); + apply(*node.arg2); + pop(1); + builder.or_reg(1, 0, 0); + std::int32_t end = pcontext.code.size(); + builder.cb_inject(pcontext.code.data() + start, (end - start) / 4); + } + break; case ast::binary_operation_type::logical_xor: builder.xor_reg(1, 0, 0); break; diff --git a/libs/parser/rules/pslang.l b/libs/parser/rules/pslang.l index b6b3cbc..796fb4a 100644 --- a/libs/parser/rules/pslang.l +++ b/libs/parser/rules/pslang.l @@ -67,7 +67,9 @@ f64 { return bp::make_f64(ctx.location); } "*" { return bp::make_asterisk(ctx.location); } "/" { return bp::make_slash(ctx.location); } "%" { return bp::make_percent(ctx.location); } +"&&" { return bp::make_double_ampersand(ctx.location); } "&" { return bp::make_ampersand(ctx.location); } +"||" { return bp::make_double_vertical_bar(ctx.location); } "|" { return bp::make_vertical_bar(ctx.location); } "^" { return bp::make_circumflex(ctx.location); } "!" { return bp::make_exclamation(ctx.location); } diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index a4148b6..a384611 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -91,7 +91,9 @@ template %token slash "/" %token percent "%" %token ampersand "&" +%token double_ampersand "&&" %token vertical_bar "|" +%token double_vertical_bar "||" %token circumflex "^" %token exclamation "!" %token equals "==" @@ -150,7 +152,7 @@ template %token end 0 %right arrow -%left ampersand vertical_bar circumflex +%left ampersand double_ampersand vertical_bar double_vertical_bar circumflex %left equals not_equals less greater less_equals greater_equals %nonassoc as %left plus minus @@ -282,8 +284,10 @@ two_or_more_type_list ; expression -: expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique($1), std::make_unique($3), @$ }; } -| expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique($1), std::make_unique($3), @$ }; } +: expression ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::binary_and, std::make_unique($1), std::make_unique($3), @$ }; } +| expression double_ampersand expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_and, std::make_unique($1), std::make_unique($3), @$ }; } +| expression vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::binary_or, std::make_unique($1), std::make_unique($3), @$ }; } +| expression double_vertical_bar expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_or, std::make_unique($1), std::make_unique($3), @$ }; } | expression circumflex expression { $$ = ast::binary_operation{ast::binary_operation_type::logical_xor, std::make_unique($1), std::make_unique($3), @$ }; } | expression equals expression { $$ = ast::binary_operation{ast::binary_operation_type::equals, std::make_unique($1), std::make_unique($3), @$ }; } | expression not_equals expression { $$ = ast::binary_operation{ast::binary_operation_type::not_equals, std::make_unique($1), std::make_unique($3), @$ }; }