Add a basic k-d tree implementation
This commit is contained in:
parent
c59b28e13f
commit
073ac16223
1 changed files with 227 additions and 0 deletions
227
libs/cg/include/psemek/cg/kdtree.hpp
Normal file
227
libs/cg/include/psemek/cg/kdtree.hpp
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <psemek/math/point.hpp>
|
||||||
|
#include <psemek/math/interval.hpp>
|
||||||
|
#include <psemek/math/math.hpp>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace psemek::cg
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace detail
|
||||||
|
{
|
||||||
|
|
||||||
|
template <typename Point, typename Data>
|
||||||
|
struct kdtree_value
|
||||||
|
{
|
||||||
|
Point point;
|
||||||
|
Data data;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Point>
|
||||||
|
Point & get_point(Point & point)
|
||||||
|
{
|
||||||
|
return point;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Point, typename Data>
|
||||||
|
Point & get_point(kdtree_value<Point, Data> & value)
|
||||||
|
{
|
||||||
|
return value.point;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Point, typename Data>
|
||||||
|
Point const & get_point(kdtree_value<Point, Data> const & value)
|
||||||
|
{
|
||||||
|
return value.point;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Point, typename Data>
|
||||||
|
struct kdtree_by_axis_comparator
|
||||||
|
{
|
||||||
|
std::uint32_t axis;
|
||||||
|
|
||||||
|
bool operator() (kdtree_value<Point, Data> const & left, kdtree_value<Point, Data> const & right)
|
||||||
|
{
|
||||||
|
return left.point[axis] < right.point[axis];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Point>
|
||||||
|
struct kdtree_by_axis_comparator<Point, void>
|
||||||
|
{
|
||||||
|
std::uint32_t axis;
|
||||||
|
|
||||||
|
bool operator() (Point const & left, Point const & right)
|
||||||
|
{
|
||||||
|
return left[axis] < right[axis];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t N, typename Data = void>
|
||||||
|
struct kdtree
|
||||||
|
{
|
||||||
|
using scalar_type = T;
|
||||||
|
using point_type = math::point<T, N>;
|
||||||
|
using value_type = std::conditional_t<std::is_same_v<Data, void>,
|
||||||
|
point_type,
|
||||||
|
detail::kdtree_value<point_type, Data>
|
||||||
|
>;
|
||||||
|
|
||||||
|
kdtree() = default;
|
||||||
|
|
||||||
|
template <typename Iterator>
|
||||||
|
kdtree(Iterator begin, Iterator end);
|
||||||
|
|
||||||
|
bool empty() const { return nodes_.empty(); }
|
||||||
|
|
||||||
|
bool insert(value_type && value);
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
bool remove(point_type const & point) const;
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
// TODO: alternative non-const version that allows modifying value.data
|
||||||
|
value_type const * find(point_type const & point) const;
|
||||||
|
|
||||||
|
value_type const & closest(point_type const & target);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
using node_id = std::uint32_t;
|
||||||
|
|
||||||
|
static constexpr auto null = static_cast<node_id>(-1);
|
||||||
|
|
||||||
|
static std::uint32_t next_axis(std::uint32_t axis)
|
||||||
|
{
|
||||||
|
return (axis + 1) % N;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NB: splitting dimension is implicit and depends on node's depth
|
||||||
|
struct node
|
||||||
|
{
|
||||||
|
value_type value;
|
||||||
|
node_id children[2] {null, null};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<node> nodes_;
|
||||||
|
|
||||||
|
template <typename Iterator>
|
||||||
|
node_id build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis);
|
||||||
|
|
||||||
|
bool insert_impl(value_type && value, node_id id, std::uint32_t split_axis);
|
||||||
|
|
||||||
|
value_type const * closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, std::size_t N, typename Data>
|
||||||
|
template <typename Iterator>
|
||||||
|
kdtree<T, N, Data>::kdtree(Iterator begin, Iterator end)
|
||||||
|
{
|
||||||
|
build_node_impl(begin, end, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t N, typename Data>
|
||||||
|
bool kdtree<T, N, Data>::insert(value_type && value)
|
||||||
|
{
|
||||||
|
return insert_impl(std::move(value), 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t N, typename Data>
|
||||||
|
kdtree<T, N, Data>::value_type const & kdtree<T, N, Data>::closest(point_type const & target)
|
||||||
|
{
|
||||||
|
if (nodes_.empty())
|
||||||
|
throw util::exception("empty kdtree");
|
||||||
|
|
||||||
|
auto best_distance_sqr = math::limits<scalar_type>::max();
|
||||||
|
return *closest_impl(target, best_distance_sqr, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t N, typename Data>
|
||||||
|
template <typename Iterator>
|
||||||
|
kdtree<T, N, Data>::node_id kdtree<T, N, Data>::build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis)
|
||||||
|
{
|
||||||
|
if (begin == end)
|
||||||
|
return null;
|
||||||
|
|
||||||
|
detail::kdtree_by_axis_comparator<point_type, Data> comparator{split_axis};
|
||||||
|
|
||||||
|
auto middle = std::next(begin, (end - begin) / 2);
|
||||||
|
std::nth_element(begin, middle, end, comparator);
|
||||||
|
{
|
||||||
|
auto middle_value = *middle;
|
||||||
|
middle = std::partition(begin, end, [&](auto const & value){ return comparator(value, middle_value); });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = static_cast<node_id>(nodes_.size());
|
||||||
|
auto & node = nodes_.emplace_back();
|
||||||
|
node.value = std::move(*middle);
|
||||||
|
|
||||||
|
nodes_[result].children[0] = build_node_impl(begin, middle, next_axis(split_axis));
|
||||||
|
nodes_[result].children[1] = build_node_impl(std::next(middle), end, next_axis(split_axis));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t N, typename Data>
|
||||||
|
bool kdtree<T, N, Data>::insert_impl(value_type && value, node_id id, std::uint32_t split_axis)
|
||||||
|
{
|
||||||
|
if (id == 0 && nodes_.empty())
|
||||||
|
{
|
||||||
|
auto & node = nodes_.emplace_back();
|
||||||
|
node.value = std::move(value);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const & point = detail::get_point(value);
|
||||||
|
|
||||||
|
auto & node = nodes_[id];
|
||||||
|
auto const & node_point = detail::get_point(node.value);
|
||||||
|
if (node_point == point)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
int child = (node.point[split_axis] < point[split_axis]) ? 0 : 1;
|
||||||
|
|
||||||
|
if (node.children[child] == null)
|
||||||
|
{
|
||||||
|
auto child_id = static_cast<node_id>(nodes_.size());
|
||||||
|
node.children[child] = child_id;
|
||||||
|
|
||||||
|
auto & child_node = nodes_.emplace_back();
|
||||||
|
child_node.value = std::move(value);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return insert_impl(std::move(value), node.children[child], next_axis(split_axis));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::size_t N, typename Data>
|
||||||
|
kdtree<T, N, Data>::value_type const * kdtree<T, N, Data>::closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const
|
||||||
|
{
|
||||||
|
kdtree<T, N, Data>::value_type const * result = nullptr;
|
||||||
|
|
||||||
|
auto const & node = nodes_[id];
|
||||||
|
auto const & node_point = detail::get_point(node.value);
|
||||||
|
|
||||||
|
if (math::make_min(best_distance_sqr, math::distance_sqr(node_point, target)))
|
||||||
|
result = &node.value;
|
||||||
|
|
||||||
|
auto delta = target[split_axis] - node_point[split_axis];
|
||||||
|
auto delta_sqr = math::sqr(delta);
|
||||||
|
|
||||||
|
if (node.children[0] != null && (delta < 0 || delta_sqr < best_distance_sqr))
|
||||||
|
if (auto new_result = closest_impl(target, best_distance_sqr, node.children[0], next_axis(split_axis)))
|
||||||
|
result = new_result;
|
||||||
|
|
||||||
|
if (node.children[1] != null && (delta >= 0 || delta_sqr <= best_distance_sqr))
|
||||||
|
if (auto new_result = closest_impl(target, best_distance_sqr, node.children[1], next_axis(split_axis)))
|
||||||
|
result = new_result;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue