Add a basic k-d tree implementation

This commit is contained in:
Nikita Lisitsa 2025-01-09 18:22:28 +03:00
parent c59b28e13f
commit 073ac16223

View file

@ -0,0 +1,227 @@
#pragma once
#include <psemek/math/point.hpp>
#include <psemek/math/interval.hpp>
#include <psemek/math/math.hpp>
#include <vector>
#include <algorithm>
namespace psemek::cg
{
namespace detail
{
template <typename Point, typename Data>
struct kdtree_value
{
Point point;
Data data;
};
template <typename Point>
Point & get_point(Point & point)
{
return point;
}
template <typename Point, typename Data>
Point & get_point(kdtree_value<Point, Data> & value)
{
return value.point;
}
template <typename Point, typename Data>
Point const & get_point(kdtree_value<Point, Data> const & value)
{
return value.point;
}
template <typename Point, typename Data>
struct kdtree_by_axis_comparator
{
std::uint32_t axis;
bool operator() (kdtree_value<Point, Data> const & left, kdtree_value<Point, Data> const & right)
{
return left.point[axis] < right.point[axis];
}
};
template <typename Point>
struct kdtree_by_axis_comparator<Point, void>
{
std::uint32_t axis;
bool operator() (Point const & left, Point const & right)
{
return left[axis] < right[axis];
}
};
}
template <typename T, std::size_t N, typename Data = void>
struct kdtree
{
using scalar_type = T;
using point_type = math::point<T, N>;
using value_type = std::conditional_t<std::is_same_v<Data, void>,
point_type,
detail::kdtree_value<point_type, Data>
>;
kdtree() = default;
template <typename Iterator>
kdtree(Iterator begin, Iterator end);
bool empty() const { return nodes_.empty(); }
bool insert(value_type && value);
// TODO: implement
bool remove(point_type const & point) const;
// TODO: implement
// TODO: alternative non-const version that allows modifying value.data
value_type const * find(point_type const & point) const;
value_type const & closest(point_type const & target);
private:
using node_id = std::uint32_t;
static constexpr auto null = static_cast<node_id>(-1);
static std::uint32_t next_axis(std::uint32_t axis)
{
return (axis + 1) % N;
}
// NB: splitting dimension is implicit and depends on node's depth
struct node
{
value_type value;
node_id children[2] {null, null};
};
std::vector<node> nodes_;
template <typename Iterator>
node_id build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis);
bool insert_impl(value_type && value, node_id id, std::uint32_t split_axis);
value_type const * closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const;
};
template <typename T, std::size_t N, typename Data>
template <typename Iterator>
kdtree<T, N, Data>::kdtree(Iterator begin, Iterator end)
{
build_node_impl(begin, end, 0);
}
template <typename T, std::size_t N, typename Data>
bool kdtree<T, N, Data>::insert(value_type && value)
{
return insert_impl(std::move(value), 0, 0);
}
template <typename T, std::size_t N, typename Data>
kdtree<T, N, Data>::value_type const & kdtree<T, N, Data>::closest(point_type const & target)
{
if (nodes_.empty())
throw util::exception("empty kdtree");
auto best_distance_sqr = math::limits<scalar_type>::max();
return *closest_impl(target, best_distance_sqr, 0, 0);
}
template <typename T, std::size_t N, typename Data>
template <typename Iterator>
kdtree<T, N, Data>::node_id kdtree<T, N, Data>::build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis)
{
if (begin == end)
return null;
detail::kdtree_by_axis_comparator<point_type, Data> comparator{split_axis};
auto middle = std::next(begin, (end - begin) / 2);
std::nth_element(begin, middle, end, comparator);
{
auto middle_value = *middle;
middle = std::partition(begin, end, [&](auto const & value){ return comparator(value, middle_value); });
}
auto result = static_cast<node_id>(nodes_.size());
auto & node = nodes_.emplace_back();
node.value = std::move(*middle);
nodes_[result].children[0] = build_node_impl(begin, middle, next_axis(split_axis));
nodes_[result].children[1] = build_node_impl(std::next(middle), end, next_axis(split_axis));
return result;
}
template <typename T, std::size_t N, typename Data>
bool kdtree<T, N, Data>::insert_impl(value_type && value, node_id id, std::uint32_t split_axis)
{
if (id == 0 && nodes_.empty())
{
auto & node = nodes_.emplace_back();
node.value = std::move(value);
return true;
}
auto const & point = detail::get_point(value);
auto & node = nodes_[id];
auto const & node_point = detail::get_point(node.value);
if (node_point == point)
return false;
int child = (node.point[split_axis] < point[split_axis]) ? 0 : 1;
if (node.children[child] == null)
{
auto child_id = static_cast<node_id>(nodes_.size());
node.children[child] = child_id;
auto & child_node = nodes_.emplace_back();
child_node.value = std::move(value);
return true;
}
return insert_impl(std::move(value), node.children[child], next_axis(split_axis));
}
template <typename T, std::size_t N, typename Data>
kdtree<T, N, Data>::value_type const * kdtree<T, N, Data>::closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const
{
kdtree<T, N, Data>::value_type const * result = nullptr;
auto const & node = nodes_[id];
auto const & node_point = detail::get_point(node.value);
if (math::make_min(best_distance_sqr, math::distance_sqr(node_point, target)))
result = &node.value;
auto delta = target[split_axis] - node_point[split_axis];
auto delta_sqr = math::sqr(delta);
if (node.children[0] != null && (delta < 0 || delta_sqr < best_distance_sqr))
if (auto new_result = closest_impl(target, best_distance_sqr, node.children[0], next_axis(split_axis)))
result = new_result;
if (node.children[1] != null && (delta >= 0 || delta_sqr <= best_distance_sqr))
if (auto new_result = closest_impl(target, best_distance_sqr, node.children[1], next_axis(split_axis)))
result = new_result;
return result;
}
}