Add a basic k-d tree implementation
This commit is contained in:
parent
c59b28e13f
commit
073ac16223
1 changed files with 227 additions and 0 deletions
227
libs/cg/include/psemek/cg/kdtree.hpp
Normal file
227
libs/cg/include/psemek/cg/kdtree.hpp
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
#pragma once
|
||||
|
||||
#include <psemek/math/point.hpp>
|
||||
#include <psemek/math/interval.hpp>
|
||||
#include <psemek/math/math.hpp>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace psemek::cg
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <typename Point, typename Data>
|
||||
struct kdtree_value
|
||||
{
|
||||
Point point;
|
||||
Data data;
|
||||
};
|
||||
|
||||
template <typename Point>
|
||||
Point & get_point(Point & point)
|
||||
{
|
||||
return point;
|
||||
}
|
||||
|
||||
template <typename Point, typename Data>
|
||||
Point & get_point(kdtree_value<Point, Data> & value)
|
||||
{
|
||||
return value.point;
|
||||
}
|
||||
|
||||
template <typename Point, typename Data>
|
||||
Point const & get_point(kdtree_value<Point, Data> const & value)
|
||||
{
|
||||
return value.point;
|
||||
}
|
||||
|
||||
template <typename Point, typename Data>
|
||||
struct kdtree_by_axis_comparator
|
||||
{
|
||||
std::uint32_t axis;
|
||||
|
||||
bool operator() (kdtree_value<Point, Data> const & left, kdtree_value<Point, Data> const & right)
|
||||
{
|
||||
return left.point[axis] < right.point[axis];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Point>
|
||||
struct kdtree_by_axis_comparator<Point, void>
|
||||
{
|
||||
std::uint32_t axis;
|
||||
|
||||
bool operator() (Point const & left, Point const & right)
|
||||
{
|
||||
return left[axis] < right[axis];
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, typename Data = void>
|
||||
struct kdtree
|
||||
{
|
||||
using scalar_type = T;
|
||||
using point_type = math::point<T, N>;
|
||||
using value_type = std::conditional_t<std::is_same_v<Data, void>,
|
||||
point_type,
|
||||
detail::kdtree_value<point_type, Data>
|
||||
>;
|
||||
|
||||
kdtree() = default;
|
||||
|
||||
template <typename Iterator>
|
||||
kdtree(Iterator begin, Iterator end);
|
||||
|
||||
bool empty() const { return nodes_.empty(); }
|
||||
|
||||
bool insert(value_type && value);
|
||||
|
||||
// TODO: implement
|
||||
bool remove(point_type const & point) const;
|
||||
|
||||
// TODO: implement
|
||||
// TODO: alternative non-const version that allows modifying value.data
|
||||
value_type const * find(point_type const & point) const;
|
||||
|
||||
value_type const & closest(point_type const & target);
|
||||
|
||||
private:
|
||||
|
||||
using node_id = std::uint32_t;
|
||||
|
||||
static constexpr auto null = static_cast<node_id>(-1);
|
||||
|
||||
static std::uint32_t next_axis(std::uint32_t axis)
|
||||
{
|
||||
return (axis + 1) % N;
|
||||
}
|
||||
|
||||
// NB: splitting dimension is implicit and depends on node's depth
|
||||
struct node
|
||||
{
|
||||
value_type value;
|
||||
node_id children[2] {null, null};
|
||||
};
|
||||
|
||||
std::vector<node> nodes_;
|
||||
|
||||
template <typename Iterator>
|
||||
node_id build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis);
|
||||
|
||||
bool insert_impl(value_type && value, node_id id, std::uint32_t split_axis);
|
||||
|
||||
value_type const * closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const;
|
||||
};
|
||||
|
||||
template <typename T, std::size_t N, typename Data>
|
||||
template <typename Iterator>
|
||||
kdtree<T, N, Data>::kdtree(Iterator begin, Iterator end)
|
||||
{
|
||||
build_node_impl(begin, end, 0);
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, typename Data>
|
||||
bool kdtree<T, N, Data>::insert(value_type && value)
|
||||
{
|
||||
return insert_impl(std::move(value), 0, 0);
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, typename Data>
|
||||
kdtree<T, N, Data>::value_type const & kdtree<T, N, Data>::closest(point_type const & target)
|
||||
{
|
||||
if (nodes_.empty())
|
||||
throw util::exception("empty kdtree");
|
||||
|
||||
auto best_distance_sqr = math::limits<scalar_type>::max();
|
||||
return *closest_impl(target, best_distance_sqr, 0, 0);
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, typename Data>
|
||||
template <typename Iterator>
|
||||
kdtree<T, N, Data>::node_id kdtree<T, N, Data>::build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis)
|
||||
{
|
||||
if (begin == end)
|
||||
return null;
|
||||
|
||||
detail::kdtree_by_axis_comparator<point_type, Data> comparator{split_axis};
|
||||
|
||||
auto middle = std::next(begin, (end - begin) / 2);
|
||||
std::nth_element(begin, middle, end, comparator);
|
||||
{
|
||||
auto middle_value = *middle;
|
||||
middle = std::partition(begin, end, [&](auto const & value){ return comparator(value, middle_value); });
|
||||
}
|
||||
|
||||
auto result = static_cast<node_id>(nodes_.size());
|
||||
auto & node = nodes_.emplace_back();
|
||||
node.value = std::move(*middle);
|
||||
|
||||
nodes_[result].children[0] = build_node_impl(begin, middle, next_axis(split_axis));
|
||||
nodes_[result].children[1] = build_node_impl(std::next(middle), end, next_axis(split_axis));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, typename Data>
|
||||
bool kdtree<T, N, Data>::insert_impl(value_type && value, node_id id, std::uint32_t split_axis)
|
||||
{
|
||||
if (id == 0 && nodes_.empty())
|
||||
{
|
||||
auto & node = nodes_.emplace_back();
|
||||
node.value = std::move(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto const & point = detail::get_point(value);
|
||||
|
||||
auto & node = nodes_[id];
|
||||
auto const & node_point = detail::get_point(node.value);
|
||||
if (node_point == point)
|
||||
return false;
|
||||
|
||||
int child = (node.point[split_axis] < point[split_axis]) ? 0 : 1;
|
||||
|
||||
if (node.children[child] == null)
|
||||
{
|
||||
auto child_id = static_cast<node_id>(nodes_.size());
|
||||
node.children[child] = child_id;
|
||||
|
||||
auto & child_node = nodes_.emplace_back();
|
||||
child_node.value = std::move(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
return insert_impl(std::move(value), node.children[child], next_axis(split_axis));
|
||||
}
|
||||
|
||||
template <typename T, std::size_t N, typename Data>
|
||||
kdtree<T, N, Data>::value_type const * kdtree<T, N, Data>::closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const
|
||||
{
|
||||
kdtree<T, N, Data>::value_type const * result = nullptr;
|
||||
|
||||
auto const & node = nodes_[id];
|
||||
auto const & node_point = detail::get_point(node.value);
|
||||
|
||||
if (math::make_min(best_distance_sqr, math::distance_sqr(node_point, target)))
|
||||
result = &node.value;
|
||||
|
||||
auto delta = target[split_axis] - node_point[split_axis];
|
||||
auto delta_sqr = math::sqr(delta);
|
||||
|
||||
if (node.children[0] != null && (delta < 0 || delta_sqr < best_distance_sqr))
|
||||
if (auto new_result = closest_impl(target, best_distance_sqr, node.children[0], next_axis(split_axis)))
|
||||
result = new_result;
|
||||
|
||||
if (node.children[1] != null && (delta >= 0 || delta_sqr <= best_distance_sqr))
|
||||
if (auto new_result = closest_impl(target, best_distance_sqr, node.children[1], next_axis(split_axis)))
|
||||
result = new_result;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue