Add generic ndtree implementation

This commit is contained in:
Nikita Lisitsa 2022-06-29 12:13:10 +03:00
parent fa27ef4d79
commit 94cb475ecc

View file

@ -0,0 +1,227 @@
#pragma once
#include <psemek/geom/point.hpp>
#include <psemek/geom/box.hpp>
#include <psemek/geom/contains.hpp>
#include <memory>
#include <optional>
namespace psemek::cg
{
template <typename T, std::size_t N, typename Data>
struct ndtree;
template <typename T, typename Data>
using quadtree = ndtree<T, 2, Data>;
template <typename T, typename Data>
using octree = ndtree<T, 3, Data>;
template <typename T, std::size_t N, typename Data>
struct ndtree
{
public:
using point_type = geom::point<T, N>;
using box_type = geom::box<T, N>;
struct value_type
{
point_type point;
Data data;
template <typename A, typename B>
value_type(A && point, B && data)
: point(std::forward<A>(point))
, data(std::forward<B>(data))
{}
};
struct node;
using node_ptr = std::unique_ptr<node>;
using node_ref = node *;
struct node
{
node_ptr children[1 << N];
std::optional<value_type> value;
node() = default;
template <typename A, typename B>
node(A && point, B && data)
: value(std::in_place, std::forward<A>(point), std::forward<B>(data))
{}
bool leaf() const
{
return std::all_of(std::begin(children), std::end(children), [](node_ptr const & p){ return !static_cast<bool>(p); });
}
bool full() const
{
return std::all_of(std::begin(children), std::end(children), [](node_ptr const & p){ return static_cast<bool>(p); });
}
};
template <typename A, typename B>
std::pair<node_ref, bool> insert(A && point, B && data)
{
if (!root)
{
root = std::make_unique<node>(std::forward<A>(point), std::forward<B>(data));
for (std::size_t d = 0; d < N; ++d)
{
root_bbox[d].min = root->value->point[d];
root_bbox[d].max = root->value->point[d] + T{1};
}
return {root.get(), true};
}
if (!geom::half_open_contains(root_bbox, point))
{
while (!geom::half_open_contains(root_bbox, point))
extend_root(point);
}
node * current_node = root.get();
box_type current_bbox = root_bbox;
while (true)
{
if (current_node->leaf())
{
if (current_node->value)
{
if (current_node->value->point == point)
{
return {current_node, false};
}
else
{
split(current_node, current_bbox);
assert(!current_node->value);
}
}
else
{
current_node->value.emplace(std::forward<A>(point), std::forward<B>(data));
return {current_node, true};
}
}
std::size_t index = 0;
for (std::size_t d = 0; d < N; ++d)
{
auto const c = current_bbox[d].center();
if (point[d] < c)
{
current_bbox[d].max = c;
}
else
{
index |= (1 << d);
current_bbox[d].min = c;
}
}
auto & child = current_node->children[index];
if (!child)
{
child = std::make_unique<node>(std::forward<A>(point), std::forward<B>(data));
return {child.get(), true};
}
current_node = child.get();
}
}
template <typename Visitor>
void traverse(Visitor && visitor) const
{
if (!root) return;
traverse_impl(std::forward<Visitor>(visitor), root.get(), root_bbox);
}
template <typename Visitor>
void traverse(Visitor && visitor, node_ref current_node, box_type const & bbox) const
{
if (!current_node) return;
traverse_impl(std::forward<Visitor>(visitor), current_node, bbox);
}
private:
node_ptr root;
geom::box<T, N> root_bbox;
void extend_root(point_type const & point)
{
std::size_t index_in_new_root = 0;
for (std::size_t d = 0; d < N; ++d)
{
if (point[d] >= root_bbox[d].min)
root_bbox[d].max += root_bbox[d].length();
else
{
root_bbox[d].min -= root_bbox[d].length();
index_in_new_root |= (1 << d);
}
}
auto new_root = std::make_unique<node>();
new_root->children[index_in_new_root] = std::move(root);
root = std::move(new_root);
}
static void split(node * n, box_type const & bbox)
{
assert(n->value);
std::size_t index_in_parent = 0;
for (std::size_t d = 0; d < N; ++d)
{
if (n->value->point[d] >= bbox[d].center())
index_in_parent |= (1 << d);
}
auto & child = n->children[index_in_parent];
child = std::make_unique<node>(n->value->point, std::move(n->value->data));
n->value = std::nullopt;
}
template <typename Visitor>
static void traverse_impl(Visitor && visitor, node * current_node, box_type const & box)
{
std::size_t children_count = 0;
for (std::size_t i = 0; i < (1 << N); ++i)
if (current_node->children[i])
++children_count;
if (visitor(*current_node, box))
{
for (std::size_t i = 0; i < (1 << N); ++i)
{
if (!current_node->children[i]) continue;
box_type child_box;
for (std::size_t d = 0; d < N; ++d)
{
auto const c = box[d].center();
if (i & (1 << d))
child_box[d] = {c, box[d].max};
else
child_box[d] = {box[d].min, c};
}
traverse_impl(visitor, current_node->children[i].get(), child_box);
}
}
}
};
}