diff --git a/libs/cg/include/psemek/cg/kdtree.hpp b/libs/cg/include/psemek/cg/kdtree.hpp new file mode 100644 index 00000000..13373606 --- /dev/null +++ b/libs/cg/include/psemek/cg/kdtree.hpp @@ -0,0 +1,227 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace psemek::cg +{ + + namespace detail + { + + template + struct kdtree_value + { + Point point; + Data data; + }; + + template + Point & get_point(Point & point) + { + return point; + } + + template + Point & get_point(kdtree_value & value) + { + return value.point; + } + + template + Point const & get_point(kdtree_value const & value) + { + return value.point; + } + + template + struct kdtree_by_axis_comparator + { + std::uint32_t axis; + + bool operator() (kdtree_value const & left, kdtree_value const & right) + { + return left.point[axis] < right.point[axis]; + } + }; + + template + struct kdtree_by_axis_comparator + { + std::uint32_t axis; + + bool operator() (Point const & left, Point const & right) + { + return left[axis] < right[axis]; + } + }; + + } + + template + struct kdtree + { + using scalar_type = T; + using point_type = math::point; + using value_type = std::conditional_t, + point_type, + detail::kdtree_value + >; + + kdtree() = default; + + template + kdtree(Iterator begin, Iterator end); + + bool empty() const { return nodes_.empty(); } + + bool insert(value_type && value); + + // TODO: implement + bool remove(point_type const & point) const; + + // TODO: implement + // TODO: alternative non-const version that allows modifying value.data + value_type const * find(point_type const & point) const; + + value_type const & closest(point_type const & target); + + private: + + using node_id = std::uint32_t; + + static constexpr auto null = static_cast(-1); + + static std::uint32_t next_axis(std::uint32_t axis) + { + return (axis + 1) % N; + } + + // NB: splitting dimension is implicit and depends on node's depth + struct node + { + value_type value; + node_id children[2] {null, null}; + }; + + std::vector nodes_; + + template + node_id build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis); + + bool insert_impl(value_type && value, node_id id, std::uint32_t split_axis); + + value_type const * closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const; + }; + + template + template + kdtree::kdtree(Iterator begin, Iterator end) + { + build_node_impl(begin, end, 0); + } + + template + bool kdtree::insert(value_type && value) + { + return insert_impl(std::move(value), 0, 0); + } + + template + kdtree::value_type const & kdtree::closest(point_type const & target) + { + if (nodes_.empty()) + throw util::exception("empty kdtree"); + + auto best_distance_sqr = math::limits::max(); + return *closest_impl(target, best_distance_sqr, 0, 0); + } + + template + template + kdtree::node_id kdtree::build_node_impl(Iterator begin, Iterator end, std::uint32_t split_axis) + { + if (begin == end) + return null; + + detail::kdtree_by_axis_comparator comparator{split_axis}; + + auto middle = std::next(begin, (end - begin) / 2); + std::nth_element(begin, middle, end, comparator); + { + auto middle_value = *middle; + middle = std::partition(begin, end, [&](auto const & value){ return comparator(value, middle_value); }); + } + + auto result = static_cast(nodes_.size()); + auto & node = nodes_.emplace_back(); + node.value = std::move(*middle); + + nodes_[result].children[0] = build_node_impl(begin, middle, next_axis(split_axis)); + nodes_[result].children[1] = build_node_impl(std::next(middle), end, next_axis(split_axis)); + + return result; + } + + template + bool kdtree::insert_impl(value_type && value, node_id id, std::uint32_t split_axis) + { + if (id == 0 && nodes_.empty()) + { + auto & node = nodes_.emplace_back(); + node.value = std::move(value); + return true; + } + + auto const & point = detail::get_point(value); + + auto & node = nodes_[id]; + auto const & node_point = detail::get_point(node.value); + if (node_point == point) + return false; + + int child = (node.point[split_axis] < point[split_axis]) ? 0 : 1; + + if (node.children[child] == null) + { + auto child_id = static_cast(nodes_.size()); + node.children[child] = child_id; + + auto & child_node = nodes_.emplace_back(); + child_node.value = std::move(value); + return true; + } + + return insert_impl(std::move(value), node.children[child], next_axis(split_axis)); + } + + template + kdtree::value_type const * kdtree::closest_impl(point_type const & target, scalar_type & best_distance_sqr, node_id id, std::uint32_t split_axis) const + { + kdtree::value_type const * result = nullptr; + + auto const & node = nodes_[id]; + auto const & node_point = detail::get_point(node.value); + + if (math::make_min(best_distance_sqr, math::distance_sqr(node_point, target))) + result = &node.value; + + auto delta = target[split_axis] - node_point[split_axis]; + auto delta_sqr = math::sqr(delta); + + if (node.children[0] != null && (delta < 0 || delta_sqr < best_distance_sqr)) + if (auto new_result = closest_impl(target, best_distance_sqr, node.children[0], next_axis(split_axis))) + result = new_result; + + if (node.children[1] != null && (delta >= 0 || delta_sqr <= best_distance_sqr)) + if (auto new_result = closest_impl(target, best_distance_sqr, node.children[1], next_axis(split_axis))) + result = new_result; + + return result; + } + +}