diff --git a/examples/raytracer.psl b/examples/raytracer.psl index cb33077..a5f93a4 100644 --- a/examples/raytracer.psl +++ b/examples/raytracer.psl @@ -304,11 +304,10 @@ func splitmix64(x: u64) -> u64: func next_u64(rng: rng mut*) -> u64: func rotate_left(x: u64, k: u32) -> u64: return (x << k) | (x >> (64u - k)) - // No struct-field-by-pointer (rng->state) access yet - let result = rotate_left((*rng).state[0] * 5ul, 7u) * 9ul - (*rng).state[1] ^= (*rng).state[0] - (*rng).state[0] = rotate_left((*rng).state[0], 24u) ^ (*rng).state[1] ^ ((*rng).state[1] << 16u) - (*rng).state[1] = rotate_left((*rng).state[1], 37u) + let result = rotate_left(rng.state[0] * 5ul, 7u) * 9ul + rng.state[1] ^= rng.state[0] + rng.state[0] = rotate_left(rng.state[0], 24u) ^ rng.state[1] ^ (rng.state[1] << 16u) + rng.state[1] = rotate_left(rng.state[1], 37u) return result // Uniform in [0..1) @@ -395,15 +394,15 @@ struct object: material: material func set_plane(object: object mut*, shape: plane): - (*object).shape.tag = plane_tag + object.shape.tag = plane_tag func set_sphere(object: object mut*, shape: sphere): - (*object).shape.tag = sphere_tag - *((*object).shape.data as f32 mut* as sphere mut*) = shape + object.shape.tag = sphere_tag + *(object.shape.data as f32 mut* as sphere mut*) = shape func set_box(object: object mut*, shape: box): - (*object).shape.tag = box_tag - *((*object).shape.data as f32 mut* as box mut*) = shape + object.shape.tag = box_tag + *(object.shape.data as f32 mut* as box mut*) = shape struct object_array: size: u64 @@ -447,22 +446,22 @@ const max_distance = 1000000.0 const no_intersection = intersection(false, max_distance, vec3(0.0, 0.0, 1.0), 0ul as material*) func intersect_plane(ray: ray, object: object*) -> intersection: - let normal = rotate((*object).rotation, vec3(0.0, 0.0, 1.0)) + let normal = rotate(object.rotation, vec3(0.0, 0.0, 1.0)) // dot(o + t * d - p, n) = 0 // dot(o - p, n) + t * dot(d, n) = 0 // t = - dot(o - p, n) / dot(d, n) - let t = - dot(sub(ray.origin, (*object).position), normal) / dot(ray.direction, normal) - return intersection(t > 0.0, t, normal, &(*object).material) + let t = - dot(sub(ray.origin, object.position), normal) / dot(ray.direction, normal) + return intersection(t > 0.0, t, normal, &object.material) func intersect_sphere(ray: ray, object: object*) -> intersection: - let shape = &(*object).shape.data as sphere* + let shape = &object.shape.data as sphere* // |o + t * d - p|^2 = r^2 // dot(o - p, o - p) + 2 * t * dot(o - p, d) + t^2 * dot(d, d) = r * r - let delta = sub(ray.origin, (*object).position) + let delta = sub(ray.origin, object.position) // Solve quadratic let a = 1.0 // assume ray.direction is normalized let b = 2.0 * dot(delta, ray.direction) - let c = dot(delta, delta) - (*shape).radius * (*shape).radius + let c = dot(delta, delta) - shape.radius * shape.radius let D = b * b - 4.0 * a * c if D < 0.0: return no_intersection @@ -472,7 +471,7 @@ func intersect_sphere(ray: ray, object: object*) -> intersection: return no_intersection let t = if t1 < 0.0 then t2 else t1 let normal = normalized(add(delta, mults(ray.direction, t))) - return intersection(true, t, normal, &(*object).material) + return intersection(true, t, normal, &object.material) func intersect_box(ray: ray, object: object*) -> intersection: func sort(x: f32 mut*, y: f32 mut*): @@ -481,18 +480,18 @@ func intersect_box(ray: ray, object: object*) -> intersection: *x = *y *y = temp - let shape = &(*object).shape.data as box* - let inverse_rotation = inverse((*object).rotation) - let local_delta = rotate(inverse_rotation, sub(ray.origin, (*object).position)) + let shape = &object.shape.data as box* + let inverse_rotation = inverse(object.rotation) + let local_delta = rotate(inverse_rotation, sub(ray.origin, object.position)) let local_direction = rotate(inverse_rotation, ray.direction) // (o + t * d).x = +/- e.x - mut txmin = (- (*shape).extent.x - local_delta.x) / local_direction.x - mut txmax = ( (*shape).extent.x - local_delta.x) / local_direction.x - mut tymin = (- (*shape).extent.y - local_delta.y) / local_direction.y - mut tymax = ( (*shape).extent.y - local_delta.y) / local_direction.y - mut tzmin = (- (*shape).extent.z - local_delta.z) / local_direction.z - mut tzmax = ( (*shape).extent.z - local_delta.z) / local_direction.z + mut txmin = (- shape.extent.x - local_delta.x) / local_direction.x + mut txmax = ( shape.extent.x - local_delta.x) / local_direction.x + mut tymin = (- shape.extent.y - local_delta.y) / local_direction.y + mut tymax = ( shape.extent.y - local_delta.y) / local_direction.y + mut tzmin = (- shape.extent.z - local_delta.z) / local_direction.z + mut tzmax = ( shape.extent.z - local_delta.z) / local_direction.z sort(&mut txmin, &mut txmax) sort(&mut tymin, &mut tymax) @@ -529,20 +528,20 @@ func intersect_box(ray: ray, object: object*) -> intersection: if inside: normal = mults(normal, -1.0) - return intersection(true, t, rotate((*object).rotation, normal), &(*object).material) + return intersection(true, t, rotate(object.rotation, normal), &object.material) func intersect_scene(scene: scene*, ray: ray) -> intersection: mut intersection = no_intersection mut i = 0ul - while i < (*scene).objects.size: + while i < scene.objects.size: mut current_intersection = no_intersection - if (*scene).objects.data[i].shape.tag == plane_tag: - current_intersection = intersect_plane(ray, &(*scene).objects.data[i]) - else if (*scene).objects.data[i].shape.tag == sphere_tag: - current_intersection = intersect_sphere(ray, &(*scene).objects.data[i]) - else if (*scene).objects.data[i].shape.tag == box_tag: - current_intersection = intersect_box(ray, &(*scene).objects.data[i]) + if scene.objects.data[i].shape.tag == plane_tag: + current_intersection = intersect_plane(ray, &scene.objects.data[i]) + else if scene.objects.data[i].shape.tag == sphere_tag: + current_intersection = intersect_sphere(ray, &scene.objects.data[i]) + else if scene.objects.data[i].shape.tag == box_tag: + current_intersection = intersect_box(ray, &scene.objects.data[i]) if current_intersection.intersected && current_intersection.distance < intersection.distance: intersection = current_intersection @@ -561,7 +560,7 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3: let intersection = intersect_scene(scene, current_ray) if !intersection.intersected: - result = add(result, multv(factor, (*scene).background)) + result = add(result, multv(factor, scene.background)) break // Uncomment to debug normals @@ -570,7 +569,7 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3: let cosine = - dot(intersection.normal, current_ray.direction) let inside = cosine < 0.0 - result = add(result, multv(factor, (*intersection.material).emission)) + result = add(result, multv(factor, intersection.material.emission)) // Russian roulette ray termination if next_f32(rng) < termination_probability: @@ -578,32 +577,32 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3: mut new_direction = vec3(0.0, 0.0, 0.0) - if (*intersection.material).type == diffuse_tag: + if intersection.material.type == diffuse_tag: // NB: albedo is assumed to be premultiplied by pi to be in [0..1] range // This should also contain multiplication by cos(new_dir, normal), division by direction pdf (cos / pi) // and division by pi (because of albedo normalization), but these all cancel out - factor = multv(factor, (*intersection.material).color) + factor = multv(factor, intersection.material.color) // Cosine-weighted hemisphere direction new_direction = normalized(add(next_vec3(rng), intersection.normal)) - else if (*intersection.material).type == metallic_tag: + else if intersection.material.type == metallic_tag: // This should also contain multiplication by brdf and division by direction pdf, // but we'll just pretend that the random reflected ray pdf coincides with brdf and thus cancels out - factor = multv(factor, (*intersection.material).color) + factor = multv(factor, intersection.material.color) // Compute perfect-mirror reflected direction new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine)) // Alter the direction based on roughness - new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * (*intersection.material).roughness))) + new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), cosine * intersection.material.roughness))) - else if (*intersection.material).type == glass_tag: + else if intersection.material.type == glass_tag: // This should also contain multiplication by brdf and division by direction pdf, // but we'll just pretend that the random refracted ray pdf coincides with brdf and thus cancels out - factor = multv(factor, (*intersection.material).color) + factor = multv(factor, intersection.material.color) - mut ior = (*intersection.material).ior + mut ior = intersection.material.ior if inside: ior = 1.0 / ior @@ -616,7 +615,7 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3: new_direction = add(current_ray.direction, mults(intersection.normal, 2.0 * cosine)) // Alter the direction based on roughness - new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), (*intersection.material).roughness))) + new_direction = normalized(add(new_direction, mults(next_vec3_normal(rng), intersection.material.roughness))) // Compute the new ray, and offset its origin a bit along intersection normal let position = add(current_ray.origin, mults(current_ray.direction, intersection.distance)) @@ -759,17 +758,17 @@ const default_fovy = 2.0 * atan(0.5) func main(): let scene = make_default_scene() - let image = create_image(512ul, 512ul) + // let image = create_image(512ul, 512ul) // let image = create_image(256ul, 256ul) - // let image = create_image(128ul, 128ul) + let image = create_image(128ul, 128ul) let aspect_ratio = (image.width as f32) / (image.height as f32) let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio)) // const samples_per_pixel = 16384ul - const samples_per_pixel = 4096ul + // const samples_per_pixel = 4096ul // const samples_per_pixel = 2048ul - // const samples_per_pixel = 1024ul + const samples_per_pixel = 1024ul // const samples_per_pixel = 512ul // const samples_per_pixel = 256ul // const samples_per_pixel = 16ul @@ -786,28 +785,27 @@ func main(): func thread_func(data_raw: unit mut*): let data = data_raw as thread_data mut* - let image = (*data).image - mut y = (*data).ystart - while y < image.height: + mut y = data.ystart + while y < data.image.height: mut x = 0ul - while x < image.width: + while x < data.image.width: mut rng = rng([splitmix64(x | (y << 32u)), 10723151780598845931ul]) mut sample = 0ul mut color = vec3(0.0, 0.0, 0.0) while sample < samples_per_pixel: - let tx = - 1.0 + 2.0 * (x as f32 + next_f32(&mut rng)) / (image.width as f32) - let ty = 1.0 - 2.0 * (y as f32 + next_f32(&mut rng)) / (image.height as f32) + let tx = - 1.0 + 2.0 * (x as f32 + next_f32(&mut rng)) / (data.image.width as f32) + let ty = 1.0 - 2.0 * (y as f32 + next_f32(&mut rng)) / (data.image.height as f32) let ray = camera_ray((*data).camera, tx, ty) color = add(color, raytrace((*data).scene, ray, &mut rng)) sample += 1ul color = mults(color, 1.0 / (samples_per_pixel as f32)) - image.data[y * image.width + x] = to_bytes(gamma_correct(aces(color))) + data.image.data[y * data.image.width + x] = to_bytes(gamma_correct(aces(color))) x += 1ul - lock_mutex((*data).done_mutex) - (*data).done += 1ul - unlock_mutex((*data).done_mutex) - y += (*data).ystep + lock_mutex(data.done_mutex) + data.done += 1ul + unlock_mutex(data.done_mutex) + y += data.ystep let start_time = get_time() mut last_report_time = start_time @@ -822,14 +820,14 @@ func main(): mut th = 0ul while th < thread_count: let data = threads_data + th - (*data).scene = &scene - (*data).image = image - (*data).camera = camera - (*data).ystart = th - (*data).ystep = thread_count - (*data).done_mutex = allocate(64ul) as mutex mut* - create_mutex((*data).done_mutex) - (*data).done = 0ul + data.scene = &scene + data.image = image + data.camera = camera + data.ystart = th + data.ystep = thread_count + data.done_mutex = allocate(64ul) as mutex mut* + create_mutex(data.done_mutex) + data.done = 0ul threads[th] = create_thread(thread_func, data as unit mut*) th += 1ul diff --git a/libs/ast/source/type_check.cpp b/libs/ast/source/type_check.cpp index db45260..dc1ddab 100644 --- a/libs/ast/source/type_check.cpp +++ b/libs/ast/source/type_check.cpp @@ -760,6 +760,10 @@ namespace pslang::ast auto object_type = get_type(*node.object); auto struct_type = std::get_if(object_type.get()); + auto pointer_type = std::get_if(object_type.get()); + + if (pointer_type) + struct_type = std::get_if(pointer_type->referenced_type.get()); if (!struct_type) { @@ -973,7 +977,11 @@ namespace pslang::ast } else if (auto field_access = std::get_if(node.get())) { - return classify_lvalue(field_access->object); + auto object_type = get_type(*field_access->object); + if (auto pointer_type = std::get_if(object_type.get())) + return pointer_type->is_mutable ? ast::value_category::_mutable : ast::value_category::constant; + else + return classify_lvalue(field_access->object); } else if (auto array_access = std::get_if(node.get())) { diff --git a/libs/ir/source/compiler.cpp b/libs/ir/source/compiler.cpp index 98df31e..bde5aa7 100644 --- a/libs/ir/source/compiler.cpp +++ b/libs/ir/source/compiler.cpp @@ -461,13 +461,30 @@ namespace pslang::ir { auto object = apply(*node.object); auto object_type = ast::get_type(*node.object); - auto struct_node = std::get_if(object_type.get())->node; + auto struct_type = std::get_if(object_type.get()); + auto pointer_type = std::get_if(object_type.get()); + if (pointer_type) + struct_type = std::get_if(pointer_type->referenced_type.get()); + auto struct_node = struct_type->node; for (std::size_t i = 0; i < struct_node->fields.size(); ++i) { - if (struct_node->fields[i].name == node.field_name) + auto const & field = struct_node->fields[i]; + if (field.name == node.field_name) { - mcontext.nodes->emplace_back(copy{object, {i}}, node.inferred_type); - return last(); + if (pointer_type) + { + mcontext.nodes->emplace_back(literal{ast::literal{ast::u64_literal{field.layout.offset}}}, + std::make_shared(types::primitive_type{types::u64_type{}})); + mcontext.nodes->emplace_back(binary_operation{ast::binary_operation_type::addition, object, last()}, + std::make_shared(types::pointer_type{field.inferred_type, pointer_type->is_mutable})); + mcontext.nodes->emplace_back(load{last()}, field.inferred_type); + return last(); + } + else + { + mcontext.nodes->emplace_back(copy{object, {i}}, node.inferred_type); + return last(); + } } } throw std::runtime_error("Unknown field name"); @@ -514,14 +531,17 @@ namespace pslang::ir else if (auto field_access = std::get_if(lhs_node.get())) { auto object_type = ast::get_type(*field_access->object); - auto struct_node = std::get_if(object_type.get())->node; - for (std::size_t i = 0; i < struct_node->fields.size(); ++i) + if (auto struct_type = std::get_if(object_type.get())) { - auto const & field = struct_node->fields[i]; - if (field.name == field_access->field_name) + auto struct_node = struct_type->node; + for (std::size_t i = 0; i < struct_node->fields.size(); ++i) { - path.push_back(i); - return apply_field_chain_assignment(field_access->object, rhs, std::move(path)); + auto const & field = struct_node->fields[i]; + if (field.name == field_access->field_name) + { + path.push_back(i); + return apply_field_chain_assignment(field_access->object, rhs, std::move(path)); + } } } } @@ -589,8 +609,22 @@ namespace pslang::ir if (auto field_access = std::get_if(node.get())) { auto object_type = ast::get_type(*field_access->object); - auto struct_node = std::get_if(object_type.get())->node; - if (auto object_ptr = apply_get_address(field_access->object)) + auto struct_type = std::get_if(object_type.get()); + auto pointer_type = std::get_if(object_type.get()); + + std::optional object_ptr; + if (pointer_type) + { + struct_type = std::get_if(pointer_type->referenced_type.get()); + object_ptr = apply(*field_access->object); + } + else + { + object_ptr = apply_get_address(field_access->object); + } + + auto struct_node = struct_type->node; + if (object_ptr) { for (std::size_t i = 0; i < struct_node->fields.size(); ++i) {