Implement ternary ifs
This commit is contained in:
parent
b3fc435dd6
commit
8ba5345daf
13 changed files with 123 additions and 27 deletions
|
|
@ -478,17 +478,16 @@ func intersect_sphere(ray: ray, object: object*) -> intersection:
|
||||||
let t2 = (- b + sqrt(D)) / (2.0 * a)
|
let t2 = (- b + sqrt(D)) / (2.0 * a)
|
||||||
if t2 < 0.0:
|
if t2 < 0.0:
|
||||||
return no_intersection
|
return no_intersection
|
||||||
mut t = t1
|
let t = if t1 < 0.0 then t2 else t1
|
||||||
if t1 < 0.0:
|
|
||||||
t = t2
|
|
||||||
let normal = normalized(add(delta, mults(ray.direction, t)))
|
let normal = normalized(add(delta, mults(ray.direction, t)))
|
||||||
return intersection(true, t, normal, &(*object).material)
|
return intersection(true, t, normal, &(*object).material)
|
||||||
|
|
||||||
func intersect_box(ray: ray, object: object*) -> intersection:
|
func intersect_box(ray: ray, object: object*) -> intersection:
|
||||||
func swap(x: f32 mut*, y: f32 mut*):
|
func sort(x: f32 mut*, y: f32 mut*):
|
||||||
let temp = *x
|
if *x > *y:
|
||||||
*x = *y
|
let temp = *x
|
||||||
*y = temp
|
*x = *y
|
||||||
|
*y = temp
|
||||||
|
|
||||||
let shape = &(*object).shape.data as box*
|
let shape = &(*object).shape.data as box*
|
||||||
let inverse_rotation = inverse((*object).rotation)
|
let inverse_rotation = inverse((*object).rotation)
|
||||||
|
|
@ -503,12 +502,9 @@ func intersect_box(ray: ray, object: object*) -> intersection:
|
||||||
mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z
|
mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z
|
||||||
mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z
|
mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z
|
||||||
|
|
||||||
if txmin > txmax:
|
sort(&mut txmin, &mut txmax)
|
||||||
swap(&mut txmin, &mut txmax)
|
sort(&mut tymin, &mut tymax)
|
||||||
if tymin > tymax:
|
sort(&mut tzmin, &mut tzmax)
|
||||||
swap(&mut tymin, &mut tymax)
|
|
||||||
if tzmin > tzmax:
|
|
||||||
swap(&mut tzmin, &mut tzmax)
|
|
||||||
|
|
||||||
let tmin = max(txmin, max(tymin, tzmin))
|
let tmin = max(txmin, max(tymin, tzmin))
|
||||||
let tmax = min(txmax, min(tymax, tzmax))
|
let tmax = min(txmax, min(tymax, tzmax))
|
||||||
|
|
@ -517,14 +513,10 @@ func intersect_box(ray: ray, object: object*) -> intersection:
|
||||||
return no_intersection
|
return no_intersection
|
||||||
|
|
||||||
let inside = tmin < 0.0
|
let inside = tmin < 0.0
|
||||||
mut t = tmin
|
let t = if inside then tmax else tmin
|
||||||
mut tt = vec3(txmin, tymin, tzmin)
|
let tt = if inside then vec3(txmax, tymax, tzmax) else vec3(txmin, tymin, tzmin)
|
||||||
// No ternary if yet...
|
|
||||||
if inside:
|
|
||||||
t = tmax
|
|
||||||
tt = vec3(txmax, tymax, tzmax)
|
|
||||||
|
|
||||||
mut normal = vec3(0.0, 0.0, 0.0)
|
mut normal = vec3()
|
||||||
|
|
||||||
if t == tt.x:
|
if t == tt.x:
|
||||||
if local_direction.x > 0.0:
|
if local_direction.x > 0.0:
|
||||||
|
|
@ -545,7 +537,7 @@ func intersect_box(ray: ray, object: object*) -> intersection:
|
||||||
if inside:
|
if inside:
|
||||||
normal = mults(normal, -1.0)
|
normal = mults(normal, -1.0)
|
||||||
|
|
||||||
return intersection(true, tmin, rotate((*object).rotation, normal), &(*object).material)
|
return intersection(true, t, rotate((*object).rotation, normal), &(*object).material)
|
||||||
|
|
||||||
func intersect_scene(scene: scene*, ray: ray) -> intersection:
|
func intersect_scene(scene: scene*, ray: ray) -> intersection:
|
||||||
mut intersection = no_intersection
|
mut intersection = no_intersection
|
||||||
|
|
@ -776,8 +768,8 @@ func main():
|
||||||
let scene = make_default_scene()
|
let scene = make_default_scene()
|
||||||
|
|
||||||
// let image = create_image(512ul, 512ul)
|
// let image = create_image(512ul, 512ul)
|
||||||
let image = create_image(256ul, 256ul)
|
// let image = create_image(256ul, 256ul)
|
||||||
// let image = create_image(128ul, 128ul)
|
let image = create_image(128ul, 128ul)
|
||||||
|
|
||||||
let aspect_ratio = (image.width as f32) / (image.height as f32)
|
let aspect_ratio = (image.width as f32) / (image.height as f32)
|
||||||
let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio))
|
let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio))
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,20 @@
|
||||||
#include <pslang/ast/expression_fwd.hpp>
|
#include <pslang/ast/expression_fwd.hpp>
|
||||||
#include <pslang/ast/statement_fwd.hpp>
|
#include <pslang/ast/statement_fwd.hpp>
|
||||||
#include <pslang/ast/location.hpp>
|
#include <pslang/ast/location.hpp>
|
||||||
|
#include <pslang/types/type_fwd.hpp>
|
||||||
|
|
||||||
namespace pslang::ast
|
namespace pslang::ast
|
||||||
{
|
{
|
||||||
|
|
||||||
|
struct if_expression
|
||||||
|
{
|
||||||
|
expression_ptr condition;
|
||||||
|
expression_ptr if_true;
|
||||||
|
expression_ptr if_false;
|
||||||
|
ast::location location;
|
||||||
|
types::type_ptr inferred_type = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
struct if_block
|
struct if_block
|
||||||
{
|
{
|
||||||
expression_ptr condition;
|
expression_ptr condition;
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <pslang/ast/function.hpp>
|
#include <pslang/ast/function.hpp>
|
||||||
#include <pslang/ast/array.hpp>
|
#include <pslang/ast/array.hpp>
|
||||||
#include <pslang/ast/struct.hpp>
|
#include <pslang/ast/struct.hpp>
|
||||||
|
#include <pslang/ast/control.hpp>
|
||||||
#include <pslang/ast/expression_fwd.hpp>
|
#include <pslang/ast/expression_fwd.hpp>
|
||||||
#include <pslang/types/type_fwd.hpp>
|
#include <pslang/types/type_fwd.hpp>
|
||||||
|
|
||||||
|
|
@ -40,7 +41,8 @@ namespace pslang::ast
|
||||||
function_call,
|
function_call,
|
||||||
array,
|
array,
|
||||||
array_access,
|
array_access,
|
||||||
field_access
|
field_access,
|
||||||
|
if_expression
|
||||||
>;
|
>;
|
||||||
|
|
||||||
struct expression
|
struct expression
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,11 @@ namespace pslang::ast
|
||||||
{
|
{
|
||||||
return node.location;
|
return node.location;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
location apply(if_expression const & node)
|
||||||
|
{
|
||||||
|
return node.location;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct get_type_visitor
|
struct get_type_visitor
|
||||||
|
|
@ -110,6 +115,11 @@ namespace pslang::ast
|
||||||
{
|
{
|
||||||
return node.inferred_type;
|
return node.inferred_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
types::type_ptr apply(if_expression const & node)
|
||||||
|
{
|
||||||
|
return node.inferred_type;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -358,6 +358,15 @@ namespace pslang::ast
|
||||||
out << "field access { name = \"" << node.field_name << "\" }\n";
|
out << "field access { name = \"" << node.field_name << "\" }\n";
|
||||||
child(*node.object);
|
child(*node.object);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(if_expression const & node)
|
||||||
|
{
|
||||||
|
put_indent(out, options);
|
||||||
|
out << "if\n";
|
||||||
|
child(*node.condition);
|
||||||
|
child(*node.if_true);
|
||||||
|
child(*node.if_false);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct statement_print_visitor
|
struct statement_print_visitor
|
||||||
|
|
|
||||||
|
|
@ -281,6 +281,13 @@ namespace pslang::ast
|
||||||
apply(*field_access.object);
|
apply(*field_access.object);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(if_expression const & if_expression)
|
||||||
|
{
|
||||||
|
apply(*if_expression.condition);
|
||||||
|
apply(*if_expression.if_true);
|
||||||
|
apply(*if_expression.if_false);
|
||||||
|
}
|
||||||
|
|
||||||
void apply(expression_ptr const & expression_ptr)
|
void apply(expression_ptr const & expression_ptr)
|
||||||
{
|
{
|
||||||
apply(*expression_ptr);
|
apply(*expression_ptr);
|
||||||
|
|
|
||||||
|
|
@ -783,6 +783,37 @@ namespace pslang::ast
|
||||||
throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location);
|
throw type_error(std::format("Struct \"{}\" has no field named \"{}\"", struct_node.name, node.field_name), node.location);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply(if_expression & node)
|
||||||
|
{
|
||||||
|
apply(*node.condition);
|
||||||
|
apply(*node.if_true);
|
||||||
|
apply(*node.if_false);
|
||||||
|
|
||||||
|
auto condition_type = get_type(*node.condition);
|
||||||
|
if (!types::is_bool_type(*condition_type))
|
||||||
|
{
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "if condition expects a bool type, but got ";
|
||||||
|
print(os, *condition_type);
|
||||||
|
throw type_error(os.str(), get_location(*node.condition));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto true_type = get_type(*node.if_true);
|
||||||
|
auto false_type = get_type(*node.if_false);
|
||||||
|
|
||||||
|
if (!types::equal(*true_type, *false_type))
|
||||||
|
{
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "Both ternary if cases must have the same type, but got ";
|
||||||
|
print(os, *true_type);
|
||||||
|
os << " and ";
|
||||||
|
print(os, *false_type);
|
||||||
|
throw type_error(os.str(), node.location);
|
||||||
|
}
|
||||||
|
|
||||||
|
node.inferred_type = true_type;
|
||||||
|
}
|
||||||
|
|
||||||
void apply(expression_ptr const & node)
|
void apply(expression_ptr const & node)
|
||||||
{
|
{
|
||||||
apply(*node);
|
apply(*node);
|
||||||
|
|
|
||||||
|
|
@ -686,6 +686,11 @@ namespace pslang::interpreter
|
||||||
throw internal_error(os.str(), field_access.location);
|
throw internal_error(os.str(), field_access.location);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value eval_impl(context & context, ast::if_expression const & if_expression)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
value eval_impl(context & context, ast::expression_ptr const & expression)
|
value eval_impl(context & context, ast::expression_ptr const & expression)
|
||||||
{
|
{
|
||||||
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);
|
return std::visit([&](auto const & expression){ return eval_impl(context, expression); }, *expression);
|
||||||
|
|
@ -774,6 +779,11 @@ namespace pslang::interpreter
|
||||||
throw internal_error(os.str(), field_access.location);
|
throw internal_error(os.str(), field_access.location);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value * eval_ref_impl(context & context, ast::if_expression const & if_expression)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("Not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
value * eval_ref_impl(context & context, ast::expression_ptr const & expression)
|
value * eval_ref_impl(context & context, ast::expression_ptr const & expression)
|
||||||
{
|
{
|
||||||
return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression);
|
return std::visit([&](auto const & expression){ return eval_ref_impl(context, expression); }, *expression);
|
||||||
|
|
|
||||||
|
|
@ -473,7 +473,28 @@ namespace pslang::ir
|
||||||
throw std::runtime_error("Unknown field name");
|
throw std::runtime_error("Unknown field name");
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: array, array_access, field_access
|
node_ref apply(ast::if_expression const & node)
|
||||||
|
{
|
||||||
|
mcontext.nodes->emplace_back(alloc{}, node.inferred_type);
|
||||||
|
auto result = last();
|
||||||
|
auto condition = apply(*node.condition);
|
||||||
|
mcontext.nodes->emplace_back(jump_if_zero{condition});
|
||||||
|
auto jump_skip_if_true = last();
|
||||||
|
auto true_result = apply(*node.if_true);
|
||||||
|
mcontext.nodes->emplace_back(assignment{result, true_result}, node.inferred_type);
|
||||||
|
mcontext.nodes->emplace_back(jump{});
|
||||||
|
auto jump_end_if_true = last();
|
||||||
|
mcontext.nodes->emplace_back(label{});
|
||||||
|
auto after_if_true = last();
|
||||||
|
auto false_result = apply(*node.if_false);
|
||||||
|
mcontext.nodes->emplace_back(assignment{result, false_result}, node.inferred_type);
|
||||||
|
mcontext.nodes->emplace_back(label{});
|
||||||
|
auto end = last();
|
||||||
|
|
||||||
|
std::get<jump_if_zero>(jump_skip_if_true->instruction).target = after_if_true;
|
||||||
|
std::get<jump>(jump_end_if_true->instruction).target = end;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Statements
|
// Statements
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -312,7 +312,7 @@ namespace pslang::jit::aarch64
|
||||||
|
|
||||||
void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &)
|
void apply(ir::node_ref it, ir::alloc const & node, types::type_ptr const &)
|
||||||
{
|
{
|
||||||
// Nothing to do: alloc just allocates a node of a struct type,
|
// Nothing to do: alloc just allocates a node of some type,
|
||||||
// but we already allocated stack space for it
|
// but we already allocated stack space for it
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ bison_target(
|
||||||
${PSLANG_PARSER_RULES_FILE}
|
${PSLANG_PARSER_RULES_FILE}
|
||||||
${PSLANG_PARSER_SOURCE_FILE}
|
${PSLANG_PARSER_SOURCE_FILE}
|
||||||
DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE}
|
DEFINES_FILE ${PSLANG_PARSER_HEADER_FILE}
|
||||||
# COMPILE_FLAGS -Wcounterexamples
|
COMPILE_FLAGS -Wcounterexamples
|
||||||
)
|
)
|
||||||
|
|
||||||
add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser)
|
add_flex_bison_dependency(generate-pslang-lexer generate-pslang-parser)
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ let { return bp::make_let(ctx.location); }
|
||||||
mut { return bp::make_mut(ctx.location); }
|
mut { return bp::make_mut(ctx.location); }
|
||||||
if { return bp::make_if(ctx.location); }
|
if { return bp::make_if(ctx.location); }
|
||||||
else { return bp::make_else(ctx.location); }
|
else { return bp::make_else(ctx.location); }
|
||||||
|
then { return bp::make_then(ctx.location); }
|
||||||
while { return bp::make_while(ctx.location); }
|
while { return bp::make_while(ctx.location); }
|
||||||
break { return bp::make_break(ctx.location); }
|
break { return bp::make_break(ctx.location); }
|
||||||
continue { return bp::make_continue(ctx.location); }
|
continue { return bp::make_continue(ctx.location); }
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,7 @@ template <typename T>
|
||||||
%token mut
|
%token mut
|
||||||
%token if
|
%token if
|
||||||
%token else
|
%token else
|
||||||
|
%token then
|
||||||
%token while
|
%token while
|
||||||
%token break
|
%token break
|
||||||
%token continue
|
%token continue
|
||||||
|
|
@ -177,6 +178,7 @@ template <typename T>
|
||||||
%left asterisk slash percent
|
%left asterisk slash percent
|
||||||
%precedence NOT
|
%precedence NOT
|
||||||
%precedence ADDRESSOF DEREFERENCE
|
%precedence ADDRESSOF DEREFERENCE
|
||||||
|
%precedence else
|
||||||
%precedence lbracket
|
%precedence lbracket
|
||||||
|
|
||||||
%type <indented_statement_list> indented_statement_list
|
%type <indented_statement_list> indented_statement_list
|
||||||
|
|
@ -352,6 +354,7 @@ expression
|
||||||
| ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique<ast::expression>($3), @$ }; }
|
| ampersand mut expression %prec ADDRESSOF { $$ = ast::unary_operation{ast::unary_operation_type::mutable_address_of, std::make_unique<ast::expression>($3), @$ }; }
|
||||||
| asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique<ast::expression>($2), @$ }; }
|
| asterisk expression %prec DEREFERENCE { $$ = ast::unary_operation{ast::unary_operation_type::dereference, std::make_unique<ast::expression>($2), @$ }; }
|
||||||
| postfix_expression { $$ = $1; }
|
| postfix_expression { $$ = $1; }
|
||||||
|
| if expression then expression else expression { $$ = ast::if_expression{ std::make_unique<ast::expression>($2), std::make_unique<ast::expression>($4), std::make_unique<ast::expression>($6), @$ }; }
|
||||||
;
|
;
|
||||||
|
|
||||||
postfix_expression
|
postfix_expression
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue