diff --git a/libs/util/include/psemek/util/binary_heap.hpp b/libs/util/include/psemek/util/binary_heap.hpp new file mode 100644 index 00000000..43c1c5d5 --- /dev/null +++ b/libs/util/include/psemek/util/binary_heap.hpp @@ -0,0 +1,302 @@ +#pragma once + +#include + +#include +#include + +namespace psemek::util +{ + + namespace detail + { + + template + struct binary_heap_node + { + T value; + + // intentionally uninitialized + binary_heap_node * parent, * left, * right; + + template + binary_heap_node(Args && ... args) + : value(std::forward(args)...) + {} + }; + + template + struct binary_heap_node_handle + { + using value_type = T; + + binary_heap_node_handle() = default; + binary_heap_node_handle(binary_heap_node_handle &&) = default; + + binary_heap_node_handle(std::unique_ptr> p) + : p_{std::move(p)} + {} + + binary_heap_node_handle & operator = (binary_heap_node_handle &&) = default; + + bool empty() const { return !static_cast(p_); } + + explicit operator bool() const { return static_cast(p_); } + + T & value() const { return p_->value; } + + private: + std::unique_ptr> p_; + + template > + friend struct util::binary_heap; + }; + + } + + // Like std::priority_queue, but supports iterators + // that aren't invalidated on heap operations + template > + struct binary_heap + : private ebo_helper + { + struct iterator + { + detail::binary_heap_node * node; + + T const & operator *() const { return node->value; } + }; + + using value_type = T; + using const_iterator = iterator; + + using node_handle = detail::binary_heap_node_handle; + + binary_heap(Compare const & comp = Compare{}) + : ebo_helper(std::move(comp)) + {} + + binary_heap(Compare && comp) + : ebo_helper(std::move(comp)) + {} + + iterator min() const + { + return iterator{root_}; + } + + iterator insert(T const & value) + { + std::unique_ptr p(new node{value}); + return insert(node_handle{std::move(p)}); + } + + iterator insert(T && value) + { + std::unique_ptr p(new node{std::move(value)}); + return insert(node_handle{std::move(p)}); + } + + iterator insert(node_handle handle) + { + if (!root_) + { + root_ = handle.p_.release(); + last_ = root_; + return iterator{root_}; + } + + last_ = next_last(last_, handle.p_.release()); + sift_up(last_); + return + } + + void erase(iterator pos) + { + extract(pos); + } + + node_handle extract(iterator pos) + { + node * new_last = prev_last(last_); + + if (pos.node == last_) + { + tear_node(pos.node); + } + else + { + tear_node(last_); + replace_node(pos.node, last_); + pos.node->parent = nullptr; + pos.node->left = nullptr; + pos.node->right = nullptr; + } + + last_ = new_last; + + return node_handle{std::unique_ptr(pos.node)}; + } + + std::size_t size() const { return size_; } + bool empty() const { return size() == 0; } + + private: + using node = detail::binary_heap_node; + + node * root_ = nullptr; + node * last_ = nullptr; + std::size_t size_ = 0; + + bool compare(T const & x, T const & y) const + { + return ebo_helper::data()(x, y); + } + + // n should have no children + static void tear_node(node * n) + { + + if (n->parent) + { + if (n == n->parent->left) + n->parent->left = nullptr; + else + n->parent->right = nullptr; + } + n->parent = nullptr; + } + + // src should be a free node: no parent, no children + static void replace_node(node const * dst, node * src) + { + src->parent = dst->parent; + if (src->parent) + { + if (dst == dst->parent->left) + dst->parent->left = src; + else + dst->parent->right = src; + } + + src->left = dst->left; + if (src->left) + src->left->parent = src; + + src->right = dst->right; + if (src->right) + src->right->parent = src; + } + + // swap n with its parent + static void move_node_up(node * n) + { + if (n == n->parent->left) + { + auto a = n->left; + auto b = n->right; + auto c = n->parent->right; + auto p = n->parent; + auto g = p->parent; + + n->parent = g; + if (g) + { + if (g->left == p) + g->left = n; + else + g->right = n; + } + n->left = p; + n->right = c; + if (c) c->parent = n; + + p->parent = n; + p->left = a; + if (a) a->parent = p; + p->right = b; + if (b) b->parent = p; + } + else + { + auto a = n->parent->left; + auto b = n->left; + auto c = n->right; + auto p = n->parent; + auto g = p->parent; + + n->parent = g; + if (g) + { + if (g->left == p) + g->left = n; + else + g->right = n; + } + n->left = a; + if (a) a->parent = n; + n->right = p; + + p->parent = n; + p->left = b; + if (b) b->parent = p; + p->right = c; + if (c) c->parent = p; + } + } + + static void sift_up(node * n) + { + while (n->parent && compare(n->value, n->parent->value)) + move_node_up(n); + } + + static void sift_down(node * n) + { + while (n->left || n->right) + { + if (!n->right) + { + if (compare(n->value, n->left->value)) break; + move_node_up(n->left); + } + else if (!n->left) + { + if (compare(n->value, n->right->value)) break; + move_node_up(n->right); + } + else + { + if (compare(n->value, n->left->value) && compare(n->value, n->right->value)) break; + move_node_up(compare(n->left->value, n->right->value) ? n->left : n->right); + } + } + } + + static void sift(node * n) + { + if (n->parent && compare(n->value, n->parent->value)) + sift_up(n); + else + sift_down(n); + } + + static node * prev_last(node * n) + { + if (!n->parent) + return nullptr; + + while (n->parent && n == n->parent->left) + n = n->parent; + + if (n->parent) + n = n->parent->left; + + while (n->right) + n = n->right; + + return n; + } + }; + +}