Implement ternary ifs

This commit is contained in:
Nikita Lisitsa 2026-04-02 20:19:55 +03:00
parent b3fc435dd6
commit 8ba5345daf
13 changed files with 123 additions and 27 deletions

View file

@ -478,17 +478,16 @@ func intersect_sphere(ray: ray, object: object*) -> intersection:
let t2 = (- b + sqrt(D)) / (2.0 * a)
if t2 < 0.0:
return no_intersection
mut t = t1
if t1 < 0.0:
t = t2
let t = if t1 < 0.0 then t2 else t1
let normal = normalized(add(delta, mults(ray.direction, t)))
return intersection(true, t, normal, &(*object).material)
func intersect_box(ray: ray, object: object*) -> intersection:
func swap(x: f32 mut*, y: f32 mut*):
let temp = *x
*x = *y
*y = temp
func sort(x: f32 mut*, y: f32 mut*):
if *x > *y:
let temp = *x
*x = *y
*y = temp
let shape = &(*object).shape.data as box*
let inverse_rotation = inverse((*object).rotation)
@ -503,12 +502,9 @@ func intersect_box(ray: ray, object: object*) -> intersection:
mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z
mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z
if txmin > txmax:
swap(&mut txmin, &mut txmax)
if tymin > tymax:
swap(&mut tymin, &mut tymax)
if tzmin > tzmax:
swap(&mut tzmin, &mut tzmax)
sort(&mut txmin, &mut txmax)
sort(&mut tymin, &mut tymax)
sort(&mut tzmin, &mut tzmax)
let tmin = max(txmin, max(tymin, tzmin))
let tmax = min(txmax, min(tymax, tzmax))
@ -517,14 +513,10 @@ func intersect_box(ray: ray, object: object*) -> intersection:
return no_intersection
let inside = tmin < 0.0
mut t = tmin
mut tt = vec3(txmin, tymin, tzmin)
// No ternary if yet...
if inside:
t = tmax
tt = vec3(txmax, tymax, tzmax)
let t = if inside then tmax else tmin
let tt = if inside then vec3(txmax, tymax, tzmax) else vec3(txmin, tymin, tzmin)
mut normal = vec3(0.0, 0.0, 0.0)
mut normal = vec3()
if t == tt.x:
if local_direction.x > 0.0:
@ -545,7 +537,7 @@ func intersect_box(ray: ray, object: object*) -> intersection:
if inside:
normal = mults(normal, -1.0)
return intersection(true, tmin, rotate((*object).rotation, normal), &(*object).material)
return intersection(true, t, rotate((*object).rotation, normal), &(*object).material)
func intersect_scene(scene: scene*, ray: ray) -> intersection:
mut intersection = no_intersection
@ -776,8 +768,8 @@ func main():
let scene = make_default_scene()
// let image = create_image(512ul, 512ul)
let image = create_image(256ul, 256ul)
// let image = create_image(128ul, 128ul)
// let image = create_image(256ul, 256ul)
let image = create_image(128ul, 128ul)
let aspect_ratio = (image.width as f32) / (image.height as f32)
let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio))

View file

@ -3,10 +3,20 @@
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/ast/statement_fwd.hpp>
#include <pslang/ast/location.hpp>
#include <pslang/types/type_fwd.hpp>
namespace pslang::ast
{
struct if_expression
{
expression_ptr condition;
expression_ptr if_true;
expression_ptr if_false;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
struct if_block
{
expression_ptr condition;

View file

@ -8,6 +8,7 @@
#include <pslang/ast/function.hpp>
#include <pslang/ast/array.hpp>
#include <pslang/ast/struct.hpp>
#include <pslang/ast/control.hpp>
#include <pslang/ast/expression_fwd.hpp>
#include <pslang/types/type_fwd.hpp>
@ -40,7 +41,8 @@ namespace pslang::ast
function_call,
array,
array_access,
field_access
field_access,
if_expression
>;
struct expression

View file

@ -58,6 +58,11 @@ namespace pslang::ast
{
return node.location;
}
location apply(if_expression const & node)
{
return node.location;
}
};
struct get_type_visitor
@ -110,6 +115,11 @@ namespace pslang::ast
{
return node.inferred_type;
}
types::type_ptr apply(if_expression const & node)
{
return node.inferred_type;
}
};
}

View file

@ -358,6 +358,15 @@ namespace pslang::ast
out << "field access { name = \"" << node.field_name << "\" }\n";
child(*node.object);
}
void apply(if_expression const & node)
{
put_indent(out, options);
out << "if\n";
child(*node.condition);
child(*node.if_true);
child(*node.if_false);
}
};
struct statement_print_visitor

View file

