Implement field access for pointer-to-struct

This commit is contained in:
Nikita Lisitsa 2026-04-02 22:48:54 +03:00
parent 42defff7d6
commit 2a3e2171b0
3 changed files with 121 additions and 81 deletions

View file

@ -304,11 +304,10 @@ func splitmix64(x: u64) -> u64:
func next_u64(rng: rng mut*) -> u64:
func rotate_left(x: u64, k: u32) -> u64:
return (x << k) | (x >> (64u - k))
// No struct-field-by-pointer (rng->state) access yet
let result = rotate_left((*rng).state[0] * 5ul, 7u) * 9ul
(*rng).state[1] ^= (*rng).state[0]
(*rng).state[0] = rotate_left((*rng).state[0], 24u) ^ (*rng).state[1] ^ ((*rng).state[1] << 16u)
(*rng).state[1] = rotate_left((*rng).state[1], 37u)
let result = rotate_left(rng.state[0] * 5ul, 7u) * 9ul
rng.state[1] ^= rng.state[0]
rng.state[0] = rotate_left(rng.state[0], 24u) ^ rng.state[1] ^ (rng.state[1] << 16u)
rng.state[1] = rotate_left(rng.state[1], 37u)
return result
// Uniform in [0..1)
@ -395,15 +394,15 @@ struct object:
material: material
func set_plane(object: object mut*, shape: plane):
(*object).shape.tag = plane_tag
object.shape.tag = plane_tag
func set_sphere(object: object mut*, shape: sphere):
(*object).shape.tag = sphere_tag
*((*object).shape.data as f32 mut* as sphere mut*) = shape
object.shape.tag = sphere_tag
*(object.shape.data as f32 mut* as sphere mut*) = shape
func set_box(object: object mut*, shape: box):
(*object).shape.tag = box_tag
*((*object).shape.data as f32 mut* as box mut*) = shape
object.shape.tag = box_tag
*(object.shape.data as f32 mut* as box mut*) = shape
struct object_array:
size: u64
@ -447,22 +446,22 @@ const max_distance = 1000000.0
const no_intersection = intersection(false, max_distance, vec3(0.0, 0.0, 1.0), 0ul as material*)
func intersect_plane(ray: ray, object: object*) -> intersection:
let normal = rotate((*object).rotation, vec3(0.0, 0.0, 1.0))
let normal = rotate(object.rotation, vec3(0.0, 0.0, 1.0))
// dot(o + t * d - p, n) = 0
// dot(o - p, n) + t * dot(d, n) = 0
// t = - dot(o - p, n) / dot(d, n)
let t = - dot(sub(ray.origin, (*object).position), normal) / dot(ray.direction, normal)
return intersection(t > 0.0, t, normal, &(*object).material)
let t = - dot(sub(ray.origin, object.position), normal) / dot(ray.direction, normal)
return intersection(t > 0.0, t, normal, &object.material)
func intersect_sphere(ray: ray, object: object*) -> intersection:
let shape = &(*object).shape.data as sphere*
let shape = &object.shape.data as sphere*
// |o + t * d - p|^2 = r^2
// dot(o - p, o - p) + 2 * t * dot(o - p, d) + t^2 * dot(d, d) = r * r
let delta = sub(ray.origin, (*object).position)
let delta = sub(ray.origin, object.position)
// Solve quadratic
let a = 1.0 // assume ray.direction is normalized
let b = 2.0 * dot(delta, ray.direction)
let c = dot(delta, delta) - (*shape).radius * (*shape).radius
let c = dot(delta, delta) - shape.radius * shape.radius
let D = b * b - 4.0 * a * c
if D < 0.0:
return no_intersection
@ -472,7 +471,7 @@ func intersect_sphere(ray: ray, object: object*) -> intersection:
return no_intersection
let t = if t1 < 0.0 then t2 else t1
let normal = normalized(add(delta, mults(ray.direction, t)))
return intersection(true, t, normal, &(*object).material)
return intersection(true, t, normal, &object.material)
func intersect_box(ray: ray, object: object*) -> intersection:
func sort(x: f32 mut*, y: f32 mut*):
@ -481,18 +480,18 @@ func intersect_box(ray: ray, object: object*) -> intersection:
*x = *y
*y = temp
let shape = &(*object).shape.data as box*
let inverse_rotation = inverse((*object).rotation)
let local_delta = rotate(inverse_rotation, sub(ray.origin, (*object).position))
let shape = &object.shape.data as box*
let inverse_rotation = inverse(object.rotation)
let local_delta = rotate(inverse_rotation, sub(ray.origin, object.position))
let local_direction = rotate(inverse_rotation, ray.direction)
// (o + t * d).x = +/- e.x
mut txmin = (- (*shape).extent.x - local_delta.x) / local_direction.x
mut txmax = ( (*shape).extent.x - local_delta.x) / local_direction.x
mut tymin = (- (*shape).extent.y - local_delta.y) / local_direction.y
mut tymax = ( (*shape).extent.y - local_delta.y) / local_direction.y
mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z
mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z
mut txmin = (- shape.extent.x - local_delta.x) / local_direction.x
mut txmax = ( shape.extent.x - local_delta.x) / local_direction.x
mut tymin = (- shape.extent.y - local_delta.y) / local_direction.y
mut tymax = ( shape.extent.y - local_delta.y) / local_direction.y
mut tzmin = (- shape.extent.z - local_delta.z) / local_direction.z
mut tzmax = ( shape.extent.z - local_delta.z) / local_direction.z
sort(&mut txmin, &mut txmax)
sort(&mut tymin, &mut tymax)
@ -529,20 +528,20 @@ func intersect_box(ray: ray, object: object*) -> intersection:
if inside:
normal = mults(normal, -1.0)
return intersection(true, t, rotate((*object).rotation, normal), &(*object).material)
return intersection(true, t, rotate(object.rotation, normal), &object.material)
func intersect_scene(scene: scene*, ray: ray) -> intersection:
mut intersection = no_intersection
mut i = 0ul
while i < (*scene).objects.size:
while i < scene.objects.size:
mut current_intersection = no_intersection
if (*scene).objects.data[i].shape.tag == plane_tag:
current_intersection = intersect_plane(ray, &(*scene).objects.data[i])
else if (*scene).objects.data[i].shape.tag == sphere_tag:
current_intersection = intersect_sphere(ray, &(*scene).objects.data[i])
else if (*scene).objects.data[i].shape.tag == box_tag:
current_intersection = intersect_box(ray, &(*scene).objects.data[i])
if scene.objects.data[i].shape.tag == plane_tag:
current_intersection = intersect_plane(ray, &scene.objects.data[i])
else if scene.objects.data[i].shape.tag == sphere_tag:
current_intersection = intersect_sphere(ray, &scene.objects.data[i])
else if scene.objects.data[i].shape.tag == box_tag:
current_intersection = intersect_box(ray, &scene.objects.data[i])
if current_intersection.intersected && current_intersection.distance < intersection.distance:
intersection = current_intersection
@ -561,7 +560,7 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3:
let intersection = intersect_scene(scene, current_ray)
if !intersection.intersected:
result = add(result, multv(factor, (*scene).background))
result = add(result, multv(factor, scene.background))
break
// Uncomment to debug normals
@ -570,7 +569,7 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3:
let cosine = - dot(intersection.normal, current_ray.direction)
let inside = cosine < 0.0
result = add(result, multv(factor, (*intersection.material).emission))
result = add(result, multv(factor, intersection.material.emission))
// Russian roulette ray termination
if next_f32(rng) < termination_probability:
@ -578,32 +577,32 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3:
mut new_direction = vec3(0.0, 0.0, 0.0)
if (*intersection.material).type == diffuse_tag:
if intersection.material.type == diffuse_tag:
// NB: albedo is assumed to be premultiplied by pi to be in [0..1] range
// This should also contain multiplication by cos(new_dir, normal), division by direction pdf (cos / pi)
// and division by pi (because of albedo normalization), but these all cancel out
factor = multv(factor, (*intersection.material).color)
factor = multv(factor, intersection.material.color)
// Cosine-weighted hemisphere direction
new_direction = normalized(add(next_vec3(rng), intersection.normal))
else if (*intersection.material).type == metallic_tag:
else if intersection.material.type == metallic_tag:
// This should also contain multiplication by brdf and division by direction pdf,
// but we'll just pretend that the random reflected ray pdf coincides with brdf and thus cancels out
factor = multv(factor, (*intersection.material).color)
factor = multv(factor, intersection.material.color)
// Compute perfect-mirror reflected direction
new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine))
// Alter the direction based on roughness
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * (*intersection.material).roughness)))
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * intersection.material.roughness)))
else if (*intersection.material).type == glass_tag:
else if intersection.material.type == glass_tag:
// This should also contain multiplication by brdf and division by direction pdf,
// but we'll just pretend that the random refracted ray pdf coincides with brdf and thus cancels out
factor = multv(factor, (*intersection.material).color)
factor = multv(factor, intersection.material.color)
mut ior = (*intersection.material).ior
mut ior = intersection.material.ior
if inside:
ior = 1.0 / ior
@ -616,7 +615,7 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3:
new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine))
// Alter the direction based on roughness
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), (*intersection.material).roughness)))
new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), intersection.material.roughness)))
// Compute the new ray, and offset its origin a bit along intersection normal
let position = add(current_ray.origin, mults(current_ray.direction, intersection.distance))
@ -759,17 +758,17 @@ const default_fovy = 2.0 * atan(0.5)
func main():
let scene = make_default_scene()
let image = create_image(512ul, 512ul)
// let image = create_image(512ul, 512ul)
// let image = create_image(256ul, 256ul)
// let image = create_image(128ul, 128ul)
let image = create_image(128ul, 128ul)
let aspect_ratio = (image.width as f32) / (image.height as f32)
let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio))
// const samples_per_pixel = 16384ul
const samples_per_pixel = 4096ul
// const samples_per_pixel = 4096ul
// const samples_per_pixel = 2048ul
// const samples_per_pixel = 1024ul
const samples_per_pixel = 1024ul
// const samples_per_pixel = 512ul
// const samples_per_pixel = 256ul
// const samples_per_pixel = 16ul
@ -786,28 +785,27 @@ func main():
func thread_func(data_raw: unit mut*):
let data = data_raw as thread_data mut*
let image = (*data).image
mut y = (*data).ystart
while y < image.height:
mut y = data.ystart
while y < data.image.height:
mut x = 0ul
while x < image.width:
while x < data.image.width:
mut rng = rng([splitmix64(x | (y << 32u)), 10723151780598845931ul])
mut sample = 0ul
mut color = vec3(0.0, 0.0, 0.0)
while sample < samples_per_pixel:
let tx = - 1.0 + 2.0 * (x as f32 + next_f32(&mut rng)) / (image.width as f32)
let ty = 1.0 - 2.0 * (y as f32 + next_f32(&mut rng)) / (image.height as f32)
let tx = - 1.0 + 2.0 * (x as f32 + next_f32(&mut rng)) / (data.image.width as f32)
let ty = 1.0 - 2.0 * (y as f32 + next_f32(&mut rng)) / (data.image.height as f32)
let ray = camera_ray((*data).camera, tx, ty)
color = add(color, raytrace((*data).scene, ray, &mut rng))
sample += 1ul
color = mults(color, 1.0 / (samples_per_pixel as f32))
image.data[y * image.width + x] = to_bytes(gamma_correct(aces(color)))
data.image.data[y * data.image.width + x] = to_bytes(gamma_correct(aces(color)))
x += 1ul
lock_mutex((*data).done_mutex)
(*data).done += 1ul
unlock_mutex((*data).done_mutex)
y += (*data).ystep
lock_mutex(data.done_mutex)
data.done += 1ul
unlock_mutex(data.done_mutex)
y += data.ystep
let start_time = get_time()
mut last_report_time = start_time
@ -822,14 +820,14 @@ func main():
mut th = 0ul
while th < thread_count:
let data = threads_data + th
(*data).scene = &scene
(*data).image = image
(*data).camera = camera
(*data).ystart = th
(*data).ystep = thread_count
(*data).done_mutex = allocate(64ul) as mutex mut*
create_mutex((*data).done_mutex)
(*data).done = 0ul
data.scene = &scene
data.image = image
data.camera = camera
data.ystart = th
data.ystep = thread_count
data.done_mutex = allocate(64ul) as mutex mut*
create_mutex(data.done_mutex)
data.done = 0ul
threads[th] = create_thread(thread_func, data as unit mut*)
th += 1ul

