Implement ternary ifs

This commit is contained in:
Nikita Lisitsa 2026-04-02 20:19:55 +03:00
parent b3fc435dd6
commit 8ba5345daf
13 changed files with 123 additions and 27 deletions

View file

@ -478,14 +478,13 @@ func intersect_sphere(ray: ray, object: object*) -> intersection:
let t2 = (- b + sqrt(D)) / (2.0 * a) let t2 = (- b + sqrt(D)) / (2.0 * a)
if t2 < 0.0: if t2 < 0.0:
return no_intersection return no_intersection
mut t = t1 let t = if t1 < 0.0 then t2 else t1
if t1 < 0.0:
t = t2
let normal = normalized(add(delta, mults(ray.direction, t))) let normal = normalized(add(delta, mults(ray.direction, t)))
return intersection(true, t, normal, &(*object).material) return intersection(true, t, normal, &(*object).material)
func intersect_box(ray: ray, object: object*) -> intersection: func intersect_box(ray: ray, object: object*) -> intersection:
func swap(x: f32 mut*, y: f32 mut*): func sort(x: f32 mut*, y: f32 mut*):
if *x > *y:
let temp = *x let temp = *x
*x = *y *x = *y
*y = temp *y = temp
@ -503,12 +502,9 @@ func intersect_box(ray: ray, object: object*) -> intersection:
mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z
mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z
if txmin > txmax: sort(&mut txmin, &mut txmax)
swap(&mut txmin, &mut txmax) sort(&mut tymin, &mut tymax)
if tymin > tymax: sort(&mut tzmin, &mut tzmax)
swap(&mut tymin, &mut tymax)
if tzmin > tzmax:
swap(&mut tzmin, &mut tzmax)
let tmin = max(txmin, max(tymin, tzmin)) let tmin = max(txmin, max(tymin, tzmin))
let tmax = min(txmax, min(tymax, tzmax)) let tmax = min(txmax, min(tymax, tzmax))
@ -517,14 +513,10 @@ func intersect_box(ray: ray, object: object*) -> intersection:
return no_intersection return no_intersection
let inside = tmin < 0.0 let inside = tmin < 0.0
mut t = tmin let t = if inside then tmax else tmin
mut tt = vec3(txmin, tymin, tzmin) let tt = if inside then vec3(txmax, tymax, tzmax) else vec3(txmin, tymin, tzmin)
// No ternary if yet...
if inside:
t = tmax
tt = vec3(txmax, tymax, tzmax)
mut normal = vec3(0.0, 0.0, 0.0) mut normal = vec3()
if t == tt.x: if t == tt.x:
if local_direction.x > 0.0: if local_direction.x > 0.0:
@ -545,7 +537,7 @@ func intersect_box(ray: ray, object: object*) -> intersection:
if inside: if inside:
normal = mults(normal, -1.0) normal = mults(normal, -1.0)
return intersection(true, tmin, rotate((*object).rotation, normal), &(*object).material) return intersection(true, t, rotate((*object).rotation, normal), &(*object).material)
func intersect_scene(scene: scene*, ray: ray) -> intersection: func intersect_scene(scene: scene*, ray: ray) -> intersection:
mut intersection = no_intersection mut intersection = no_intersection
@ -776,8 +768,8 @@ func main():
let scene = make_default_scene() let scene = make_default_scene()
// let image = create_image(512ul, 512ul) // let image = create_image(512ul, 512ul)
let image = create_image(256ul, 256ul) // let image = create_image(256ul, 256ul)
// let image = create_image(128ul, 128ul) let image = create_image(128ul, 128ul)
let aspect_ratio = (image.width as f32) / (image.height as f32) let aspect_ratio = (image.width as f32) / (image.height as f32)
let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio)) let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio))

View file