@ -281,6 +281,13 @@ namespace pslang::ast
apply(*field_access.object);
}
void apply(if_expression const & if_expression)
{
apply(*if_expression.condition);
apply(*if_expression.if_true);
apply(*if_expression.if_false);
}
void apply(expression_ptr const & expression_ptr)
{
apply(*expression_ptr);

View file

@ -783,6 +783,37 @@ namespace pslang::ast
throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location);
}
void apply(if_expression & node)
{
apply(*node.condition);
apply(*node.if_true);
apply(*node.if_false);
auto condition_type = get_type(*node.condition);
if (!types::is_bool_type(*condition_type))
{
std::ostringstream os;
os << "if condition expects a bool type, but got ";
print(os, *condition_type);
throw type_error(os.str(), get_location(*node.condition));
}
auto true_type = get_type(*node.if_true);
auto false_type = get_type(*node.if_false);
if (!types::equal(*true_type, *false_type))
{
std::ostringstream os;
os << "Both ternary if cases must have the same type, but got ";
print(os, *true_type);
os << " and ";
print(os, *false_type);
throw type_error(os.str(), node.location);
}
node.inferred_type = true_type;
}
void apply(expression_ptr const & node)
{
apply(*node);

View file

@ -686,6 +686,11 @@ namespace pslang::interpreter
throw internal_error(os.str(), field_access.location);
}
value eval_impl(context & context, ast::if_expression const & if_expression)
{
throw std::runtime_error("Not implemented");
}
value eval_impl(context & context, ast::expression_ptr const & expression)
{
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);
@ -774,6 +779,11 @@ namespace pslang::interpreter
throw internal_error(os.str(), field_access.location);
}
value * eval_ref_impl(context & context, ast::if_expression const & if_expression)
{
throw std::runtime_error("Not implemented");
}
value * eval_ref_impl(context & context, ast::expression_ptr const & expression)
{
return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression);

View file

@ -473,7 +473,28 @@ namespace pslang::ir
throw std::runtime_error("Unknown field name");
}
// TODO: array, array_access, field_access
node_ref apply(ast::if_expression const & node)
{
mcontext.nodes->emplace_back(alloc{}, node.inferred_type);
auto result = last();
auto condition = apply(*node.condition);
mcontext.nodes->emplace_back(jump_if_zero{condition});
auto jump_skip_if_true = last();
auto true_result = apply(*node.if_true);
mcontext.nodes->emplace_back(assignment{result, true_result}, node.inferred_type);
mcontext.nodes->emplace_back(jump{});
auto jump_end_if_true = last();
mcontext.nodes->emplace_back(label{});
auto after_if_true = last();
auto false_result = apply(*node.if_false);
mcontext.nodes->emplace_back(assignment{result, false_result}, node.inferred_type);
mcontext.nodes->emplace_back(label{});
auto end = last();
std::get<jump_if_zero>(jump_skip_if_true->instruction).target = after_if_true;
std::get<jump>(jump_end_if_true->instruction).target = end;
return result;
}
// Statements

View file

@ -312,7 +312,7 @@ namespace pslang::jit::aarch64
void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &)
{
// Nothing to do: alloc just allocates a node of a struct type,
// Nothing to do: alloc just allocates a node of some type,
// but we already allocated stack space for it
}

View file

@ -24,7 +24,7 @@ bison_target(
${PSLANG_PARSER_RULES_FILE}
${PSLANG_PARSER_SOURCE_FILE}
DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE}
# COMPILE_FLAGS -Wcounterexamples
COMPILE_FLAGS -Wcounterexamples
)
add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser)

View file

@ -28,6 +28,7 @@ let { return bp::make_let(ctx.location); }
mut { return bp::make_mut(ctx.location); }
if { return bp::make_if(ctx.location); }
else { return bp::make_else(ctx.location); }
then { return bp::make_then(ctx.location); }
while { return bp::make_while(ctx.location); }
break { return bp::make_break(ctx.location); }
continue { return bp::make_continue(ctx.location); }

View file

@ -139,6 +139,7 @@ template <typename T>
%token mut
%token if
%token else
%token then
%token while
%token break
%token continue
@ -177,6 +178,7 @@ template <typename T>
%left asterisk slash percent
%precedence NOT
%precedence ADDRESSOF DEREFERENCE
%precedence else
%precedence lbracket
%type <indented_statement_list> indented_statement_list
@ -352,6 +354,7 @@ expression
| ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique<ast::expression>($3), @$ }; }
| asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique<ast::expression>($2), @$ }; }
| postfix_expression { $$ = $1; }
| if expression then expression else expression { $$ = ast::if_expression{ std::make_unique<ast::expression>($2), std::make_unique<ast::expression>($4), std::make_unique<ast::expression>($6), @$ }; }
;
postfix_expression