From a27378a3a7904f87f56c049af78da6c8e2e03f3b Mon Sep 17 00:00:00 2001 From: lisyarus Date: Thu, 24 Aug 2023 16:24:47 +0300 Subject: [PATCH] Add util::hash_set/map with some tests --- libs/util/include/psemek/util/hash_table.hpp | 401 +++++++++++++++++++ libs/util/tests/hash_table.cpp | 184 +++++++++ 2 files changed, 585 insertions(+) create mode 100644 libs/util/include/psemek/util/hash_table.hpp create mode 100644 libs/util/tests/hash_table.cpp diff --git a/libs/util/include/psemek/util/hash_table.hpp b/libs/util/include/psemek/util/hash_table.hpp new file mode 100644 index 00000000..c0b4513a --- /dev/null +++ b/libs/util/include/psemek/util/hash_table.hpp @@ -0,0 +1,401 @@ +#pragma once + +#include +#include + +#include +#include + +namespace psemek::util +{ + + namespace detail + { + + template + struct hash_table_entry + { + std::size_t hash; + std::optional value; + }; + + template + struct hash_table_iterator + { + using value_type = T; + using pointer = T *; + using reference = T &; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; + + using entry_type = hash_table_entry>; + + hash_table_iterator(entry_type * p, entry_type * end) + : p_(p) + , end_(end) + { + advance(); + } + + T & operator *() const + { + return *(p_->value); + } + + T * operator ->() const + { + return std::addressof(*(p_->value)); + } + + hash_table_iterator & operator ++() + { + ++p_; + advance(); + return *this; + } + + hash_table_iterator operator ++(int) + { + auto copy = *this; + this->operator++(); + return copy; + } + + friend bool operator == (hash_table_iterator const & it1, hash_table_iterator const & it2) + { + return it1.p_ == it2.p_; + } + + hash_table_iterator as_const() const + { + return {p_, end_}; + } + + private: + entry_type * p_; + entry_type * end_; + + void advance() + { + while (!(p_->value) && p_ != end_) + ++p_; + } + }; + + template + struct hash_table_storage + { + std::unique_ptr[]> table; + std::size_t capacity = 0; + + void reset(std::size_t capacity) + { + table.reset(new hash_table_entry[capacity]); + this->capacity = capacity; + } + + util::span> entries() + { + return {table.get(), table.get() + capacity}; + } + + hash_table_iterator iterator(std::size_t index) const + { + return {table.get() + index, table.get() + capacity}; + } + }; + + template + struct hash_table_impl + : Hash, Equal + { + template + hash_table_impl(H && h, K && k) + : Hash(std::forward(h)) + , Equal(std::forward(k)) + {} + + Hash const & hash() const { return *this; } + Equal const & equal() const { return *this; } + + template + std::pair, bool> insert(H && value) + { + ensure_capacity_for(size_ + 1); + std::size_t hash = this->hash()(value); + return insert_impl(std::forward(value), hash); + } + + template + hash_table_iterator find(Key const & key) const + { + std::size_t hash = this->hash()(key); + return find_impl(key, hash); + } + + void clear() + { + for (auto & entry : storage_.entries()) + entry.value.reset(); + size_ = 0; + } + + hash_table_iterator begin() const + { + return storage_.iterator(0); + } + + hash_table_iterator end() const + { + return storage_.iterator(storage_.capacity); + } + + std::size_t size() const + { + return size_; + } + + std::size_t capacity() const + { + return storage_.capacity; + } + + private: + hash_table_storage storage_; + std::size_t size_ = 0; + + static std::size_t min_capacity_for_size(std::size_t size) + { + // Ensure at most 0.5 load factor + return 2 * size; + } + + static std::size_t find_capacity(std::size_t current_capacity, std::size_t min_capacity) + { + current_capacity = std::max(current_capacity, std::size_t(16)); + while (current_capacity < min_capacity) + current_capacity *= 2; + return current_capacity; + } + + void ensure_capacity_for(std::size_t size) + { + std::size_t capacity = min_capacity_for_size(size); + if (storage_.capacity < capacity) + reallocate(find_capacity(storage_.capacity, capacity)); + } + + void reallocate(std::size_t capacity) + { + hash_table_storage storage; + storage.reset(capacity); + + std::swap(storage_, storage); + + size_ = 0; + + for (hash_table_entry & entry : storage.entries()) + { + if (entry.value) + { + insert_impl(std::move(*entry.value), entry.hash); + entry.value.reset(); + } + } + } + + std::size_t probe_index(std::size_t hash, std::size_t i) const + { + return (hash + (i * (i + 1)) / 2) % storage_.capacity; + } + + template + std::pair, bool> insert_impl(H && value, std::size_t hash) + { + std::size_t i = 0; + while (true) + { + std::size_t index = probe_index(hash, i); + auto & entry = storage_.table[index]; + if (!entry.value) + { + entry.value.emplace(std::forward(value)); + entry.hash = hash; + ++size_; + return {storage_.iterator(index), true}; + } + else if (entry.hash == hash && equal()(value, *entry.value)) + { + return {storage_.iterator(index), false}; + } + else + ++i; + } + } + + template + hash_table_iterator find_impl(Key const & key, std::size_t hash) const + { + std::size_t i = 0; + while (true) + { + std::size_t index = probe_index(hash, i); + auto & entry = storage_.table[index]; + if (!entry.value) + { + return storage_.iterator(storage_.capacity); + } + else if (entry.hash == hash && equal()(key, *entry.value)) + { + return storage_.iterator(index); + } + else + ++i; + } + } + }; + + template + struct pair_hash + : Hash + { + pair_hash(Hash const & hash) + : Hash(hash) + {} + + std::size_t operator()(Key const & key) const + { + return static_cast(*this)(key); + } + + std::size_t operator()(std::pair const & pair) const + { + return static_cast(*this)(pair.first); + } + }; + + template + struct pair_equal + : Equal + { + pair_equal(Equal const & equal) + : Equal(equal) + {} + + bool operator()(Key const & key1, std::pair const & pair2) const + { + return static_cast(*this)(key1, pair2.first); + } + + bool operator()(std::pair const & pair1, Key const & key2) const + { + return static_cast(*this)(pair1.first, key2); + } + + bool operator()(std::pair const & pair1, std::pair const & pair2) const + { + return static_cast(*this)(pair1.first, pair2.first); + } + }; + + } + + template , typename Equal = std::equal_to> + struct hash_set + { + using iterator = detail::hash_table_iterator; + + hash_set(Hash const & hash = {}, Equal const & equal = {}) + : impl_(hash, equal) + {} + + std::pair insert(T const & value) + { + auto result = impl_.insert(value); + return {result.first.as_const(), result.second}; + } + + std::pair insert(T && value) + { + auto result = impl_.insert(std::move(value)); + return {result.first.as_const(), result.second}; + } + + iterator find(T const & value) const + { + return impl_.find(value).as_const(); + } + + iterator begin() const + { + return impl_.begin().as_const(); + } + + iterator end() const + { + return impl_.end().as_const(); + } + + void clear() + { + impl_.clear(); + } + + std::size_t size() const + { + return impl_.size(); + } + + private: + detail::hash_table_impl impl_; + }; + + template , typename KeyEqual = std::equal_to> + struct hash_map + { + using iterator = detail::hash_table_iterator>; + + hash_map(Hash const & hash = {}, KeyEqual const & equal = {}) + : impl_(hash, equal) + {} + + std::pair insert(std::pair const & value) + { + return impl_.insert(value); + } + + std::pair insert(std::pair && value) + { + return impl_.insert(std::move(value)); + } + + iterator find(Key const & key) const + { + return impl_.find(key); + } + + iterator begin() const + { + return impl_.begin(); + } + + iterator end() const + { + return impl_.end(); + } + + void clear() + { + impl_.clear(); + } + + std::size_t size() const + { + return impl_.size(); + } + + private: + detail::hash_table_impl, detail::pair_hash, detail::pair_equal> impl_; + }; + +} diff --git a/libs/util/tests/hash_table.cpp b/libs/util/tests/hash_table.cpp new file mode 100644 index 00000000..110847f2 --- /dev/null +++ b/libs/util/tests/hash_table.cpp @@ -0,0 +1,184 @@ +#include + +#include +#include + +#include +#include +#include + +using namespace psemek; +using namespace psemek::util; + +test_case(util_hash__set_benchmark) +{ + random::generator rng; + std::vector values; + int const count = 1024 * 1024; + for (int i = 0; i < count; ++i) + values.push_back(i); + std::shuffle(values.begin(), values.end(), rng); + + test_profile(hash_set_total) + { + hash_set set; + + test_profile(hash_set_insert) + { + for (auto value : values) + set.insert(value); + expect_equal(set.size(), count); + } + + test_profile(hash_set_iterate) + { + int size = 0; + for (auto value : set) + { + expect(0 <= value && value < count); + ++size; + } + expect_equal(size, count); + } + + test_profile(hash_set_find) + { + for (auto value : values) + { + auto it = set.find(value); + expect(it != set.end()); + expect_equal(*it, value); + } + } + + test_profile(hash_set_clear) + { + set.clear(); + } + } + + test_profile(unordered_set_total) + { + std::unordered_set set; + + test_profile(unordered_set_insert) + { + for (auto value : values) + set.insert(value); + expect_equal(set.size(), count); + } + + test_profile(unordered_set_iterate) + { + int size = 0; + for (auto value : set) + { + expect(0 <= value && value < count); + ++size; + } + expect_equal(size, count); + } + + test_profile(unordered_set_find) + { + for (auto value : values) + { + auto it = set.find(value); + expect(it != set.end()); + expect_equal(*it, value); + } + } + + test_profile(unordered_set_clear) + { + set.clear(); + } + } +} + +test_case(util_hash__map_benchmark) +{ + random::generator rng; + std::vector keys; + int const count = 1024 * 1024; + for (int i = 0; i < count; ++i) + keys.push_back(i); + std::shuffle(keys.begin(), keys.end(), rng); + + test_profile(hash_map_total) + { + hash_map map; + + test_profile(hash_map_insert) + { + for (auto key : keys) + map.insert({key, -key}); + expect_equal(map.size(), count); + } + + test_profile(hash_map_iterate) + { + int size = 0; + for (auto const & pair : map) + { + expect(0 <= pair.first && pair.first < count); + expect_equal(pair.second, -pair.first); + ++size; + } + expect_equal(size, count); + } + + test_profile(hash_map_find) + { + for (auto key : keys) + { + auto it = map.find(key); + expect(map.find(key) != map.end()); + expect_equal(it->second, -key); + } + } + + test_profile(hash_map_clear) + { + map.clear(); + } + } + + test_profile(unordered_map_total) + { + std::unordered_map map; + + test_profile(unordered_map_insert) + { + for (auto key : keys) + map.insert({key, -key}); + } + + test_profile(unordered_map_iterate) + { + int size = 0; + for (auto const & pair : map) + { + expect(0 <= pair.first && pair.first < count); + expect_equal(pair.second, -pair.first); + ++size; + } + expect_equal(size, count); + } + + test_profile(unordered_map_find) + { + for (auto key : keys) + { + auto it = map.find(key); + expect(map.find(key) != map.end()); + expect_equal(it->second, -key); + } + } + + test_profile(unordered_map_clear) + { + map.clear(); + } + } +}