From 8ba5345daf5d2e8db48d7f8586df7eb1af543f56 Mon Sep 17 00:00:00 2001 From: lisyarus Date: Thu, 2 Apr 2026 20:19:55 +0300 Subject: [PATCH] Implement ternary ifs --- examples/raytracer.psl | 38 +++++++++------------- libs/ast/include/pslang/ast/control.hpp | 10 ++++++ libs/ast/include/pslang/ast/expression.hpp | 4 ++- libs/ast/source/expression.cpp | 10 ++++++ libs/ast/source/print.cpp | 9 +++++ libs/ast/source/resolve_identifiers.cpp | 7 ++++ libs/ast/source/type_check.cpp | 31 ++++++++++++++++++ libs/interpreter/source/eval.cpp | 10 ++++++ libs/ir/source/compiler.cpp | 23 ++++++++++++- libs/jit/source/arch/aarch64/compiler.cpp | 2 +- libs/parser/CMakeLists.txt | 2 +- libs/parser/rules/pslang.l | 1 + libs/parser/rules/pslang.y | 3 ++ 13 files changed, 123 insertions(+), 27 deletions(-) diff --git a/examples/raytracer.psl b/examples/raytracer.psl index 09dcb53..433a09b 100644 --- a/examples/raytracer.psl +++ b/examples/raytracer.psl @@ -478,17 +478,16 @@ func intersect_sphere(ray: ray, object: object*) -> intersection: let t2 = (- b + sqrt(D)) / (2.0 * a) if t2 < 0.0: return no_intersection - mut t = t1 - if t1 < 0.0: - t = t2 + let t = if t1 < 0.0 then t2 else t1 let normal = normalized(add(delta, mults(ray.direction, t))) return intersection(true, t, normal, &(*object).material) func intersect_box(ray: ray, object: object*) -> intersection: - func swap(x: f32 mut*, y: f32 mut*): - let temp = *x - *x = *y - *y = temp + func sort(x: f32 mut*, y: f32 mut*): + if *x > *y: + let temp = *x + *x = *y + *y = temp let shape = &(*object).shape.data as box* let inverse_rotation = inverse((*object).rotation) @@ -503,12 +502,9 @@ func intersect_box(ray: ray, object: object*) -> intersection: mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z - if txmin > txmax: - swap(&mut txmin, &mut txmax) - if tymin > tymax: - swap(&mut tymin, &mut tymax) - if tzmin > tzmax: - swap(&mut tzmin, &mut tzmax) + sort(&mut txmin, &mut txmax) + sort(&mut tymin, &mut tymax) + sort(&mut tzmin, &mut tzmax) let tmin = max(txmin, max(tymin, tzmin)) let tmax = min(txmax, min(tymax, tzmax)) @@ -517,14 +513,10 @@ func intersect_box(ray: ray, object: object*) -> intersection: return no_intersection let inside = tmin < 0.0 - mut t = tmin - mut tt = vec3(txmin, tymin, tzmin) - // No ternary if yet... - if inside: - t = tmax - tt = vec3(txmax, tymax, tzmax) + let t = if inside then tmax else tmin + let tt = if inside then vec3(txmax, tymax, tzmax) else vec3(txmin, tymin, tzmin) - mut normal = vec3(0.0, 0.0, 0.0) + mut normal = vec3() if t == tt.x: if local_direction.x > 0.0: @@ -545,7 +537,7 @@ func intersect_box(ray: ray, object: object*) -> intersection: if inside: normal = mults(normal, -1.0) - return intersection(true, tmin, rotate((*object).rotation, normal), &(*object).material) + return intersection(true, t, rotate((*object).rotation, normal), &(*object).material) func intersect_scene(scene: scene*, ray: ray) -> intersection: mut intersection = no_intersection @@ -776,8 +768,8 @@ func main(): let scene = make_default_scene() // let image = create_image(512ul, 512ul) - let image = create_image(256ul, 256ul) - // let image = create_image(128ul, 128ul) + // let image = create_image(256ul, 256ul) + let image = create_image(128ul, 128ul) let aspect_ratio = (image.width as f32) / (image.height as f32) let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio)) diff --git a/libs/ast/include/pslang/ast/control.hpp b/libs/ast/include/pslang/ast/control.hpp index b167bdf..8583789 100644 --- a/libs/ast/include/pslang/ast/control.hpp +++ b/libs/ast/include/pslang/ast/control.hpp @@ -3,10 +3,20 @@ #include #include #include +#include namespace pslang::ast { + struct if_expression + { + expression_ptr condition; + expression_ptr if_true; + expression_ptr if_false; + ast::location location; + types::type_ptr inferred_type = nullptr; + }; + struct if_block { expression_ptr condition; diff --git a/libs/ast/include/pslang/ast/expression.hpp b/libs/ast/include/pslang/ast/expression.hpp index 07af6fc..b4081b4 100644 --- a/libs/ast/include/pslang/ast/expression.hpp +++ b/libs/ast/include/pslang/ast/expression.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -40,7 +41,8 @@ namespace pslang::ast function_call, array, array_access, - field_access + field_access, + if_expression >; struct expression diff --git a/libs/ast/source/expression.cpp b/libs/ast/source/expression.cpp index 990d868..a1ad6ec 100644 --- a/libs/ast/source/expression.cpp +++ b/libs/ast/source/expression.cpp @@ -58,6 +58,11 @@ namespace pslang::ast { return node.location; } + + location apply(if_expression const & node) + { + return node.location; + } }; struct get_type_visitor @@ -110,6 +115,11 @@ namespace pslang::ast { return node.inferred_type; } + + types::type_ptr apply(if_expression const & node) + { + return node.inferred_type; + } }; } diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index 86713ef..df198ac 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -358,6 +358,15 @@ namespace pslang::ast out << "field access { name = \"" << node.field_name << "\" }\n"; child(*node.object); } + + void apply(if_expression const & node) + { + put_indent(out, options); + out << "if\n"; + child(*node.condition); + child(*node.if_true); + child(*node.if_false); + } }; struct statement_print_visitor diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index 661e251..3844949 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -281,6 +281,13 @@ namespace pslang::ast apply(*field_access.object); } + void apply(if_expression const & if_expression) + { + apply(*if_expression.condition); + apply(*if_expression.if_true); + apply(*if_expression.if_false); + } + void apply(expression_ptr const & expression_ptr) { apply(*expression_ptr); diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index 873c476..db45260 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -783,6 +783,37 @@ namespace pslang::ast throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location); } + void apply(if_expression & node) + { + apply(*node.condition); + apply(*node.if_true); + apply(*node.if_false); + + auto condition_type = get_type(*node.condition); + if (!types::is_bool_type(*condition_type)) + { + std::ostringstream os; + os << "if condition expects a bool type, but got "; + print(os, *condition_type); + throw type_error(os.str(), get_location(*node.condition)); + } + + auto true_type = get_type(*node.if_true); + auto false_type = get_type(*node.if_false); + + if (!types::equal(*true_type, *false_type)) + { + std::ostringstream os; + os << "Both ternary if cases must have the same type, but got "; + print(os, *true_type); + os << " and "; + print(os, *false_type); + throw type_error(os.str(), node.location); + } + + node.inferred_type = true_type; + } + void apply(expression_ptr const & node) { apply(*node); diff --git a/libs/interpreter/source/eval.cpp b/libs/interpreter/source/eval.cpp index af41637..f2e55d3 100644 --- a/libs/interpreter/source/eval.cpp +++ b/libs/interpreter/source/eval.cpp @@ -686,6 +686,11 @@ namespace pslang::interpreter throw internal_error(os.str(), field_access.location); } + value eval_impl(context & context, ast::if_expression const & if_expression) + { + throw std::runtime_error("Not implemented"); + } + value eval_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression); @@ -774,6 +779,11 @@ namespace pslang::interpreter throw internal_error(os.str(), field_access.location); } + value * eval_ref_impl(context & context, ast::if_expression const & if_expression) + { + throw std::runtime_error("Not implemented"); + } + value * eval_ref_impl(context & context, ast::expression_ptr const & expression) { return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression); diff --git a/libs/ir/source/compiler.cpp b/libs/ir/source/compiler.cpp index 9f2e34b..98df31e 100644 --- a/libs/ir/source/compiler.cpp +++ b/libs/ir/source/compiler.cpp @@ -473,7 +473,28 @@ namespace pslang::ir throw std::runtime_error("Unknown field name"); } - // TODO: array, array_access, field_access + node_ref apply(ast::if_expression const & node) + { + mcontext.nodes->emplace_back(alloc{}, node.inferred_type); + auto result = last(); + auto condition = apply(*node.condition); + mcontext.nodes->emplace_back(jump_if_zero{condition}); + auto jump_skip_if_true = last(); + auto true_result = apply(*node.if_true); + mcontext.nodes->emplace_back(assignment{result, true_result}, node.inferred_type); + mcontext.nodes->emplace_back(jump{}); + auto jump_end_if_true = last(); + mcontext.nodes->emplace_back(label{}); + auto after_if_true = last(); + auto false_result = apply(*node.if_false); + mcontext.nodes->emplace_back(assignment{result, false_result}, node.inferred_type); + mcontext.nodes->emplace_back(label{}); + auto end = last(); + + std::get(jump_skip_if_true->instruction).target = after_if_true; + std::get(jump_end_if_true->instruction).target = end; + return result; + } // Statements diff --git a/libs/jit/source/arch/aarch64/compiler.cpp b/libs/jit/source/arch/aarch64/compiler.cpp index 0d3602b..0d61320 100644 --- a/libs/jit/source/arch/aarch64/compiler.cpp +++ b/libs/jit/source/arch/aarch64/compiler.cpp @@ -312,7 +312,7 @@ namespace pslang::jit::aarch64 void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &) { - // Nothing to do: alloc just allocates a node of a struct type, + // Nothing to do: alloc just allocates a node of some type, // but we already allocated stack space for it } diff --git a/libs/parser/CMakeLists.txt b/libs/parser/CMakeLists.txt index 4649615..d9c803f 100644 --- a/libs/parser/CMakeLists.txt +++ b/libs/parser/CMakeLists.txt @@ -24,7 +24,7 @@ bison_target( ${PSLANG_PARSER_RULES_FILE} ${PSLANG_PARSER_SOURCE_FILE} DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE} - # COMPILE_FLAGS -Wcounterexamples + COMPILE_FLAGS -Wcounterexamples ) add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser) diff --git a/libs/parser/rules/pslang.l b/libs/parser/rules/pslang.l index fe37b80..ad99a32 100644 --- a/libs/parser/rules/pslang.l +++ b/libs/parser/rules/pslang.l @@ -28,6 +28,7 @@ let { return bp::make_let(ctx.location); } mut { return bp::make_mut(ctx.location); } if { return bp::make_if(ctx.location); } else { return bp::make_else(ctx.location); } +then { return bp::make_then(ctx.location); } while { return bp::make_while(ctx.location); } break { return bp::make_break(ctx.location); } continue { return bp::make_continue(ctx.location); } diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index aebf5d7..8c77327 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -139,6 +139,7 @@ template %token mut %token if %token else +%token then %token while %token break %token continue @@ -177,6 +178,7 @@ template %left asterisk slash percent %precedence NOT %precedence ADDRESSOF DEREFERENCE +%precedence else %precedence lbracket %type indented_statement_list @@ -352,6 +354,7 @@ expression | ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique($3), @$ }; } | asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique($2), @$ }; } | postfix_expression { $$ = $1; } +| if expression then expression else expression { $$ = ast::if_expression{ std::make_unique($2), std::make_unique($4), std::make_unique($6), @$ }; } ; postfix_expression