Implement ternary ifs
This commit is contained in:
parent
b3fc435dd6
commit
8ba5345daf
13 changed files with 123 additions and 27 deletions
|
|
@ -478,17 +478,16 @@ func intersect_sphere(ray: ray, object: object*) -> intersection:
|
|||
let t2 = (- b + sqrt(D)) / (2.0 * a)
|
||||
if t2 < 0.0:
|
||||
return no_intersection
|
||||
mut t = t1
|
||||
if t1 < 0.0:
|
||||
t = t2
|
||||
let t = if t1 < 0.0 then t2 else t1
|
||||
let normal = normalized(add(delta, mults(ray.direction, t)))
|
||||
return intersection(true, t, normal, &(*object).material)
|
||||
|
||||
func intersect_box(ray: ray, object: object*) -> intersection:
|
||||
func swap(x: f32 mut*, y: f32 mut*):
|
||||
let temp = *x
|
||||
*x = *y
|
||||
*y = temp
|
||||
func sort(x: f32 mut*, y: f32 mut*):
|
||||
if *x > *y:
|
||||
let temp = *x
|
||||
*x = *y
|
||||
*y = temp
|
||||
|
||||
let shape = &(*object).shape.data as box*
|
||||
let inverse_rotation = inverse((*object).rotation)
|
||||
|
|
@ -503,12 +502,9 @@ func intersect_box(ray: ray, object: object*) -> intersection:
|
|||
mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z
|
||||
mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z
|
||||
|
||||
if txmin > txmax:
|
||||
swap(&mut txmin, &mut txmax)
|
||||
if tymin > tymax:
|
||||
swap(&mut tymin, &mut tymax)
|
||||
if tzmin > tzmax:
|
||||
swap(&mut tzmin, &mut tzmax)
|
||||
sort(&mut txmin, &mut txmax)
|
||||
sort(&mut tymin, &mut tymax)
|
||||
sort(&mut tzmin, &mut tzmax)
|
||||
|
||||
let tmin = max(txmin, max(tymin, tzmin))
|
||||
let tmax = min(txmax, min(tymax, tzmax))
|
||||
|
|
@ -517,14 +513,10 @@ func intersect_box(ray: ray, object: object*) -> intersection:
|
|||
return no_intersection
|
||||
|
||||
let inside = tmin < 0.0
|
||||
mut t = tmin
|
||||
mut tt = vec3(txmin, tymin, tzmin)
|
||||
// No ternary if yet...
|
||||
if inside:
|
||||
t = tmax
|
||||
tt = vec3(txmax, tymax, tzmax)
|
||||
let t = if inside then tmax else tmin
|
||||
let tt = if inside then vec3(txmax, tymax, tzmax) else vec3(txmin, tymin, tzmin)
|
||||
|
||||
mut normal = vec3(0.0, 0.0, 0.0)
|
||||
mut normal = vec3()
|
||||
|
||||
if t == tt.x:
|
||||
if local_direction.x > 0.0:
|
||||
|
|
@ -545,7 +537,7 @@ func intersect_box(ray: ray, object: object*) -> intersection:
|
|||
if inside:
|
||||
normal = mults(normal, -1.0)
|
||||
|
||||
return intersection(true, tmin, rotate((*object).rotation, normal), &(*object).material)
|
||||
return intersection(true, t, rotate((*object).rotation, normal), &(*object).material)
|
||||
|
||||
func intersect_scene(scene: scene*, ray: ray) -> intersection:
|
||||
mut intersection = no_intersection
|
||||
|
|
@ -776,8 +768,8 @@ func main():
|
|||
let scene = make_default_scene()
|
||||
|
||||
// let image = create_image(512ul, 512ul)
|
||||
let image = create_image(256ul, 256ul)
|
||||
// let image = create_image(128ul, 128ul)
|
||||
// let image = create_image(256ul, 256ul)
|
||||
let image = create_image(128ul, 128ul)
|
||||
|
||||
let aspect_ratio = (image.width as f32) / (image.height as f32)
|
||||
let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio))
|
||||
|
|
|
|||
|
|
@ -3,10 +3,20 @@
|
|||
#include <pslang/ast/expression_fwd.hpp>
|
||||
#include <pslang/ast/statement_fwd.hpp>
|
||||
#include <pslang/ast/location.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
namespace pslang::ast
|
||||
{
|
||||
|
||||
struct if_expression
|
||||
{
|
||||
expression_ptr condition;
|
||||
expression_ptr if_true;
|
||||
expression_ptr if_false;
|
||||
ast::location location;
|
||||
types::type_ptr inferred_type = nullptr;
|
||||
};
|
||||
|
||||
struct if_block
|
||||
{
|
||||
expression_ptr condition;
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <pslang/ast/function.hpp>
|
||||
#include <pslang/ast/array.hpp>
|
||||
#include <pslang/ast/struct.hpp>
|
||||
#include <pslang/ast/control.hpp>
|
||||
#include <pslang/ast/expression_fwd.hpp>
|
||||
#include <pslang/types/type_fwd.hpp>
|
||||
|
||||
|
|
@ -40,7 +41,8 @@ namespace pslang::ast
|
|||
function_call,
|
||||
array,
|
||||
array_access,
|
||||
field_access
|
||||
field_access,
|
||||
if_expression
|
||||
>;
|
||||
|
||||
struct expression
|
||||
|
|
|
|||
|
|
@ -58,6 +58,11 @@ namespace pslang::ast
|
|||
{
|
||||
return node.location;
|
||||
}
|
||||
|
||||
location apply(if_expression const & node)
|
||||
{
|
||||
return node.location;
|
||||
}
|
||||
};
|
||||
|
||||
struct get_type_visitor
|
||||
|
|
@ -110,6 +115,11 @@ namespace pslang::ast
|
|||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
|
||||
types::type_ptr apply(if_expression const & node)
|
||||
{
|
||||
return node.inferred_type;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -358,6 +358,15 @@ namespace pslang::ast
|
|||
out << "field access { name = \"" << node.field_name << "\" }\n";
|
||||
child(*node.object);
|
||||
}
|
||||
|
||||
void apply(if_expression const & node)
|
||||
{
|
||||
put_indent(out, options);
|
||||
out << "if\n";
|
||||
child(*node.condition);
|
||||
child(*node.if_true);
|
||||
child(*node.if_false);
|
||||
}
|
||||
};
|
||||
|
||||
struct statement_print_visitor
|
||||
|
|
|
|||
|
|
@ -281,6 +281,13 @@ namespace pslang::ast
|
|||
apply(*field_access.object);
|
||||
}
|
||||
|
||||
void apply(if_expression const & if_expression)
|
||||
{
|
||||
apply(*if_expression.condition);
|
||||
apply(*if_expression.if_true);
|
||||
apply(*if_expression.if_false);
|
||||
}
|
||||
|
||||
void apply(expression_ptr const & expression_ptr)
|
||||
{
|
||||
apply(*expression_ptr);
|
||||
|
|
|
|||
|
|
@ -783,6 +783,37 @@ namespace pslang::ast
|
|||
throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location);
|
||||
}
|
||||
|
||||
void apply(if_expression & node)
|
||||
{
|
||||
apply(*node.condition);
|
||||
apply(*node.if_true);
|
||||
apply(*node.if_false);
|
||||
|
||||
auto condition_type = get_type(*node.condition);
|
||||
if (!types::is_bool_type(*condition_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "if condition expects a bool type, but got ";
|
||||
print(os, *condition_type);
|
||||
throw type_error(os.str(), get_location(*node.condition));
|
||||
}
|
||||
|
||||
auto true_type = get_type(*node.if_true);
|
||||
auto false_type = get_type(*node.if_false);
|
||||
|
||||
if (!types::equal(*true_type, *false_type))
|
||||
{
|
||||
std::ostringstream os;
|
||||
os << "Both ternary if cases must have the same type, but got ";
|
||||
print(os, *true_type);
|
||||
os << " and ";
|
||||
print(os, *false_type);
|
||||
throw type_error(os.str(), node.location);
|
||||
}
|
||||
|
||||
node.inferred_type = true_type;
|
||||
}
|
||||
|
||||
void apply(expression_ptr const & node)
|
||||
{
|
||||
apply(*node);
|
||||
|
|
|
|||
|
|
@ -686,6 +686,11 @@ namespace pslang::interpreter
|
|||
throw internal_error(os.str(), field_access.location);
|
||||
}
|
||||
|
||||
value eval_impl(context & context, ast::if_expression const & if_expression)
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
value eval_impl(context & context, ast::expression_ptr const & expression)
|
||||
{
|
||||
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);
|
||||
|
|
@ -774,6 +779,11 @@ namespace pslang::interpreter
|
|||
throw internal_error(os.str(), field_access.location);
|
||||
}
|
||||
|
||||
value * eval_ref_impl(context & context, ast::if_expression const & if_expression)
|
||||
{
|
||||
throw std::runtime_error("Not implemented");
|
||||
}
|
||||
|
||||
value * eval_ref_impl(context & context, ast::expression_ptr const & expression)
|
||||
{
|
||||
return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression);
|
||||
|
|
|
|||
|
|
@ -473,7 +473,28 @@ namespace pslang::ir
|
|||
throw std::runtime_error("Unknown field name");
|
||||
}
|
||||
|
||||
// TODO: array, array_access, field_access
|
||||
node_ref apply(ast::if_expression const & node)
|
||||
{
|
||||
mcontext.nodes->emplace_back(alloc{}, node.inferred_type);
|
||||
auto result = last();
|
||||
auto condition = apply(*node.condition);
|
||||
mcontext.nodes->emplace_back(jump_if_zero{condition});
|
||||
auto jump_skip_if_true = last();
|
||||
auto true_result = apply(*node.if_true);
|
||||
mcontext.nodes->emplace_back(assignment{result, true_result}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(jump{});
|
||||
auto jump_end_if_true = last();
|
||||
mcontext.nodes->emplace_back(label{});
|
||||
auto after_if_true = last();
|
||||
auto false_result = apply(*node.if_false);
|
||||
mcontext.nodes->emplace_back(assignment{result, false_result}, node.inferred_type);
|
||||
mcontext.nodes->emplace_back(label{});
|
||||
auto end = last();
|
||||
|
||||
std::get<jump_if_zero>(jump_skip_if_true->instruction).target = after_if_true;
|
||||
std::get<jump>(jump_end_if_true->instruction).target = end;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Statements
|
||||
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ namespace pslang::jit::aarch64
|
|||
|
||||
void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &)
|
||||
{
|
||||
// Nothing to do: alloc just allocates a node of a struct type,
|
||||
// Nothing to do: alloc just allocates a node of some type,
|
||||
// but we already allocated stack space for it
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ bison_target(
|
|||
${PSLANG_PARSER_RULES_FILE}
|
||||
${PSLANG_PARSER_SOURCE_FILE}
|
||||
DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE}
|
||||
# COMPILE_FLAGS -Wcounterexamples
|
||||
COMPILE_FLAGS -Wcounterexamples
|
||||
)
|
||||
|
||||
add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ let { return bp::make_let(ctx.location); }
|
|||
mut { return bp::make_mut(ctx.location); }
|
||||
if { return bp::make_if(ctx.location); }
|
||||
else { return bp::make_else(ctx.location); }
|
||||
then { return bp::make_then(ctx.location); }
|
||||
while { return bp::make_while(ctx.location); }
|
||||
break { return bp::make_break(ctx.location); }
|
||||
continue { return bp::make_continue(ctx.location); }
|
||||
|
|
|
|||
|
|
@ -139,6 +139,7 @@ template <typename T>
|
|||
%token mut
|
||||
%token if
|
||||
%token else
|
||||
%token then
|
||||
%token while
|
||||
%token break
|
||||
%token continue
|
||||
|
|
@ -177,6 +178,7 @@ template <typename T>
|
|||
%left asterisk slash percent
|
||||
%precedence NOT
|
||||
%precedence ADDRESSOF DEREFERENCE
|
||||
%precedence else
|
||||
%precedence lbracket
|
||||
|
||||
%type <indented_statement_list> indented_statement_list
|
||||
|
|
@ -352,6 +354,7 @@ expression
|
|||
| ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique<ast::expression>($3), @$ }; }
|
||||
| asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique<ast::expression>($2), @$ }; }
|
||||
| postfix_expression { $$ = $1; }
|
||||
| if expression then expression else expression { $$ = ast::if_expression{ std::make_unique<ast::expression>($2), std::make_unique<ast::expression>($4), std::make_unique<ast::expression>($6), @$ }; }
|
||||
;
|
||||
|
||||
postfix_expression
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue