Add multi-threading in raytracer example

This commit is contained in:
Nikita Lisitsa 2026-04-01 23:43:58 +03:00
parent e0c1db4978
commit 20c84596df
2 changed files with 143 additions and 43 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 563 KiB

After

Width:  |  Height:  |  Size: 423 KiB

View file

@ -146,6 +146,10 @@ func get_time() -> timespec:
func time_delta(x: timespec, y: timespec) -> f32:
return ((x.seconds - y.seconds) as f32) + ((x.nanoseconds - y.nanoseconds) as f32) / 1000000000.0
func sleep(time: timespec):
foreign func nanosleep(time: timespec*, rem: unit*) -> i32
nanosleep(&time, 0ul as unit*)
// ===== Image =====
struct image:
@ -219,6 +223,12 @@ func max(x: f32, y: f32) -> f32:
return x
return y
// I want function overloading so badly
func min_u32(x: u32, y: u32) -> u32:
if x < y:
return x
return y
func clamp(x: f32) -> f32:
return max(0.0, min(1.0, x))
@ -638,6 +648,43 @@ func raytrace(scene: scene*, camera_ray: ray, rng: rng mut*) -> vec3:
return result
// ===== Multithreading =====
// Opaque handle
struct thread:
// Platform-dependent!
handle: unit*
func create_thread(thread_func: unit mut* -> unit, thread_data: unit mut*) -> thread:
foreign func pthread_create(thread: thread mut*, attr: unit*, start_routine: unit mut* -> unit, arg: unit mut*) -> i32
mut result = thread()
pthread_create(&result, 0ul as unit*, thread_func, thread_data)
return result
func join_thread(thread: thread):
foreign func pthread_join(thread: thread, retval: unit*) -> i32
pthread_join(thread, 0ul as unit*)
struct mutex:
// Platform-dependent!
payload: u64[8]
func create_mutex(mutex: mutex mut*):
foreign func pthread_mutex_init(mutex: mutex mut*, attr: unit*) -> i32
pthread_mutex_init(mutex, 0ul as unit*)
func destroy_mutex(mutex: mutex mut*):
foreign func pthread_mutex_destroy(mutex: mutex mut *) -> i32
pthread_mutex_destroy(mutex)
func lock_mutex(mutex: mutex mut*):
foreign func pthread_mutex_lock(mutex: mutex mut*) -> i32
pthread_mutex_lock(mutex)
func unlock_mutex(mutex: mutex mut*):
foreign func pthread_mutex_unlock(mutex: mutex mut*) -> i32
pthread_mutex_unlock(mutex)
// ===== Main =====
func make_default_scene() -> scene:
@ -703,7 +750,7 @@ func make_default_scene() -> scene:
// Box
objects[6].position = vec3(-2.25, -2.0, -1.0)
objects[6].rotation = rotation(vec3(0.0, 1.0, 0.0), pi / 6.0)
objects[6].material.color = vec3(1.0, 1.0, 1.0)
objects[6].material.color = vec3(0.7, 0.85, 1.0)
objects[6].material.emission = vec3(0.0, 0.0, 0.0)
objects[6].material.type = glass_tag
objects[6].material.roughness = 0.0
@ -713,7 +760,7 @@ func make_default_scene() -> scene:
// Sphere
objects[7].position = vec3(2.5, -3.0, -1.0)
objects[7].rotation = rotation(vec3(0.0, 1.0, 0.0), - pi / 6.0)
objects[7].material.color = vec3(1.0, 1.0, 1.0)
objects[7].material.color = vec3(1.0, 0.8, 0.6)
objects[7].material.emission = vec3(0.0, 0.0, 0.0)
objects[7].material.type = metallic_tag
objects[7].material.roughness = 0.5
@ -731,23 +778,78 @@ func main():
let scene = make_default_scene()
let image = create_image(512ul, 512ul)
// let image = create_image(256ul, 256ul)
// let image = create_image(128ul, 128ul)
let aspect_ratio = (image.width as f32) / (image.height as f32)
let camera = camera(vec3(0.0, 0.0, 15.0), identity, default_fovy, compute_fovx(default_fovy, aspect_ratio))
// let samples_per_pixel = 4096ul
let samples_per_pixel = 2048ul
// let samples_per_pixel = 1024ul
// let samples_per_pixel = 512ul
// let samples_per_pixel = 256ul
// let samples_per_pixel = 16ul
// let samples_per_pixel = 1ul
// const samples_per_pixel = 16384ul
// const samples_per_pixel = 4096ul
// const samples_per_pixel = 2048ul
// const samples_per_pixel = 1024ul
// const samples_per_pixel = 512ul
const samples_per_pixel = 256ul
// const samples_per_pixel = 16ul
// const samples_per_pixel = 1ul
struct thread_data:
scene: scene*
image: image
camera: camera
ystart: u64
ystep: u64
done_mutex: mutex mut*
done: u64
func thread_func(data_raw: unit mut*):
let data = data_raw as thread_data mut*
let image = (*data).image
mut y = (*data).ystart
while y < image.height:
mut x = 0ul
while x < image.width:
mut rng = rng([splitmix64(x | (y << 32u)), 10723151780598845931ul])
mut sample = 0ul
mut color = vec3(0.0, 0.0, 0.0)
while sample < samples_per_pixel:
let tx = - 1.0 + 2.0 * (x as f32 + next_f32(&mut rng)) / (image.width as f32)
let ty = 1.0 - 2.0 * (y as f32 + next_f32(&mut rng)) / (image.height as f32)
let ray = camera_ray((*data).camera, tx, ty)
color = add(color, raytrace((*data).scene, ray, &mut rng))
sample += 1ul
color = mults(color, 1.0 / (samples_per_pixel as f32))
image.data[y * image.width + x] = to_bytes(gamma_correct(aces(color)))
x += 1ul
lock_mutex((*data).done_mutex)
(*data).done += 1ul
unlock_mutex((*data).done_mutex)
y += (*data).ystep
let start_time = get_time()
mut last_report_time = start_time
mut last_report_progress = 0.0
mut average_speed = 0.0
mut report_count = 0u
let thread_count = 9ul
let threads = allocate(thread_count * 8ul) as thread mut*
let threads_data = allocate(thread_count * 104ul) as thread_data mut*
mut th = 0ul
while th < thread_count:
let data = threads_data + th
(*data).scene = &scene
(*data).image = image
(*data).camera = camera
(*data).ystart = th
(*data).ystep = thread_count
(*data).done_mutex = allocate(64ul) as mutex mut*
create_mutex((*data).done_mutex)
(*data).done = 0ul
threads[th] = create_thread(thread_func, data as unit mut*)
th += 1ul
func clear_line(spaces: u32):
mut i = 0u
@ -757,65 +859,63 @@ func main():
i += 1u
print_byte('\r')
mut y = 0ul
while y < image.height:
mut x = 0ul
while x < image.width:
mut rng = rng([splitmix64(x | (y << 32u)), 10723151780598845931ul])
mut sample = 0ul
mut color = vec3(0.0, 0.0, 0.0)
while sample < samples_per_pixel:
let tx = - 1.0 + 2.0 * (x as f32 + next_f32(&mut rng)) / (image.width as f32)
let ty = 1.0 - 2.0 * (y as f32 + next_f32(&mut rng)) / (image.height as f32)
let ray = camera_ray(camera, tx, ty)
color = add(color, raytrace(&scene, ray, &mut rng))
sample += 1ul
color = mults(color, 1.0 / (samples_per_pixel as f32))
image.data[y * image.width + x] = to_bytes(gamma_correct(aces(color)))
mut done = 0ul
let total = image.width * image.height
while done < total:
sleep(timespec(0l, 125000000l))
let time = get_time()
let delta = time_delta(time, last_report_time)
if delta > 0.125 || last_report_progress == 0.0:
done = 0ul
x += 1ul
mut th = 0ul
while th < thread_count:
lock_mutex(threads_data[th].done_mutex)
done += threads_data[th].done
unlock_mutex(threads_data[th].done_mutex)
th += 1ul
let time = get_time()
let delta = time_delta(time, last_report_time)
if delta > 0.125 || (y == 0ul && x == 1ul) || (y + 1ul == image.height && x == image.width):
let done = y * image.width + x
let total = image.width * image.height
if done > 0ul:
let progress = (done as f32) / (total as f32)
let time_passed = time_delta(time, start_time)
// Running exponential-weighted average speed for
// accurate remaining time estimate
// Running exponential-weighted average speed
// for +/- accurate remaining time estimate
let speed_estimate = (progress - last_report_progress) / delta
if average_speed == 0.0:
average_speed = speed_estimate
else:
average_speed = average_speed + (speed_estimate - average_speed) * 0.125
average_speed = average_speed + (speed_estimate - average_speed) / (min_u32(report_count + 1u, 16u) as f32)
let time_remaining = (1.0 - progress) / average_speed
let str1 = ['%', ' ', 'd', 'o', 'n', 'e', ',', ' ', '\0']
let str2 = [' ', 'l', 'e', 'f', 't', ' ', '\0']
let str2 = [' ', 'p', 'a', 's', 's', 'e', 'd', ',', ' ', '\0']
let str3 = [' ', 'l', 'e', 'f', 't', ' ', '\0']
clear_line(30u)
clear_line(50u)
print_f32(100.0 * progress)
print_str(str1 as u8*)
print_time(time_remaining)
print_time(time_delta(time, start_time))
print_str(str2 as u8*)
print_time(time_remaining)
print_str(str3 as u8*)
flush()
last_report_time = time
last_report_progress = progress
report_count += 1u
y += 1ul
// No for loops yet...
th = 0ul
while th < thread_count:
join_thread(threads[th])
th += 1ul
let total_time = time_delta(get_time(), start_time)
let total_samples = image.width * image.height * samples_per_pixel
let str1 = ['R', 'e', 'n', 'd', 'e', 'r', ' ', 't', 'o', 'o', 'k', ' ', '\0']
let str2 = [' ', 'p', 'e', 'r', ' ', 's', 'a', 'm', 'p', 'l', 'e', '\0']
clear_line(30u)
clear_line(50u)
print_str(str1 as u8*)
print_time(total_time)
print_byte('\n')
print_time(total_time / (total_samples as f32))
print_time((thread_count as f32) * total_time / (total_samples as f32))
print_str(str2 as u8*)
print_byte('\n')