From 06d509bcab5ed0b67e113f2bf7970ca747d5802d Mon Sep 17 00:00:00 2001 From: lisyarus Date: Thu, 2 Apr 2026 17:10:00 +0300 Subject: [PATCH] Add break & continue statements --- examples/raytracer.psl | 204 +++++++++++----------- libs/ast/include/pslang/ast/control.hpp | 10 ++ libs/ast/include/pslang/ast/statement.hpp | 4 + libs/ast/source/print.cpp | 12 ++ libs/ast/source/resolve_identifiers.cpp | 12 ++ libs/ast/source/type_check.cpp | 12 ++ libs/ast/source/validate.cpp | 6 + libs/interpreter/source/exec.cpp | 10 ++ libs/ir/source/compiler.cpp | 43 ++++- libs/parser/rules/pslang.l | 2 + libs/parser/rules/pslang.y | 4 + libs/parser/source/finalize.cpp | 26 +++ spec.txt | 5 +- 13 files changed, 246 insertions(+), 104 deletions(-) diff --git a/examples/raytracer.psl b/examples/raytracer.psl index e05c491..4ae9416 100644 --- a/examples/raytracer.psl +++ b/examples/raytracer.psl @@ -573,78 +573,76 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3: mut factor = vec3(1.0, 1.0, 1.0) let termination_probability = 0.25 - mut running = true - while running: + while true: let intersection = intersect_scene(scene, current_ray) - if intersection.intersected: - // Uncomment to debug normals - // return mults(add(intersection.normal, vec3(1.0, 1.0, 1.0)), 0.5) - - let cosine = - dot(intersection.normal, current_ray.direction) - let inside = cosine < 0.0 - - result = add(result, multv(factor, (*intersection.material).emission)) - - // Russian roulette ray termination - if next_f32(rng) < termination_probability: - // No break operator yet... - running = false - else: - mut new_direction = vec3(0.0, 0.0, 0.0) - - if (*intersection.material).type == diffuse_tag: - // NB: albedo is assumed to be premultiplied by pi to be in [0..1] range - // This should also contain multiplication by cos(new_dir, normal), division by direction pdf (cos / pi) - // and division by pi (because of albedo normalization), but these all cancel out - factor = multv(factor, (*intersection.material).color) - - // Cosine-weighted hemisphere direction - new_direction = normalized(add(next_vec3(rng), intersection.normal)) - - else if (*intersection.material).type == metallic_tag: - // This should also contain multiplication by brdf and division by direction pdf, - // but we'll just pretend that the random reflected ray pdf coincides with brdf and thus cancels out - factor = multv(factor, (*intersection.material).color) - - // Compute perfect-mirror reflected direction - new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine)) - - // Alter the direction based on roughness - new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * (*intersection.material).roughness))) - - else if (*intersection.material).type == glass_tag: - // This should also contain multiplication by brdf and division by direction pdf, - // but we'll just pretend that the random refracted ray pdf coincides with brdf and thus cancels out - factor = multv(factor, (*intersection.material).color) - - mut ior = (*intersection.material).ior - if inside: - ior = 1.0 / ior - - // Compute perfect refracted ray - let k = 1.0 - ior * ior * (1.0 - cosine * cosine) - if k >= 0.0: - new_direction = add(mults(current_ray.direction, ior), mults(intersection.normal, ior * abs(cosine) - sqrt(k))) - else: - // Total internal reflection - new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine)) - - // Alter the direction based on roughness - new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), (*intersection.material).roughness))) - - // Compute the new ray, and offset its origin a bit along intersection normal - let position = add(current_ray.origin, mults(current_ray.direction, intersection.distance)) - mut position_shift = 0.001 - if dot(new_direction, intersection.normal) < 0.0: - position_shift = -position_shift - current_ray = ray(add(position, mults(intersection.normal, position_shift)), new_direction) - - // Account for Russian roulette - factor = mults(factor, 1.0 / (1.0 - termination_probability)) - else: + if !intersection.intersected: result = add(result, multv(factor, (*scene).background)) - running = false + break + + // Uncomment to debug normals + // return mults(add(intersection.normal, vec3(1.0, 1.0, 1.0)), 0.5) + + let cosine = - dot(intersection.normal, current_ray.direction) + let inside = cosine < 0.0 + + result = add(result, multv(factor, (*intersection.material).emission)) + + // Russian roulette ray termination + if next_f32(rng) < termination_probability: + break + + mut new_direction = vec3(0.0, 0.0, 0.0) + + if (*intersection.material).type == diffuse_tag: + // NB: albedo is assumed to be premultiplied by pi to be in [0..1] range + // This should also contain multiplication by cos(new_dir, normal), division by direction pdf (cos / pi) + // and division by pi (because of albedo normalization), but these all cancel out + factor = multv(factor, (*intersection.material).color) + + // Cosine-weighted hemisphere direction + new_direction = normalized(add(next_vec3(rng), intersection.normal)) + + else if (*intersection.material).type == metallic_tag: + // This should also contain multiplication by brdf and division by direction pdf, + // but we'll just pretend that the random reflected ray pdf coincides with brdf and thus cancels out + factor = multv(factor, (*intersection.material).color) + + // Compute perfect-mirror reflected direction + new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine)) + + // Alter the direction based on roughness + new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * (*intersection.material).roughness))) + + else if (*intersection.material).type == glass_tag: + // This should also contain multiplication by brdf and division by direction pdf, + // but we'll just pretend that the random refracted ray pdf coincides with brdf and thus cancels out + factor = multv(factor, (*intersection.material).color) + + mut ior = (*intersection.material).ior + if inside: + ior = 1.0 / ior + + // Compute perfect refracted ray + let k = 1.0 - ior * ior * (1.0 - cosine * cosine) + if k >= 0.0: + new_direction = add(mults(current_ray.direction, ior), mults(intersection.normal, ior * abs(cosine) - sqrt(k))) + else: + // Total internal reflection + new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine)) + + // Alter the direction based on roughness + new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), (*intersection.material).roughness))) + + // Compute the new ray, and offset its origin a bit along intersection normal + let position = add(current_ray.origin, mults(current_ray.direction, intersection.distance)) + mut position_shift = 0.001 + if dot(new_direction, intersection.normal) < 0.0: + position_shift = -position_shift + current_ray = ray(add(position, mults(intersection.normal, position_shift)), new_direction) + + // Account for Russian roulette + factor = mults(factor, 1.0 / (1.0 - termination_probability)) return result @@ -777,8 +775,8 @@ const default_fovy = 2.0 * atan(0.5) func main(): let scene = make_default_scene() - let image = create_image(512ul, 512ul) - // let image = create_image(256ul, 256ul) + // let image = create_image(512ul, 512ul) + let image = create_image(256ul, 256ul) // let image = create_image(128ul, 128ul) let aspect_ratio = (image.width as f32) / (image.height as f32) @@ -865,40 +863,44 @@ func main(): sleep(timespec(0l, 125000000l)) let time = get_time() let delta = time_delta(time, last_report_time) - if delta > 0.125 || last_report_progress == 0.0: - done = 0ul - - mut th = 0ul - while th < thread_count: - lock_mutex(threads_data[th].done_mutex) - done += threads_data[th].done - unlock_mutex(threads_data[th].done_mutex) - th += 1ul + if !(delta > 0.125 || last_report_progress == 0.0): + continue - if done > 0ul: - let progress = (done as f32) / (total as f32) + done = 0ul + + mut th = 0ul + while th < thread_count: + lock_mutex(threads_data[th].done_mutex) + done += threads_data[th].done + unlock_mutex(threads_data[th].done_mutex) + th += 1ul - // Running exponential-weighted average speed - // for +/- accurate remaining time estimate - let speed_estimate = (progress - last_report_progress) / delta - average_speed = average_speed + (speed_estimate - average_speed) / (min_u32(report_count + 1u, 16u) as f32) - let time_remaining = (1.0 - progress) / average_speed + if done == 0ul: + continue - let str1 = ['%', ' ', 'd', 'o', 'n', 'e', ',', ' ', '\0'] - let str2 = [' ', 'p', 'a', 's', 's', 'e', 'd', ',', ' ', '\0'] - let str3 = [' ', 'l', 'e', 'f', 't', ' ', '\0'] + let progress = (done as f32) / (total as f32) - clear_line(50u) - print_f32(100.0 * progress) - print_str(str1 as u8*) - print_time(time_delta(time, start_time)) - print_str(str2 as u8*) - print_time(time_remaining) - print_str(str3 as u8*) - flush() - last_report_time = time - last_report_progress = progress - report_count += 1u + // Running exponential-weighted average speed + // for +/- accurate remaining time estimate + let speed_estimate = (progress - last_report_progress) / delta + average_speed = average_speed + (speed_estimate - average_speed) / (min_u32(report_count + 1u, 16u) as f32) + let time_remaining = (1.0 - progress) / average_speed + + let str1 = ['%', ' ', 'd', 'o', 'n', 'e', ',', ' ', '\0'] + let str2 = [' ', 'p', 'a', 's', 's', 'e', 'd', ',', ' ', '\0'] + let str3 = [' ', 'l', 'e', 'f', 't', ' ', '\0'] + + clear_line(50u) + print_f32(100.0 * progress) + print_str(str1 as u8*) + print_time(time_delta(time, start_time)) + print_str(str2 as u8*) + print_time(time_remaining) + print_str(str3 as u8*) + flush() + last_report_time = time + last_report_progress = progress + report_count += 1u // No for loops yet... th = 0ul diff --git a/libs/ast/include/pslang/ast/control.hpp b/libs/ast/include/pslang/ast/control.hpp index b460880..b167bdf 100644 --- a/libs/ast/include/pslang/ast/control.hpp +++ b/libs/ast/include/pslang/ast/control.hpp @@ -49,4 +49,14 @@ namespace pslang::ast location location; }; + struct break_statement + { + location location; + }; + + struct continue_statement + { + location location; + }; + } diff --git a/libs/ast/include/pslang/ast/statement.hpp b/libs/ast/include/pslang/ast/statement.hpp index 1bdad3e..224a896 100644 --- a/libs/ast/include/pslang/ast/statement.hpp +++ b/libs/ast/include/pslang/ast/statement.hpp @@ -28,6 +28,8 @@ namespace pslang::ast else_block, else_if_block, while_block, + break_statement, + continue_statement, function_definition, foreign_function_declaration, return_statement, @@ -47,6 +49,8 @@ namespace pslang::ast variable_declaration, if_chain, while_block, + break_statement, + continue_statement, function_definition, foreign_function_declaration, return_statement, diff --git a/libs/ast/source/print.cpp b/libs/ast/source/print.cpp index 8207f44..86713ef 100644 --- a/libs/ast/source/print.cpp +++ b/libs/ast/source/print.cpp @@ -434,6 +434,18 @@ namespace pslang::ast child(*node.statements); } + void apply(break_statement const & node) + { + put_indent(out, options); + out << "break\n"; + } + + void apply(continue_statement const & node) + { + put_indent(out, options); + out << "continue\n"; + } + void apply(function_definition const & node) { put_indent(out, options); diff --git a/libs/ast/source/resolve_identifiers.cpp b/libs/ast/source/resolve_identifiers.cpp index 4e96102..661e251 100644 --- a/libs/ast/source/resolve_identifiers.cpp +++ b/libs/ast/source/resolve_identifiers.cpp @@ -84,6 +84,12 @@ namespace pslang::ast void apply(while_block const &) {} + void apply(break_statement const &) + {} + + void apply(continue_statement const &) + {} + void apply(function_definition & function_definition) { if (scopes.back().contains(function_definition.name)) @@ -321,6 +327,12 @@ namespace pslang::ast scopes.pop_back(); } + void apply(break_statement const & break_statement) + {} + + void apply(continue_statement const & continue_statement) + {} + void apply(function_definition & function_definition) { // Already added to scope by populate_globals_visitor diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index f3ed367..873c476 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -180,6 +180,12 @@ namespace pslang::ast void apply(while_block const &) {} + void apply(break_statement const &) + {} + + void apply(continue_statement const &) + {} + void apply(function_definition & node) { types::function_type function_type; @@ -866,6 +872,12 @@ namespace pslang::ast apply(*node.statements); } + void apply(break_statement const &) + {} + + void apply(continue_statement const &) + {} + void apply(function_definition & node) { apply(*node.statements); diff --git a/libs/ast/source/validate.cpp b/libs/ast/source/validate.cpp index 90b8675..bd56ef3 100644 --- a/libs/ast/source/validate.cpp +++ b/libs/ast/source/validate.cpp @@ -31,6 +31,12 @@ namespace pslang::ast apply(*node.statements); } + void apply(break_statement const &) + {} + + void apply(continue_statement const &) + {} + void apply(function_definition const & node) { apply(*node.statements); diff --git a/libs/interpreter/source/exec.cpp b/libs/interpreter/source/exec.cpp index c35088a..b37cfd8 100644 --- a/libs/interpreter/source/exec.cpp +++ b/libs/interpreter/source/exec.cpp @@ -139,6 +139,16 @@ namespace pslang::interpreter } } + void exec_impl(context & context, ast::break_statement const &) + { + throw std::runtime_error("Not implemented"); + } + + void exec_impl(context & context, ast::continue_statement const &) + { + throw std::runtime_error("Not implemented"); + } + void exec_impl(context & context, ast::function_definition const & function_definition) { auto & frame = context.frame_stack.back(); diff --git a/libs/ir/source/compiler.cpp b/libs/ir/source/compiler.cpp index e016e7e..c367b50 100644 --- a/libs/ir/source/compiler.cpp +++ b/libs/ir/source/compiler.cpp @@ -24,6 +24,15 @@ namespace pslang::ir std::vector scopes; + struct loop_scope + { + // These must point to jump nodes + std::vector resolve_break; + std::vector resolve_continue; + }; + + std::vector loop_scopes; + struct resolve_address_data { node_ref address; @@ -655,15 +664,39 @@ namespace pslang::ir auto begin = last(); auto condition = apply(*node.condition); mcontext.nodes->emplace_back(jump_if_zero{condition, {}}); - auto jump1 = last(); + auto jump_to_end = last(); + lcontext.loop_scopes.emplace_back(); lcontext.scopes.emplace_back(); apply(*node.statements); lcontext.scopes.pop_back(); mcontext.nodes->emplace_back(jump{begin}); mcontext.nodes->emplace_back(label{}); - std::get(jump1->instruction).target = last(); + + auto end = last(); + std::get(jump_to_end->instruction).target = end; + + for (auto node : lcontext.loop_scopes.back().resolve_continue) + std::get(node->instruction).target = begin; + for (auto node : lcontext.loop_scopes.back().resolve_break) + std::get(node->instruction).target = end; + lcontext.loop_scopes.pop_back(); + + return last(); + } + + node_ref apply(ast::break_statement const &) + { + mcontext.nodes->emplace_back(jump{}); + lcontext.loop_scopes.back().resolve_break.push_back(last()); + return last(); + } + + node_ref apply(ast::continue_statement const &) + { + mcontext.nodes->emplace_back(jump{}); + lcontext.loop_scopes.back().resolve_continue.push_back(last()); return last(); } @@ -759,6 +792,12 @@ namespace pslang::ir lcontext.scopes.pop_back(); } + void apply(ast::break_statement const &) + {} + + void apply(ast::continue_statement const &) + {} + void apply(ast::function_definition const & node) { compile_function_visitor{{}, {}, mcontext, lcontext}.do_apply(node); diff --git a/libs/parser/rules/pslang.l b/libs/parser/rules/pslang.l index 474c597..fe37b80 100644 --- a/libs/parser/rules/pslang.l +++ b/libs/parser/rules/pslang.l @@ -29,6 +29,8 @@ mut { return bp::make_mut(ctx.location); } if { return bp::make_if(ctx.location); } else { return bp::make_else(ctx.location); } while { return bp::make_while(ctx.location); } +break { return bp::make_break(ctx.location); } +continue { return bp::make_continue(ctx.location); } as { return bp::make_as(ctx.location); } func { return bp::make_func(ctx.location); } foreign { return bp::make_foreign(ctx.location); } diff --git a/libs/parser/rules/pslang.y b/libs/parser/rules/pslang.y index dd06d2c..aebf5d7 100644 --- a/libs/parser/rules/pslang.y +++ b/libs/parser/rules/pslang.y @@ -140,6 +140,8 @@ template %token if %token else %token while +%token break +%token continue %token as %token func %token foreign @@ -247,6 +249,8 @@ statement | else colon { $$ = ast::else_block{@$}; } | else if expression colon { $$ = ast::else_if_block{std::make_unique($3), @$}; } | while expression colon { $$ = ast::while_block{std::make_unique($2), {}, @$, @$}; } +| break { $$ = ast::break_statement{@$}; } +| continue { $$ = ast::continue_statement{@$}; } | func name lparen function_declaration_argument_list rparen function_return_type colon { $$ = ast::function_definition{{$2, $4, $6, @$}, {}}; } | foreign func name lparen function_declaration_argument_list rparen function_return_type { $$ = ast::foreign_function_declaration{{$3, $5, $7, @$}}; } | return expression { $$ = ast::return_statement{std::make_unique($2), @$}; } diff --git a/libs/parser/source/finalize.cpp b/libs/parser/source/finalize.cpp index c7e4a53..1e9ace4 100644 --- a/libs/parser/source/finalize.cpp +++ b/libs/parser/source/finalize.cpp @@ -51,6 +51,16 @@ namespace pslang::parser return node.location = ast::merge(node.prelude_location, apply(*node.statements)); } + ast::location apply(ast::break_statement & node) + { + return node.location; + } + + ast::location apply(ast::continue_statement & node) + { + return node.location; + } + ast::location apply(ast::function_definition & node) { return node.location = ast::merge(node.prelude_location, apply(*node.statements)); @@ -103,6 +113,7 @@ namespace pslang::parser stack.push_back(result.get()); std::size_t current_indent = 0; std::vector function_stack; + std::vector loop_stack; auto current_statement_list = [&](ast::location const & location) -> ast::statement_list * { @@ -139,6 +150,8 @@ namespace pslang::parser throw ast::invalid_ast_error("Unexpected empty indent stack", ast::get_location(*statement.statement)); if (!function_stack.empty() && std::holds_alternative(stack.back()) && function_stack.back()->statements.get() == std::get(stack.back())) function_stack.pop_back(); + if (!loop_stack.empty() && std::holds_alternative(stack.back()) && loop_stack.back() == std::get(stack.back())) + loop_stack.pop_back(); stack.pop_back(); --current_indent; } @@ -182,6 +195,7 @@ namespace pslang::parser while_block->statements = std::make_unique(); list = while_block->statements.get(); current_statement_list(location)->statements.push_back(std::make_unique(std::move(*while_block))); + loop_stack.push_back(list); } else if (auto function_definition = std::get_if(statement.statement.get())) { @@ -229,6 +243,18 @@ namespace pslang::parser { current_statement_list(location)->statements.push_back(std::make_unique(std::move(*foreign_function_declaration))); } + else if (auto break_statement = std::get_if(statement.statement.get())) + { + if (loop_stack.empty()) + throw parse_error("Break without an enclosing loop", break_statement->location); + current_statement_list(location)->statements.push_back(std::make_unique(std::move(*break_statement))); + } + else if (auto continue_statement = std::get_if(statement.statement.get())) + { + if (loop_stack.empty()) + throw parse_error("Continue without an enclosing loop", continue_statement->location); + current_statement_list(location)->statements.push_back(std::make_unique(std::move(*continue_statement))); + } else { throw ast::invalid_ast_error(std::format("Unknown pre-statement \"{}\"", std::visit([](auto const & statement){ return typeid(statement).name(); }, *statement.statement)), location); diff --git a/spec.txt b/spec.txt index ae1c1e1..2157717 100644 --- a/spec.txt +++ b/spec.txt @@ -148,8 +148,11 @@ Flow control: while condition: statements + if x: + break + if y: + continue - TODO: break/continue? TODO: for loops? iterator/range interface? ======== STRUCTS ========