Implement proper struct field assigment in IR compiler

This commit is contained in:
Nikita Lisitsa 2026-03-29 12:20:49 +03:00
parent 79101ff3bd
commit 46c33b2acf
3 changed files with 100 additions and 41 deletions

View file

@ -63,17 +63,21 @@ struct ray:
func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32:
return (value - dot(ray.origin, normal)) / dot(ray.direction, normal)
let t = intersect_plane(ray(vec3(0.0, 0.0, 5.0), vec3(0.0, 1.0, -1.0)), vec3(0.0, 0.0, 1.0), 0.0)
mut v = vec3(1.0, 2.0, 3.0)
mut u = vec3(1.0, 2.0, 3.0)
let p = &mut u
p[4].x = 4.0
print_vec3(v)
print('\n')
print_i32(((p + 5) - p) as i32)
print('\n')
let a = vec3(5.0, 1.0, 4.0)
let n = normalized(vec3(1.0, 2.0, 3.0))
let b = add(a, mult(n, -dot(a, n)))
// Assignment syntax:
// lhs = rhs
// lhs must be an lvalue
// An lvalue can be
// * An identifier
// * A pointer dereference
// * A field of another lvalue
// * Access by index to an lvalue array
print_vec3(a)
print('\n')
print_vec3(n)
print('\n')
print_vec3(b)
print('\n')
print_f32(dot(b, n))
print('\n')

10
examples/struct_test.psl Normal file
View file

@ -0,0 +1,10 @@
struct vec2f:
x: f32
y: f32
struct body:
position: vec2f
rotation: f32
func move_x(b: body mut*, delta: f32):
(*b).position.x = (*b).position.x + delta

View file

@ -127,7 +127,7 @@ namespace pslang::ir
{
auto arg1 = apply(*node.arg1);
// Handle short-circuit operators
// Short-circuit operators
if (node.type == ast::binary_operation_type::logical_and)
{
@ -154,7 +154,7 @@ namespace pslang::ir
return arg1;
}
// Handle pointer arithmetic
// Pointer arithmetic
auto arg2 = apply(*node.arg2);
@ -406,55 +406,100 @@ namespace pslang::ir
return apply(*node);
}
node_ref apply(ast::assignment const & node)
std::optional<node_ref> apply_field_chain_assignment(ast::expression_ptr const & lhs_node, node_ref rhs, std::vector<std::size_t> path)
{
auto rhs = apply(*node.rhs);
if (std::holds_alternative<ast::identifier>(*node.lhs))
if (auto identifier = std::get_if<ast::identifier>(lhs_node.get()))
{
auto lhs = apply(*node.lhs);
mcontext.nodes->emplace_back(assignment{lhs, rhs}, ast::get_type(*node.rhs));
auto lhs = apply(*lhs_node);
mcontext.nodes->emplace_back(assignment{lhs, rhs, std::move(path)}, identifier->inferred_type);
return last();
}
if (auto array_access = std::get_if<ast::array_access>(node.lhs.get()))
{
auto array_type = get_type(*array_access->array);
if (types::is_pointer_type(*array_type))
{
apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
mcontext.nodes->emplace_back(store{last(), rhs}, ast::get_type(*node.rhs));
return last();
}
}
if (auto field_access = std::get_if<ast::field_access>(node.lhs.get()))
else if (auto field_access = std::get_if<ast::field_access>(lhs_node.get()))
{
auto object_type = ast::get_type(*field_access->object);
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
auto object = apply(field_access->object);
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
{
auto const & field = struct_node->fields[i];
if (field.name == field_access->field_name)
{
mcontext.nodes->emplace_back(assignment{object, rhs, {i}}, field.inferred_type);
return last();
path.push_back(i);
return apply_field_chain_assignment(field_access->object, rhs, std::move(path));
}
}
throw std::runtime_error("Unknown field name");
}
return std::nullopt;
}
std::optional<node_ref> apply_get_address(ast::expression_ptr const & node)
{
auto result_type = std::make_shared<types::type>(types::pointer_type{ast::get_type(*node), true});
if (auto identifier = std::get_if<ast::identifier>(node.get()))
{
auto object = apply(*node);
mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::address_of, object}, result_type);
return last();
}
if (auto unary_operation = std::get_if<ast::unary_operation>(node.lhs.get()))
if (auto unary_operation = std::get_if<ast::unary_operation>(node.get()))
{
if (unary_operation->type == ast::unary_operation_type::dereference)
{
auto lhs = apply(*unary_operation->arg1);
mcontext.nodes->emplace_back(store{lhs, rhs}, ast::get_type(*node.rhs));
return last();
return apply(*unary_operation->arg1);
}
}
if (auto array_access = std::get_if<ast::array_access>(node.get()))
{
auto array_type = get_type(*array_access->array);
if (types::is_pointer_type(*array_type))
{
return apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type});
}
}
if (auto field_access = std::get_if<ast::field_access>(node.get()))
{
auto object_type = ast::get_type(*field_access->object);
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
if (auto object_ptr = apply_get_address(field_access->object))
{
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
{
auto const & field = struct_node->fields[i];
if (field.name == field_access->field_name)
{
mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}},
std::make_shared<types::type>(types::primitive_type{types::u64_type{}}));
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, *object_ptr, last()},
result_type);
return last();
}
}
throw std::runtime_error("Unknown field name");
}
}
return std::nullopt;
}
node_ref apply(ast::assignment const & node)
{
auto rhs = apply(*node.rhs);
// Detect compound field access (like a.b.c = 1) - a sequence of field_access nodes
// terminating in an identifier node.
if (auto result = apply_field_chain_assignment(node.lhs, rhs, {}))
return *result;
// Otherwise, compile into explicit memory store
if (auto lhs_ptr = apply_get_address(node.lhs))
{
mcontext.nodes->emplace_back(store{*lhs_ptr, rhs}, ast::get_type(*node.rhs));
return last();
}
throw std::runtime_error("Unknown assignment left-hand side");
}