Implement proper struct field assigment in IR compiler
This commit is contained in:
parent
79101ff3bd
commit
46c33b2acf
3 changed files with 100 additions and 41 deletions
|
|
@ -63,17 +63,21 @@ struct ray:
|
|||
func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32:
|
||||
return (value - dot(ray.origin, normal)) / dot(ray.direction, normal)
|
||||
|
||||
let t = intersect_plane(ray(vec3(0.0, 0.0, 5.0), vec3(0.0, 1.0, -1.0)), vec3(0.0, 0.0, 1.0), 0.0)
|
||||
mut v = vec3(1.0, 2.0, 3.0)
|
||||
mut u = vec3(1.0, 2.0, 3.0)
|
||||
let p = &mut u
|
||||
p[4].x = 4.0
|
||||
print_vec3(v)
|
||||
print('\n')
|
||||
print_i32(((p + 5) - p) as i32)
|
||||
print('\n')
|
||||
|
||||
let a = vec3(5.0, 1.0, 4.0)
|
||||
let n = normalized(vec3(1.0, 2.0, 3.0))
|
||||
let b = add(a, mult(n, -dot(a, n)))
|
||||
// Assignment syntax:
|
||||
// lhs = rhs
|
||||
// lhs must be an lvalue
|
||||
// An lvalue can be
|
||||
// * An identifier
|
||||
// * A pointer dereference
|
||||
// * A field of another lvalue
|
||||
// * Access by index to an lvalue array
|
||||
|
||||
print_vec3(a)
|
||||
print('\n')
|
||||
print_vec3(n)
|
||||
print('\n')
|
||||
print_vec3(b)
|
||||
print('\n')
|
||||
print_f32(dot(b, n))
|
||||
print('\n')
|
||||
|
|
|
|||
10
examples/struct_test.psl
Normal file
10
examples/struct_test.psl
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
struct vec2f:
|
||||
x: f32
|
||||
y: f32
|
||||
|
||||
struct body:
|
||||
position: vec2f
|
||||
rotation: f32
|
||||
|
||||
func move_x(b: body mut*, delta: f32):
|
||||
(*b).position.x = (*b).position.x + delta
|
||||
|
|
@ -127,7 +127,7 @@ namespace pslang::ir
|
|||
{
|
||||
auto arg1 = apply(*node.arg1);
|
||||
|
||||
// Handle short-circuit operators
|
||||
// Short-circuit operators
|
||||
|
||||
if (node.type == ast::binary_operation_type::logical_and)
|
||||
{
|
||||
|
|
@ -154,7 +154,7 @@ namespace pslang::ir
|
|||
return arg1;
|
||||
}
|
||||
|
||||
// Handle pointer arithmetic
|
||||
// Pointer arithmetic
|
||||
|
||||
auto arg2 = apply(*node.arg2);
|
||||
|
||||
|
|
@ -406,55 +406,100 @@ namespace pslang::ir
|
|||
return apply(*node);
|
||||
}
|
||||
|
||||
node_ref apply(ast::assignment const & node)
|
||||
std::optional<node_ref> apply_field_chain_assignment(ast::expression_ptr const & lhs_node, node_ref rhs, std::vector<std::size_t> path)
|
||||
{
|
||||
auto rhs = apply(*node.rhs);
|
||||
|
||||
if (std::holds_alternative<ast::identifier>(*node.lhs))
|
||||
if (auto identifier = std::get_if<ast::identifier>(lhs_node.get()))
|
||||
{
|
||||
auto lhs = apply(*node.lhs);
|
||||
mcontext.nodes->emplace_back(assignment{lhs, rhs}, ast::get_type(*node.rhs));
|
||||
auto lhs = apply(*lhs_node);
|
||||
mcontext.nodes->emplace_back(assignment{lhs, rhs, std::move(path)}, identifier->inferred_type);
|
||||
return last();
|
||||
}
|
||||
|
||||
if (auto array_access = std::get_if<ast::array_access>(node.lhs.get()))
|
||||
{
|
||||
auto array_type = get_type(*array_access->array);
|
||||
if (types::is_pointer_type(*array_type))
|
||||
{
|
||||
apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
|
||||
mcontext.nodes->emplace_back(store{last(), rhs}, ast::get_type(*node.rhs));
|
||||
return last();
|
||||
}
|
||||
}
|
||||
|
||||
if (auto field_access = std::get_if<ast::field_access>(node.lhs.get()))
|
||||
else if (auto field_access = std::get_if<ast::field_access>(lhs_node.get()))
|
||||
{
|
||||
auto object_type = ast::get_type(*field_access->object);
|
||||
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
||||
auto object = apply(field_access->object);
|
||||
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
||||
{
|
||||
auto const & field = struct_node->fields[i];
|
||||
if (field.name == field_access->field_name)
|
||||
{
|
||||
mcontext.nodes->emplace_back(assignment{object, rhs, {i}}, field.inferred_type);
|
||||
return last();
|
||||
path.push_back(i);
|
||||
return apply_field_chain_assignment(field_access->object, rhs, std::move(path));
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unknown field name");
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<node_ref> apply_get_address(ast::expression_ptr const & node)
|
||||
{
|
||||
auto result_type = std::make_shared<types::type>(types::pointer_type{ast::get_type(*node), true});
|
||||
|
||||
if (auto identifier = std::get_if<ast::identifier>(node.get()))
|
||||
{
|
||||
auto object = apply(*node);
|
||||
mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::address_of, object}, result_type);
|
||||
return last();
|
||||
}
|
||||
|
||||
if (auto unary_operation = std::get_if<ast::unary_operation>(node.lhs.get()))
|
||||
if (auto unary_operation = std::get_if<ast::unary_operation>(node.get()))
|
||||
{
|
||||
if (unary_operation->type == ast::unary_operation_type::dereference)
|
||||
{
|
||||
auto lhs = apply(*unary_operation->arg1);
|
||||
mcontext.nodes->emplace_back(store{lhs, rhs}, ast::get_type(*node.rhs));
|
||||
return last();
|
||||
return apply(*unary_operation->arg1);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto array_access = std::get_if<ast::array_access>(node.get()))
|
||||
{
|
||||
auto array_type = get_type(*array_access->array);
|
||||
if (types::is_pointer_type(*array_type))
|
||||
{
|
||||
return apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
|
||||
}
|
||||
}
|
||||
|
||||
if (auto field_access = std::get_if<ast::field_access>(node.get()))
|
||||
{
|
||||
auto object_type = ast::get_type(*field_access->object);
|
||||
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
|
||||
if (auto object_ptr = apply_get_address(field_access->object))
|
||||
{
|
||||
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
|
||||
{
|
||||
auto const & field = struct_node->fields[i];
|
||||
if (field.name == field_access->field_name)
|
||||
{
|
||||
mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}},
|
||||
std::make_shared<types::type>(types::primitive_type{types::u64_type{}}));
|
||||
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, *object_ptr, last()},
|
||||
result_type);
|
||||
return last();
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Unknown field name");
|
||||
}
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
node_ref apply(ast::assignment const & node)
|
||||
{
|
||||
auto rhs = apply(*node.rhs);
|
||||
|
||||
// Detect compound field access (like a.b.c = 1) - a sequence of field_access nodes
|
||||
// terminating in an identifier node.
|
||||
if (auto result = apply_field_chain_assignment(node.lhs, rhs, {}))
|
||||
return *result;
|
||||
|
||||
// Otherwise, compile into explicit memory store
|
||||
if (auto lhs_ptr = apply_get_address(node.lhs))
|
||||
{
|
||||
mcontext.nodes->emplace_back(store{*lhs_ptr, rhs}, ast::get_type(*node.rhs));
|
||||
return last();
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown assignment left-hand side");
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue