diff --git a/examples/ir_test.psl b/examples/ir_test.psl index bf2efc7..c65b2dc 100644 --- a/examples/ir_test.psl +++ b/examples/ir_test.psl @@ -63,17 +63,21 @@ struct ray: func intersect_plane(ray: ray, normal: vec3, value: f32) -> f32: return (value - dot(ray.origin, normal)) / dot(ray.direction, normal) -let t = intersect_plane(ray(vec3(0.0, 0.0, 5.0), vec3(0.0, 1.0, -1.0)), vec3(0.0, 0.0, 1.0), 0.0) +mut v = vec3(1.0, 2.0, 3.0) +mut u = vec3(1.0, 2.0, 3.0) +let p = &mut u +p[4].x = 4.0 +print_vec3(v) +print('\n') +print_i32(((p + 5) - p) as i32) +print('\n') -let a = vec3(5.0, 1.0, 4.0) -let n = normalized(vec3(1.0, 2.0, 3.0)) -let b = add(a, mult(n, -dot(a, n))) +// Assignment syntax: +// lhs = rhs +// lhs must be an lvalue +// An lvalue can be +// * An identifier +// * A pointer dereference +// * A field of another lvalue +// * Access by index to an lvalue array -print_vec3(a) -print('\n') -print_vec3(n) -print('\n') -print_vec3(b) -print('\n') -print_f32(dot(b, n)) -print('\n') diff --git a/examples/struct_test.psl b/examples/struct_test.psl new file mode 100644 index 0000000..21fd512 --- /dev/null +++ b/examples/struct_test.psl @@ -0,0 +1,10 @@ +struct vec2f: + x: f32 + y: f32 + +struct body: + position: vec2f + rotation: f32 + +func move_x(b: body mut*, delta: f32): + (*b).position.x = (*b).position.x + delta diff --git a/libs/ir/source/compiler.cpp b/libs/ir/source/compiler.cpp index f23a454..edd2fee 100644 --- a/libs/ir/source/compiler.cpp +++ b/libs/ir/source/compiler.cpp @@ -127,7 +127,7 @@ namespace pslang::ir { auto arg1 = apply(*node.arg1); - // Handle short-circuit operators + // Short-circuit operators if (node.type == ast::binary_operation_type::logical_and) { @@ -154,7 +154,7 @@ namespace pslang::ir return arg1; } - // Handle pointer arithmetic + // Pointer arithmetic auto arg2 = apply(*node.arg2); @@ -406,55 +406,100 @@ namespace pslang::ir return apply(*node); } - node_ref apply(ast::assignment const & node) + std::optional apply_field_chain_assignment(ast::expression_ptr const & lhs_node, node_ref rhs, std::vector path) { - auto rhs = apply(*node.rhs); - - if (std::holds_alternative(*node.lhs)) + if (auto identifier = std::get_if(lhs_node.get())) { - auto lhs = apply(*node.lhs); - mcontext.nodes->emplace_back(assignment{lhs, rhs}, ast::get_type(*node.rhs)); + auto lhs = apply(*lhs_node); + mcontext.nodes->emplace_back(assignment{lhs, rhs, std::move(path)}, identifier->inferred_type); return last(); } - - if (auto array_access = std::get_if(node.lhs.get())) - { - auto array_type = get_type(*array_access->array); - if (types::is_pointer_type(*array_type)) - { - apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type}); - mcontext.nodes->emplace_back(store{last(), rhs}, ast::get_type(*node.rhs)); - return last(); - } - } - - if (auto field_access = std::get_if(node.lhs.get())) + else if (auto field_access = std::get_if(lhs_node.get())) { auto object_type = ast::get_type(*field_access->object); auto struct_node = std::get_if(object_type.get())->node; - auto object = apply(field_access->object); for (std::size_t i = 0; i < struct_node->fields.size(); ++i) { auto const & field = struct_node->fields[i]; if (field.name == field_access->field_name) { - mcontext.nodes->emplace_back(assignment{object, rhs, {i}}, field.inferred_type); - return last(); + path.push_back(i); + return apply_field_chain_assignment(field_access->object, rhs, std::move(path)); } } - throw std::runtime_error("Unknown field name"); + } + return std::nullopt; + } + + std::optional apply_get_address(ast::expression_ptr const & node) + { + auto result_type = std::make_shared(types::pointer_type{ast::get_type(*node), true}); + + if (auto identifier = std::get_if(node.get())) + { + auto object = apply(*node); + mcontext.nodes->emplace_back(unary_operation{ast::unary_operation_type::address_of, object}, result_type); + return last(); } - if (auto unary_operation = std::get_if(node.lhs.get())) + if (auto unary_operation = std::get_if(node.get())) { if (unary_operation->type == ast::unary_operation_type::dereference) { - auto lhs = apply(*unary_operation->arg1); - mcontext.nodes->emplace_back(store{lhs, rhs}, ast::get_type(*node.rhs)); - return last(); + return apply(*unary_operation->arg1); } } + if (auto array_access = std::get_if(node.get())) + { + auto array_type = get_type(*array_access->array); + if (types::is_pointer_type(*array_type)) + { + return apply(ast::binary_operation{ast::binary_operation_type::addition, array_access->array, array_access->index, {}, array_type}); + } + } + + if (auto field_access = std::get_if(node.get())) + { + auto object_type = ast::get_type(*field_access->object); + auto struct_node = std::get_if(object_type.get())->node; + if (auto object_ptr = apply_get_address(field_access->object)) + { + for (std::size_t i = 0; i < struct_node->fields.size(); ++i) + { + auto const & field = struct_node->fields[i]; + if (field.name == field_access->field_name) + { + mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}}, + std::make_shared(types::primitive_type{types::u64_type{}})); + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, *object_ptr, last()}, + result_type); + return last(); + } + } + throw std::runtime_error("Unknown field name"); + } + } + + return std::nullopt; + } + + node_ref apply(ast::assignment const & node) + { + auto rhs = apply(*node.rhs); + + // Detect compound field access (like a.b.c = 1) - a sequence of field_access nodes + // terminating in an identifier node. + if (auto result = apply_field_chain_assignment(node.lhs, rhs, {})) + return *result; + + // Otherwise, compile into explicit memory store + if (auto lhs_ptr = apply_get_address(node.lhs)) + { + mcontext.nodes->emplace_back(store{*lhs_ptr, rhs}, ast::get_type(*node.rhs)); + return last(); + } + throw std::runtime_error("Unknown assignment left-hand side"); }