Implement proper struct field assigment in IR compiler
This commit is contained in:
parent
79101ff3bd
commit
46c33b2acf
3 changed files with 100 additions and 41 deletions
|
|
@ -63,17 +63,21 @@ struct ray:
|
||||||
func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32:
|
func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32:
|
||||||
return (value - dot(ray.origin, normal)) / dot(ray.direction, normal)
|
return (value - dot(ray.origin, normal)) / dot(ray.direction, normal)
|
||||||
|
|
||||||
let t = intersect_plane(ray(vec3(0.0, 0.0, 5.0), vec3(0.0, 1.0, -1.0)), vec3(0.0, 0.0, 1.0), 0.0)
|
mut v = vec3(1.0, 2.0, 3.0)
|
||||||
|
mut u = vec3(1.0, 2.0, 3.0)
|
||||||
|
let p = &mut u
|
||||||
|
p[4].x = 4.0
|
||||||
|
print_vec3(v)
|
||||||
|
print('\n')
|
||||||
|
print_i32(((p + 5) - p) as i32)
|
||||||
|
print('\n')
|
||||||
|
|
||||||
let a = vec3(5.0, 1.0, 4.0)
|
// Assignment syntax:
|
||||||
let n = normalized(vec3(1.0, 2.0, 3.0))
|
// lhs = rhs
|
||||||
let b = add(a, mult(n, -dot(a, n)))
|
// lhs must be an lvalue
|
||||||
|
// An lvalue can be
|
||||||
|
// * An identifier
|
||||||
|
// * A pointer dereference
|
||||||
|
// * A field of another lvalue
|
||||||
|
// * Access by index to an lvalue array
|
||||||
|
|
||||||
print_vec3(a)
|
|
||||||
print('\n')
|
|
||||||
print_vec3(n)
|
|
||||||
print('\n')
|
|
||||||
print_vec3(b)
|
|
||||||
print('\n')
|
|
||||||
print_f32(dot(b, n))
|
|
||||||
print('\n')
|
|
||||||
|
|
|
||||||
10
examples/struct_test.psl
Normal file
10
examples/struct_test.psl
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
struct vec2f:
|
||||||
|
x: f32
|
||||||
|
y: f32
|
||||||
|
|
||||||
|
struct body:
|
||||||
|
position: vec2f
|
||||||
|
rotation: f32
|
||||||
|
|
||||||
|
func move_x(b: body mut*, delta: f32):
|
||||||
|
(*b).position.x = (*b).position.x + delta
|
||||||
|
|
@ -127,7 +127,7 @@ namespace pslang::ir
|
||||||
{
|
{
|
||||||
auto arg1 = apply(*node.arg1);
|
auto arg1 = apply(*node.arg1);
|
||||||
|
|
||||||
// Handle short-circuit operators
|
// Short-circuit operators
|
||||||
|
|
||||||
if (node.type == ast::binary_operation_type::logical_and)
|
if (node.type == ast::binary_operation_type::logical_and)
|
||||||
{
|
{
|
||||||
|
|
@ -154,7 +154,7 @@ namespace pslang::ir
|
||||||
return arg1;
|
return arg1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle pointer arithmetic
|
// Pointer arithmetic
|
||||||
|
|
||||||
auto arg2 = apply(*node.arg2);
|
auto arg2 = apply(*node.arg2);
|
||||||
|
|
||||||
|
|
@ -406,55 +406,100 @@ namespace pslang::ir
|
||||||
return apply(*node);
|
return apply(*node);
|
||||||
}
|
}
|
||||||
|
|
||||||
node_ref apply(ast::assignment const & node)
|
std::optional<node_ref> apply_field_chain_assignment(ast::expression_ptr const & lhs_node, node_ref rhs, std::vector<std::size_t> path)
|
||||||
{
|
{
|
||||||
auto rhs = apply(*node.rhs);
|
if (auto identifier = std::get_if<ast::identifier>(lhs_node.get()))
|
||||||
|
|
||||||
if (std::holds_alternative<ast::identifier>(*node.lhs))
|
|
||||||
{
|
{
|
||||||
auto lhs = apply(*node.lhs);
|
auto lhs = apply(*lhs_node);
|
||||||
mcontext.nodes->emplace_back(assignment{lhs, rhs}, ast::get_type(*node.rhs));
|
mcontext.nodes->emplace_back(assignment{lhs, rhs, std::move(path)}, identifier->inferred_type);
|
||||||
return last();
|
return last();
|
||||||
}
|
}
|
||||||
|
else if (auto field_access = std::get_if<ast::field_access>(lhs_node.get()))
|
||||||
if (auto array_access = std::get_if<ast::array_access>(node.lhs.get()))
|
|
||||||
{
|
|
||||||
auto array_type = get_type(*array_access->array);
|
|
||||||
if (types::is_pointer_type(*array_type))
|
|
||||||
{
|
|
||||||
apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
|
|
||||||
mcontext.nodes->emplace_back(store{last(), rhs}, ast::get_type(*node.rhs));
|
|
||||||
return last();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto field_access = std::get_if<ast::field_access>(node.lhs.get()))
|
|
||||||
{
|
{
|
||||||
auto object_type = ast::get_type(*field_access->object);
|
auto object_type = ast::get_type(*field_access->object);
|
||||||
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
||||||
auto object = apply(field_access->object);
|
|
||||||
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
||||||
{
|
{
|
||||||
auto const & field = struct_node->fields[i];
|
auto const & field = struct_node->fields[i];
|
||||||
if (field.name == field_access->field_name)
|
if (field.name == field_access->field_name)
|
||||||
{
|
{
|
||||||
mcontext.nodes->emplace_back(assignment{object, rhs, {i}}, field.inferred_type);
|
path.push_back(i);
|
||||||
return last();
|
return apply_field_chain_assignment(field_access->object, rhs, std::move(path));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw std::runtime_error("Unknown field name");
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<node_ref> apply_get_address(ast::expression_ptr const & node)
|
||||||
|
{
|
||||||
|
auto result_type = std::make_shared<types::type>(types::pointer_type{ast::get_type(*node), true});
|
||||||
|
|
||||||
|
if (auto identifier = std::get_if<ast::identifier>(node.get()))
|
||||||
|
{
|
||||||
|
auto object = apply(*node);
|
||||||
|
mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::address_of, object}, result_type);
|
||||||
|
return last();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto unary_operation = std::get_if<ast::unary_operation>(node.lhs.get()))
|
if (auto unary_operation = std::get_if<ast::unary_operation>(node.get()))
|
||||||
{
|
{
|
||||||
if (unary_operation->type == ast::unary_operation_type::dereference)
|
if (unary_operation->type == ast::unary_operation_type::dereference)
|
||||||
{
|
{
|
||||||
auto lhs = apply(*unary_operation->arg1);
|
return apply(*unary_operation->arg1);
|
||||||
mcontext.nodes->emplace_back(store{lhs, rhs}, ast::get_type(*node.rhs));
|
|
||||||
return last();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto array_access = std::get_if<ast::array_access>(node.get()))
|
||||||
|
{
|
||||||
|
auto array_type = get_type(*array_access->array);
|
||||||
|
if (types::is_pointer_type(*array_type))
|
||||||
|
{
|
||||||
|
return apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto field_access = std::get_if<ast::field_access>(node.get()))
|
||||||
|
{
|
||||||
|
auto object_type = ast::get_type(*field_access->object);
|
||||||
|
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
||||||
|
if (auto object_ptr = apply_get_address(field_access->object))
|
||||||
|
{
|
||||||
|
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
||||||
|
{
|
||||||
|
auto const & field = struct_node->fields[i];
|
||||||
|
if (field.name == field_access->field_name)
|
||||||
|
{
|
||||||
|
mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}},
|
||||||
|
std::make_shared<types::type>(types::primitive_type{types::u64_type{}}));
|
||||||
|
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, *object_ptr, last()},
|
||||||
|
result_type);
|
||||||
|
return last();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Unknown field name");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
node_ref apply(ast::assignment const & node)
|
||||||
|
{
|
||||||
|
auto rhs = apply(*node.rhs);
|
||||||
|
|
||||||
|
// Detect compound field access (like a.b.c = 1) - a sequence of field_access nodes
|
||||||
|
// terminating in an identifier node.
|
||||||
|
if (auto result = apply_field_chain_assignment(node.lhs, rhs, {}))
|
||||||
|
return *result;
|
||||||
|
|
||||||
|
// Otherwise, compile into explicit memory store
|
||||||
|
if (auto lhs_ptr = apply_get_address(node.lhs))
|
||||||
|
{
|
||||||
|
mcontext.nodes->emplace_back(store{*lhs_ptr, rhs}, ast::get_type(*node.rhs));
|
||||||
|
return last();
|
||||||
|
}
|
||||||
|
|
||||||
throw std::runtime_error("Unknown assignment left-hand side");
|
throw std::runtime_error("Unknown assignment left-hand side");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue