Add break & continue statements

This commit is contained in:
Nikita Lisitsa 2026-04-02 17:10:00 +03:00
parent 356eea4fb8
commit 06d509bcab
13 changed files with 246 additions and 104 deletions

View file

@ -573,78 +573,76 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3:
mut factor = vec3(1.0, 1.0, 1.0)
let termination_probability = 0.25
mut running = true
while running:
while true:
let intersection = intersect_scene(scene, current_ray)
if intersection.intersected:
// Uncomment to debug normals
// return mults(add(intersection.normal, vec3(1.0, 1.0, 1.0)), 0.5)
let cosine = - dot(intersection.normal, current_ray.direction)
let inside = cosine < 0.0
result = add(result, multv(factor, (*intersection.material).emission))
// Russian roulette ray termination
if next_f32(rng) < termination_probability:
// No break operator yet...
running = false
else:
mut new_direction = vec3(0.0, 0.0, 0.0)
if (*intersection.material).type == diffuse_tag:
// NB: albedo is assumed to be premultiplied by pi to be in [0..1] range
// This should also contain multiplication by cos(new_dir, normal), division by direction pdf (cos / pi)
// and division by pi (because of albedo normalization), but these all cancel out
factor = multv(factor, (*intersection.material).color)
// Cosine-weighted hemisphere direction
new_direction = normalized(add(next_vec3(rng), intersection.normal))
else if (*intersection.material).type == metallic_tag:
// This should also contain multiplication by brdf and division by direction pdf,
// but we'll just pretend that the random reflected ray pdf coincides with brdf and thus cancels out
factor = multv(factor, (*intersection.material).color)
// Compute perfect-mirror reflected direction
new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine))
// Alter the direction based on roughness
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * (*intersection.material).roughness)))
else if (*intersection.material).type == glass_tag:
// This should also contain multiplication by brdf and division by direction pdf,
// but we'll just pretend that the random refracted ray pdf coincides with brdf and thus cancels out
factor = multv(factor, (*intersection.material).color)
mut ior = (*intersection.material).ior
if inside:
ior = 1.0 / ior
// Compute perfect refracted ray
let k = 1.0 - ior * ior * (1.0 - cosine * cosine)
if k >= 0.0:
new_direction = add(mults(current_ray.direction, ior), mults(intersection.normal, ior * abs(cosine) - sqrt(k)))
else:
// Total internal reflection
new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine))
// Alter the direction based on roughness
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), (*intersection.material).roughness)))
// Compute the new ray, and offset its origin a bit along intersection normal
let position = add(current_ray.origin, mults(current_ray.direction, intersection.distance))
mut position_shift = 0.001
if dot(new_direction, intersection.normal) < 0.0:
position_shift = -position_shift
current_ray = ray(add(position, mults(intersection.normal, position_shift)), new_direction)
// Account for Russian roulette
factor = mults(factor, 1.0 / (1.0 - termination_probability))
else:
if !intersection.intersected:
result = add(result, multv(factor, (*scene).background))
running = false
break
// Uncomment to debug normals
// return mults(add(intersection.normal, vec3(1.0, 1.0, 1.0)), 0.5)
let cosine = - dot(intersection.normal, current_ray.direction)
let inside = cosine < 0.0
result = add(result, multv(factor, (*intersection.material).emission))
// Russian roulette ray termination
if next_f32(rng) < termination_probability:
break
mut new_direction = vec3(0.0, 0.0, 0.0)
if (*intersection.material).type == diffuse_tag:
// NB: albedo is assumed to be premultiplied by pi to be in [0..1] range
// This should also contain multiplication by cos(new_dir, normal), division by direction pdf (cos / pi)
// and division by pi (because of albedo normalization), but these all cancel out
factor = multv(factor, (*intersection.material).color)
// Cosine-weighted hemisphere direction
new_direction = normalized(add(next_vec3(rng), intersection.normal))
else if (*intersection.material).type == metallic_tag:
// This should also contain multiplication by brdf and division by direction pdf,
// but we'll just pretend that the random reflected ray pdf coincides with brdf and thus cancels out
factor = multv(factor, (*intersection.material).color)
// Compute perfect-mirror reflected direction
new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine))
// Alter the direction based on roughness
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * (*intersection.material).roughness)))
else if (*intersection.material).type == glass_tag:
// This should also contain multiplication by brdf and division by direction pdf,
// but we'll just pretend that the random refracted ray pdf coincides with brdf and thus cancels out
factor = multv(factor, (*intersection.material).color)
mut ior = (*intersection.material).ior
if inside:
ior = 1.0 / ior
// Compute perfect refracted ray
let k = 1.0 - ior * ior * (1.0 - cosine * cosine)
if k >= 0.0:
new_direction = add(mults(current_ray.direction, ior), mults(intersection.normal, ior * abs(cosine) - sqrt(k)))
else:
// Total internal reflection
new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine))
// Alter the direction based on roughness
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), (*intersection.material).roughness)))
// Compute the new ray, and offset its origin a bit along intersection normal
let position = add(current_ray.origin, mults(current_ray.direction, intersection.distance))
mut position_shift = 0.001
if dot(new_direction, intersection.normal) < 0.0:
position_shift = -position_shift
current_ray = ray(add(position, mults(intersection.normal, position_shift)), new_direction)
// Account for Russian roulette
factor = mults(factor, 1.0 / (1.0 - termination_probability))
return result
@ -777,8 +775,8 @@ const default_fovy = 2.0 * atan(0.5)
func main():
let scene = make_default_scene()
let image = create_image(512ul, 512ul)
// let image = create_image(256ul, 256ul)
// let image = create_image(512ul, 512ul)
let image = create_image(256ul, 256ul)
// let image = create_image(128ul, 128ul)
let aspect_ratio = (image.width as f32) / (image.height as f32)
@ -865,40 +863,44 @@ func main():
sleep(timespec(0l, 125000000l))
let time = get_time()
let delta = time_delta(time, last_report_time)
if delta > 0.125 || last_report_progress == 0.0:
done = 0ul
mut th = 0ul
while th < thread_count:
lock_mutex(threads_data[th].done_mutex)
done += threads_data[th].done
unlock_mutex(threads_data[th].done_mutex)
th += 1ul
if !(delta > 0.125 || last_report_progress == 0.0):
continue
if done > 0ul:
let progress = (done as f32) / (total as f32)
done = 0ul
mut th = 0ul
while th < thread_count:
lock_mutex(threads_data[th].done_mutex)
done += threads_data[th].done
unlock_mutex(threads_data[th].done_mutex)
th += 1ul
// Running exponential-weighted average speed
// for +/- accurate remaining time estimate
let speed_estimate = (progress - last_report_progress) / delta
average_speed = average_speed + (speed_estimate - average_speed) / (min_u32(report_count + 1u, 16u) as f32)
let time_remaining = (1.0 - progress) / average_speed
if done == 0ul:
continue
let str1 = ['%', ' ', 'd', 'o', 'n', 'e', ',', ' ', '\0']
let str2 = [' ', 'p', 'a', 's', 's', 'e', 'd', ',', ' ', '\0']
let str3 = [' ', 'l', 'e', 'f', 't', ' ', '\0']
let progress = (done as f32) / (total as f32)
clear_line(50u)
print_f32(100.0 * progress)
print_str(str1 as u8*)
print_time(time_delta(time, start_time))
print_str(str2 as u8*)
print_time(time_remaining)
print_str(str3 as u8*)
flush()
last_report_time = time
last_report_progress = progress
report_count += 1u
// Running exponential-weighted average speed
// for +/- accurate remaining time estimate
let speed_estimate = (progress - last_report_progress) / delta
average_speed = average_speed + (speed_estimate - average_speed) / (min_u32(report_count + 1u, 16u) as f32)
let time_remaining = (1.0 - progress) / average_speed
let str1 = ['%', ' ', 'd', 'o', 'n', 'e', ',', ' ', '\0']
let str2 = [' ', 'p', 'a', 's', 's', 'e', 'd', ',', ' ', '\0']
let str3 = [' ', 'l', 'e', 'f', 't', ' ', '\0']
clear_line(50u)
print_f32(100.0 * progress)
print_str(str1 as u8*)
print_time(time_delta(time, start_time))
print_str(str2 as u8*)
print_time(time_remaining)
print_str(str3 as u8*)
flush()
last_report_time = time
last_report_progress = progress
report_count += 1u
// No for loops yet...
th = 0ul

View file

@ -49,4 +49,14 @@ namespace pslang::ast
location location;
};
struct break_statement
{
location location;
};
struct continue_statement
{
location location;
};
}

View file

@ -28,6 +28,8 @@ namespace pslang::ast
else_block,
else_if_block,
while_block,
break_statement,
continue_statement,
function_definition,
foreign_function_declaration,
return_statement,
@ -47,6 +49,8 @@ namespace pslang::ast
variable_declaration,
if_chain,
while_block,
break_statement,
continue_statement,
function_definition,
foreign_function_declaration,
return_statement,

View file

@ -434,6 +434,18 @@ namespace pslang::ast
child(*node.statements);
}
void apply(break_statement const & node)
{
put_indent(out, options);
out << "break\n";
}
void apply(continue_statement const & node)
{
put_indent(out, options);
out << "continue\n";
}
void apply(function_definition const & node)
{
put_indent(out, options);

View file

@ -84,6 +84,12 @@ namespace pslang::ast
void apply(while_block const &)
{}
void apply(break_statement const &)
{}
void apply(continue_statement const &)
{}
void apply(function_definition & function_definition)
{
if (scopes.back().contains(function_definition.name))
@ -321,6 +327,12 @@ namespace pslang::ast
scopes.pop_back();
}
void apply(break_statement const & break_statement)
{}
void apply(continue_statement const & continue_statement)
{}
void apply(function_definition & function_definition)
{
// Already added to scope by populate_globals_visitor

View file

@ -180,6 +180,12 @@ namespace pslang::ast
void apply(while_block const &)
{}
void apply(break_statement const &)
{}
void apply(continue_statement const &)
{}
void apply(function_definition & node)
{
types::function_type function_type;
@ -866,6 +872,12 @@ namespace pslang::ast
apply(*node.statements);
}
void apply(break_statement const &)
{}
void apply(continue_statement const &)
{}
void apply(function_definition & node)
{
apply(*node.statements);

View file

@ -31,6 +31,12 @@ namespace pslang::ast
apply(*node.statements);
}
void apply(break_statement const &)
{}
void apply(continue_statement const &)
{}
void apply(function_definition const & node)
{
apply(*node.statements);

View file

@ -139,6 +139,16 @@ namespace pslang::interpreter
}
}
void exec_impl(context & context, ast::break_statement const &)
{
throw std::runtime_error("Not implemented");
}
void exec_impl(context & context, ast::continue_statement const &)
{
throw std::runtime_error("Not implemented");
}
void exec_impl(context & context, ast::function_definition const & function_definition)
{
auto & frame = context.frame_stack.back();

View file

@ -24,6 +24,15 @@ namespace pslang::ir
std::vector<scope> scopes;
struct loop_scope
{
// These must point to jump nodes
std::vector<node_ref> resolve_break;
std::vector<node_ref> resolve_continue;
};
std::vector<loop_scope> loop_scopes;
struct resolve_address_data
{
node_ref address;
@ -655,15 +664,39 @@ namespace pslang::ir
auto begin = last();
auto condition = apply(*node.condition);
mcontext.nodes->emplace_back(jump_if_zero{condition, {}});
auto jump1 = last();
auto jump_to_end = last();
lcontext.loop_scopes.emplace_back();
lcontext.scopes.emplace_back();
apply(*node.statements);
lcontext.scopes.pop_back();
mcontext.nodes->emplace_back(jump{begin});
mcontext.nodes->emplace_back(label{});
std::get<jump_if_zero>(jump1->instruction).target = last();
auto end = last();
std::get<jump_if_zero>(jump_to_end->instruction).target = end;
for (auto node : lcontext.loop_scopes.back().resolve_continue)
std::get<jump>(node->instruction).target = begin;
for (auto node : lcontext.loop_scopes.back().resolve_break)
std::get<jump>(node->instruction).target = end;
lcontext.loop_scopes.pop_back();
return last();
}
node_ref apply(ast::break_statement const &)
{
mcontext.nodes->emplace_back(jump{});
lcontext.loop_scopes.back().resolve_break.push_back(last());
return last();
}
node_ref apply(ast::continue_statement const &)
{
mcontext.nodes->emplace_back(jump{});
lcontext.loop_scopes.back().resolve_continue.push_back(last());
return last();
}
@ -759,6 +792,12 @@ namespace pslang::ir
lcontext.scopes.pop_back();
}
void apply(ast::break_statement const &)
{}
void apply(ast::continue_statement const &)
{}
void apply(ast::function_definition const & node)
{
compile_function_visitor{{}, {}, mcontext, lcontext}.do_apply(node);

View file

@ -29,6 +29,8 @@ mut { return bp::make_mut(ctx.location); }
if { return bp::make_if(ctx.location); }
else { return bp::make_else(ctx.location); }
while { return bp::make_while(ctx.location); }
break { return bp::make_break(ctx.location); }
continue { return bp::make_continue(ctx.location); }
as { return bp::make_as(ctx.location); }
func { return bp::make_func(ctx.location); }
foreign { return bp::make_foreign(ctx.location); }

View file

@ -140,6 +140,8 @@ template <typename T>
%token if
%token else
%token while
%token break
%token continue
%token as
%token func
%token foreign
@ -247,6 +249,8 @@ statement
| else colon { $$ = ast::else_block{@$}; }
| else if expression colon { $$ = ast::else_if_block{std::make_unique<ast::expression>($3), @$}; }
| while expression colon { $$ = ast::while_block{std::make_unique<ast::expression>($2), {}, @$, @$}; }
| break { $$ = ast::break_statement{@$}; }
| continue { $$ = ast::continue_statement{@$}; }
| func name lparen function_declaration_argument_list rparen function_return_type colon { $$ = ast::function_definition{{$2, $4, $6, @$}, {}}; }
| foreign func name lparen function_declaration_argument_list rparen function_return_type { $$ = ast::foreign_function_declaration{{$3, $5, $7, @$}}; }
| return expression { $$ = ast::return_statement{std::make_unique<ast::expression>($2), @$}; }

View file

@ -51,6 +51,16 @@ namespace pslang::parser
return node.location = ast::merge(node.prelude_location, apply(*node.statements));
}
ast::location apply(ast::break_statement & node)
{
return node.location;
}
ast::location apply(ast::continue_statement & node)
{
return node.location;
}
ast::location apply(ast::function_definition & node)
{
return node.location = ast::merge(node.prelude_location, apply(*node.statements));
@ -103,6 +113,7 @@ namespace pslang::parser
stack.push_back(result.get());
std::size_t current_indent = 0;
std::vector<ast::function_definition *> function_stack;
std::vector<ast::statement_list *> loop_stack;
auto current_statement_list = [&](ast::location const & location) -> ast::statement_list *
{
@ -139,6 +150,8 @@ namespace pslang::parser
throw ast::invalid_ast_error("Unexpected empty indent stack", ast::get_location(*statement.statement));
if (!function_stack.empty() && std::holds_alternative<ast::statement_list *>(stack.back()) && function_stack.back()->statements.get() == std::get<ast::statement_list *>(stack.back()))
function_stack.pop_back();
if (!loop_stack.empty() && std::holds_alternative<ast::statement_list *>(stack.back()) && loop_stack.back() == std::get<ast::statement_list *>(stack.back()))
loop_stack.pop_back();
stack.pop_back();
--current_indent;
}
@ -182,6 +195,7 @@ namespace pslang::parser
while_block->statements = std::make_unique<ast::statement_list>();
list = while_block->statements.get();
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*while_block)));
loop_stack.push_back(list);
}
else if (auto function_definition = std::get_if<ast::function_definition>(statement.statement.get()))
{
@ -229,6 +243,18 @@ namespace pslang::parser
{
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*foreign_function_declaration)));
}
else if (auto break_statement = std::get_if<ast::break_statement>(statement.statement.get()))
{
if (loop_stack.empty())
throw parse_error("Break without an enclosing loop", break_statement->location);
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*break_statement)));
}
else if (auto continue_statement = std::get_if<ast::continue_statement>(statement.statement.get()))
{
if (loop_stack.empty())
throw parse_error("Continue without an enclosing loop", continue_statement->location);
current_statement_list(location)->statements.push_back(std::make_unique<ast::statement>(std::move(*continue_statement)));
}
else
{
throw ast::invalid_ast_error(std::format("Unknown pre-statement \"{}\"", std::visit([](auto const & statement){ return typeid(statement).name(); }, *statement.statement)), location);

View file

@ -148,8 +148,11 @@ Flow control:
while condition:
statements
if x:
break
if y:
continue
TODO: break/continue?
TODO: for loops? iterator/range interface?
======== STRUCTS ========