@ -3,10 +3,20 @@
#include <pslang/ast/expression_fwd.hpp> #include <pslang/ast/expression_fwd.hpp>
#include <pslang/ast/statement_fwd.hpp> #include <pslang/ast/statement_fwd.hpp>
#include <pslang/ast/location.hpp> #include <pslang/ast/location.hpp>
#include <pslang/types/type_fwd.hpp>
namespace pslang::ast namespace pslang::ast
{ {
struct if_expression
{
expression_ptr condition;
expression_ptr if_true;
expression_ptr if_false;
ast::location location;
types::type_ptr inferred_type = nullptr;
};
struct if_block struct if_block
{ {
expression_ptr condition; expression_ptr condition;

View file

@ -8,6 +8,7 @@
#include <pslang/ast/function.hpp> #include <pslang/ast/function.hpp>
#include <pslang/ast/array.hpp> #include <pslang/ast/array.hpp>
#include <pslang/ast/struct.hpp> #include <pslang/ast/struct.hpp>
#include <pslang/ast/control.hpp>
#include <pslang/ast/expression_fwd.hpp> #include <pslang/ast/expression_fwd.hpp>
#include <pslang/types/type_fwd.hpp> #include <pslang/types/type_fwd.hpp>
@ -40,7 +41,8 @@ namespace pslang::ast
function_call, function_call,
array, array,
array_access, array_access,
field_access field_access,
if_expression
>; >;
struct expression struct expression

View file

@ -58,6 +58,11 @@ namespace pslang::ast
{ {
return node.location; return node.location;
} }
location apply(if_expression const & node)
{
return node.location;
}
}; };
struct get_type_visitor struct get_type_visitor
@ -110,6 +115,11 @@ namespace pslang::ast
{ {
return node.inferred_type; return node.inferred_type;
} }
types::type_ptr apply(if_expression const & node)
{
return node.inferred_type;
}
}; };
} }

View file

@ -358,6 +358,15 @@ namespace pslang::ast
out << "field access { name = \"" << node.field_name << "\" }\n"; out << "field access { name = \"" << node.field_name << "\" }\n";
child(*node.object); child(*node.object);
} }
void apply(if_expression const & node)
{
put_indent(out, options);
out << "if\n";
child(*node.condition);
child(*node.if_true);
child(*node.if_false);
}
}; };
struct statement_print_visitor struct statement_print_visitor

View file

@ -281,6 +281,13 @@ namespace pslang::ast
apply(*field_access.object); apply(*field_access.object);
} }
void apply(if_expression const & if_expression)
{
apply(*if_expression.condition);
apply(*if_expression.if_true);
apply(*if_expression.if_false);
}
void apply(expression_ptr const & expression_ptr) void apply(expression_ptr const & expression_ptr)
{ {
apply(*expression_ptr); apply(*expression_ptr);

View file

@ -783,6 +783,37 @@ namespace pslang::ast
throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location); throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location);
} }
void apply(if_expression & node)
{
apply(*node.condition);
apply(*node.if_true);
apply(*node.if_false);
auto condition_type = get_type(*node.condition);
if (!types::is_bool_type(*condition_type))
{
std::ostringstream os;
os << "if condition expects a bool type, but got ";
print(os, *condition_type);
throw type_error(os.str(), get_location(*node.condition));
}
auto true_type = get_type(*node.if_true);
auto false_type = get_type(*node.if_false);
if (!types::equal(*true_type, *false_type))
{
std::ostringstream os;
os << "Both ternary if cases must have the same type, but got ";
print(os, *true_type);
os << " and ";
print(os, *false_type);
throw type_error(os.str(), node.location);
}
node.inferred_type = true_type;
}
void apply(expression_ptr const & node) void apply(expression_ptr const & node)
{ {
apply(*node); apply(*node);

View file

@ -686,6 +686,11 @@ namespace pslang::interpreter
throw internal_error(os.str(), field_access.location); throw internal_error(os.str(), field_access.location);
} }
value eval_impl(context & context, ast::if_expression const & if_expression)
{
throw std::runtime_error("Not implemented");
}
value eval_impl(context & context, ast::expression_ptr const & expression) value eval_impl(context & context, ast::expression_ptr const & expression)
{ {
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression); return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);
@ -774,6 +779,11 @@ namespace pslang::interpreter
throw internal_error(os.str(), field_access.location); throw internal_error(os.str(), field_access.location);
} }
value * eval_ref_impl(context & context, ast::if_expression const & if_expression)
{
throw std::runtime_error("Not implemented");
}
value * eval_ref_impl(context & context, ast::expression_ptr const & expression) value * eval_ref_impl(context & context, ast::expression_ptr const & expression)
{ {
return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression); return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression);