View file

@ -760,6 +760,10 @@ namespace pslang::ast
auto object_type = get_type(*node.object);
auto struct_type = std::get_if<types::struct_type>(object_type.get());
auto pointer_type = std::get_if<types::pointer_type>(object_type.get());
if (pointer_type)
struct_type = std::get_if<types::struct_type>(pointer_type->referenced_type.get());
if (!struct_type)
{
@ -973,7 +977,11 @@ namespace pslang::ast
}
else if (auto field_access = std::get_if<ast::field_access>(node.get()))
{
return classify_lvalue(field_access->object);
auto object_type = get_type(*field_access->object);
if (auto pointer_type = std::get_if<types::pointer_type>(object_type.get()))
return pointer_type->is_mutable ? ast::value_category::_mutable : ast::value_category::constant;
else
return classify_lvalue(field_access->object);
}
else if (auto array_access = std::get_if<ast::array_access>(node.get()))
{

View file

@ -461,13 +461,30 @@ namespace pslang::ir
{
auto object = apply(*node.object);
auto object_type = ast::get_type(*node.object);
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
auto struct_type = std::get_if<types::struct_type>(object_type.get());
auto pointer_type = std::get_if<types::pointer_type>(object_type.get());
if (pointer_type)
struct_type = std::get_if<types::struct_type>(pointer_type->referenced_type.get());
auto struct_node = struct_type->node;
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
{
if (struct_node->fields[i].name == node.field_name)
auto const & field = struct_node->fields[i];
if (field.name == node.field_name)
{
mcontext.nodes->emplace_back(copy{object, {i}}, node.inferred_type);
return last();
if (pointer_type)
{
mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}},
std::make_shared<types::type>(types::primitive_type{types::u64_type{}}));
mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, object, last()},
std::make_shared<types::type>(types::pointer_type{field.inferred_type, pointer_type->is_mutable}));
mcontext.nodes->emplace_back(load{last()}, field.inferred_type);
return last();
}
else
{
mcontext.nodes->emplace_back(copy{object, {i}}, node.inferred_type);
return last();
}
}
}
throw std::runtime_error("Unknown field name");
@ -514,14 +531,17 @@ namespace pslang::ir
else if (auto field_access = std::get_if<ast::field_access>(lhs_node.get()))
{
auto object_type = ast::get_type(*field_access->object);
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
if (auto struct_type = std::get_if<types::struct_type>(object_type.get()))
{
auto const & field = struct_node->fields[i];
if (field.name == field_access->field_name)
auto struct_node = struct_type->node;
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
{
path.push_back(i);
return apply_field_chain_assignment(field_access->object, rhs, std::move(path));
auto const & field = struct_node->fields[i];
if (field.name == field_access->field_name)
{
path.push_back(i);
return apply_field_chain_assignment(field_access->object, rhs, std::move(path));
}
}
}
}
@ -589,8 +609,22 @@ namespace pslang::ir
if (auto field_access = std::get_if<ast::field_access>(node.get()))
{
auto object_type = ast::get_type(*field_access->object);
auto struct_node = std::get_if<types::struct_type>(object_type.get())->node;
if (auto object_ptr = apply_get_address(field_access->object))
auto struct_type = std::get_if<types::struct_type>(object_type.get());
auto pointer_type = std::get_if<types::pointer_type>(object_type.get());
std::optional<node_ref> object_ptr;
if (pointer_type)
{
struct_type = std::get_if<types::struct_type>(pointer_type->referenced_type.get());
object_ptr = apply(*field_access->object);
}
else
{
object_ptr = apply_get_address(field_access->object);
}
auto struct_node = struct_type->node;
if (object_ptr)
{
for (std::size_t i = 0; i < struct_node->fields.size(); ++i)
{