View file

@ -473,7 +473,28 @@ namespace pslang::ir
throw std::runtime_error("Unknown field name"); throw std::runtime_error("Unknown field name");
} }
// TODO: array, array_access, field_access node_ref apply(ast::if_expression const & node)
{
mcontext.nodes->emplace_back(alloc{}, node.inferred_type);
auto result = last();
auto condition = apply(*node.condition);
mcontext.nodes->emplace_back(jump_if_zero{condition});
auto jump_skip_if_true = last();
auto true_result = apply(*node.if_true);
mcontext.nodes->emplace_back(assignment{result, true_result}, node.inferred_type);
mcontext.nodes->emplace_back(jump{});
auto jump_end_if_true = last();
mcontext.nodes->emplace_back(label{});
auto after_if_true = last();
auto false_result = apply(*node.if_false);
mcontext.nodes->emplace_back(assignment{result, false_result}, node.inferred_type);
mcontext.nodes->emplace_back(label{});
auto end = last();
std::get<jump_if_zero>(jump_skip_if_true->instruction).target = after_if_true;
std::get<jump>(jump_end_if_true->instruction).target = end;
return result;
}
// Statements // Statements

View file

@ -312,7 +312,7 @@ namespace pslang::jit::aarch64
void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &) void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &)
{ {
// Nothing to do: alloc just allocates a node of a struct type, // Nothing to do: alloc just allocates a node of some type,
// but we already allocated stack space for it // but we already allocated stack space for it
} }

View file

@ -24,7 +24,7 @@ bison_target(
${PSLANG_PARSER_RULES_FILE} ${PSLANG_PARSER_RULES_FILE}
${PSLANG_PARSER_SOURCE_FILE} ${PSLANG_PARSER_SOURCE_FILE}
DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE} DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE}
# COMPILE_FLAGS -Wcounterexamples COMPILE_FLAGS -Wcounterexamples
) )
add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser) add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser)

View file

@ -28,6 +28,7 @@ let { return bp::make_let(ctx.location); }
mut { return bp::make_mut(ctx.location); } mut { return bp::make_mut(ctx.location); }
if { return bp::make_if(ctx.location); } if { return bp::make_if(ctx.location); }
else { return bp::make_else(ctx.location); } else { return bp::make_else(ctx.location); }
then { return bp::make_then(ctx.location); }
while { return bp::make_while(ctx.location); } while { return bp::make_while(ctx.location); }
break { return bp::make_break(ctx.location); } break { return bp::make_break(ctx.location); }
continue { return bp::make_continue(ctx.location); } continue { return bp::make_continue(ctx.location); }

View file

@ -139,6 +139,7 @@ template <typename T>
%token mut %token mut
%token if %token if
%token else %token else
%token then
%token while %token while
%token break %token break
%token continue %token continue
@ -177,6 +178,7 @@ template <typename T>
%left asterisk slash percent %left asterisk slash percent
%precedence NOT %precedence NOT
%precedence ADDRESSOF DEREFERENCE %precedence ADDRESSOF DEREFERENCE
%precedence else
%precedence lbracket %precedence lbracket
%type <indented_statement_list> indented_statement_list %type <indented_statement_list> indented_statement_list
@ -352,6 +354,7 @@ expression
| ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique<ast::expression>($3), @$ }; } | ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique<ast::expression>($3), @$ }; }
| asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique<ast::expression>($2), @$ }; } | asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique<ast::expression>($2), @$ }; }
| postfix_expression { $$ = $1; } | postfix_expression { $$ = $1; }
| if expression then expression else expression { $$ = ast::if_expression{ std::make_unique<ast::expression>($2), std::make_unique<ast::expression>($4), std::make_unique<ast::expression>($6), @$ }; }
; ;
postfix_expression postfix_expression