From fcfd7138d15d72e455305b2d232c792b3c773dce Mon Sep 17 00:00:00 2001 From: lisyarus Date: Tue, 17 May 2022 18:22:31 +0300 Subject: [PATCH] Add lru cache implementation & tests --- libs/util/include/psemek/util/lru_cache.hpp | 317 ++++++++++++++++++-- libs/util/tests/lru_cache.cpp | 158 ++++++++++ 2 files changed, 448 insertions(+), 27 deletions(-) create mode 100644 libs/util/tests/lru_cache.cpp diff --git a/libs/util/include/psemek/util/lru_cache.hpp b/libs/util/include/psemek/util/lru_cache.hpp index 2acecb47..a2dd2e70 100644 --- a/libs/util/include/psemek/util/lru_cache.hpp +++ b/libs/util/include/psemek/util/lru_cache.hpp @@ -1,59 +1,322 @@ #pragma once +#include +#include + #include #include namespace psemek::util { - template + // No function except `insert` touches the accessed elements + // To mark an element as accessed, one must call `touch` manually + // Calling touch from inside the `removable` predicate or while iterating is UB + template struct lru_cache { - bool empty() const; - std::size_t count(T const & value) const; + using key_type = Key; + using mapped_type = Mapped; + using value_type = std::pair; - T pop(); - void push(T value); + using iterator = typename std::list::iterator; + using const_iterator = typename std::list::const_iterator; + + // Predicate must be callable with arguments of type (Key const, Value) + lru_cache(std::size_t max_size, Predicate removable = Predicate{}); + + // Find an element without touching it + iterator find(Key const & key); + const_iterator find(Key const & key) const; + + bool contains(Key const & key) const; + + // Access an element without touching it + // Throws if the key is not present + Mapped & at(Key const & key); + Mapped const & at(Key const & key) const; + + // Insert an element (automatically touches it) + // N.B. the element might be immediately deleted by a shrink + void insert(Key const & key, Mapped && mapped); + void insert(Key const & key, Mapped const & mapped); + + void erase(const_iterator it); + void erase(Key const & key); + + void touch(const_iterator it); + void touch(Key const & key); + + iterator begin(); + iterator end(); + + const_iterator begin() const; + const_iterator end() const; + + const_iterator cbegin() const; + const_iterator cend() const; + + bool empty() const; + + std::size_t size() const; + + std::size_t max_size() const; + std::size_t max_size(std::size_t new_max_size); + + void shrink(); + + Predicate & removable(); + Predicate const & removable() const; + + bool removable(const_iterator it) const; + bool removable(Key const & key, Mapped const & value) const; private: - std::list queue_; - std::unordered_map::iterator> queue_it_map_; + std::size_t max_size_; + Predicate removable_; + + // N.B.: keys are stored twice - one in map_, one in list_ + // list_ contains non-removable items first, then all removable items + std::list list_; + iterator removable_begin_ = list_.end(); + std::unordered_map map_; + + iterator find_safe(Key const & key); + const_iterator find_safe(Key const & key) const; + + void insert(iterator it); }; - template - bool lru_cache::empty() const + template + lru_cache::lru_cache(std::size_t max_size, Predicate removable) + : max_size_(max_size) + , removable_(std::move(removable)) + {} + + template + auto lru_cache::find(Key const & key) -> iterator { - return queue_.empty(); + if (auto it = map_.find(key); it != map_.end()) + return it->second; + return list_.end(); } - template - std::size_t lru_cache::count(T const & value) const + template + auto lru_cache::find(Key const & key) const -> const_iterator { - return queue_it_map_.count(value); + return const_cast(*this).find(key); } - template - T lru_cache::pop() + template + bool lru_cache::contains(Key const & key) const { - auto value = std::move(queue_.front()); - queue_.pop_front(); - queue_it_map_.erase(value); - return value; + return map_.contains(key); } - template - void lru_cache::push(T value) + template + auto lru_cache::at(Key const & key) -> Mapped & { - auto it = queue_it_map_.find(value); - if (it == queue_it_map_.end()) + return find_safe(key)->second; + } + + template + auto lru_cache::at(Key const & key) const -> Mapped const & + { + return const_cast(*this).at(key); + } + + template + void lru_cache::insert(Key const & key, Mapped && mapped) + { + list_.emplace_front(key, std::move(mapped)); + insert(list_.begin()); + } + + template + void lru_cache::insert(Key const & key, Mapped const & mapped) + { + list_.emplace_front(key, mapped); + insert(list_.begin()); + } + + template + void lru_cache::erase(const_iterator it) + { + if (it == removable_begin_) + ++removable_begin_; + map_.erase(it->first); + list_.erase(it); + } + + template + void lru_cache::erase(Key const & key) + { + erase(find_safe(key)); + } + + template + void lru_cache::touch(const_iterator it) + { + bool const removable = removable_(it->first, it->second); + list_.splice(removable ? removable_begin_ : list_.begin(), list_, it); + if (removable && it != removable_begin_) + --removable_begin_; + } + + template + void lru_cache::touch(Key const & key) + { + touch(find_safe(key)); + } + + template + auto lru_cache::begin() -> iterator + { + return list_.begin(); + } + + template + auto lru_cache::end() -> iterator + { + return list_.end(); + } + + template + auto lru_cache::begin() const -> const_iterator + { + return list_.begin(); + } + + template + auto lru_cache::end() const -> const_iterator + { + return list_.end(); + } + + template + auto lru_cache::cbegin() const -> const_iterator + { + return list_.begin(); + } + + template + auto lru_cache::cend() const -> const_iterator + { + return list_.end(); + } + + template + bool lru_cache::empty() const + { + return list_.empty(); + } + + template + std::size_t lru_cache::size() const + { + return list_.size(); + } + + template + std::size_t lru_cache::max_size() const + { + return max_size_; + } + + template + std::size_t lru_cache::max_size(std::size_t new_max_size) + { + std::swap(max_size_, new_max_size); + shrink(); + return new_max_size; + } + + template + void lru_cache::shrink() + { + while (size() > max_size()) { - queue_.push_back(value); - queue_it_map_[std::move(value)] = queue_.back(); + if (removable_begin_ != list_.end()) + { + bool const removable = removable_(list_.back().first, list_.back().second); + if (removable) + { + erase(std::prev(list_.end())); + } + else + { + // Move the item to non-removable sublist + // We will have one potentially removable item less in the next loop iteration + if (removable_begin_ == std::prev(list_.end())) + ++removable_begin_; + else + list_.splice(removable_begin_, list_, std::prev(list_.end())); + } + } + else + { + // No potentially removable elements, try non-removable ones + if (list_.begin() == removable_begin_) + break; + + auto it = std::prev(removable_begin_); + if (!removable_(it->first, it->second)) + break; + + erase(it); + } } - else + } + + template + Predicate & lru_cache::removable() + { + return removable_; + } + + template + Predicate const & lru_cache::removable() const + { + return removable_; + } + + template + bool lru_cache::removable(const_iterator it) const + { + return removable_(it->first, it->second); + } + + template + bool lru_cache::removable(Key const & key, Mapped const & value) const + { + return removable_(key, value); + } + + template + auto lru_cache::find_safe(Key const & key) -> iterator + { + if (auto it = find(key); it != list_.end()) + return it; + throw key_error{key}; + } + + template + auto lru_cache::find_safe(Key const & key) const -> const_iterator + { + return const_cast(*this).find_safe(key); + } + + template + void lru_cache::insert(iterator it) + { + bool const removable = removable_(it->first, it->second); + if (removable) { - queue_.splice(queue_.begin(), queue_, *it); + list_.splice(removable_begin_, list_, it); + it = --removable_begin_; } + map_[it->first] = it; + shrink(); } } diff --git a/libs/util/tests/lru_cache.cpp b/libs/util/tests/lru_cache.cpp new file mode 100644 index 00000000..9759455e --- /dev/null +++ b/libs/util/tests/lru_cache.cpp @@ -0,0 +1,158 @@ +#include + +#include +#include + +#include + +using namespace psemek::util; + +test_case(util_lru__cache_empty) +{ + lru_cache c(64); + expect_equal(c.size(), 0); + expect(c.empty()); +} + +test_case(util_lru__cache_insert) +{ + lru_cache c(64); + expect_equal(c.size(), 0); + expect(c.empty()); + + for (int key = 0; key < 32; ++key) + { + int const value = key * key; + c.insert(key, value); + expect_lequal(c.size(), key + 1); + expect(c.contains(key)); + expect(c.find(key) != c.end()); + expect_equal(c.at(key), value); + expect_equal(c.find(key)->first, key); + expect_equal(c.find(key)->second, value); + } +} + +test_case(util_lru__cache_iterate) +{ + lru_cache c(64); + expect_equal(c.size(), 0); + expect(c.empty()); + + std::unordered_map inserted; + + for (int key = 0; key < 32; ++key) + { + int const value = key * key; + c.insert(key, value); + inserted[key] = value; + } + + for (auto const & p : c) + { + expect(inserted.contains(p.first)); + expect_equal(inserted.at(p.first), p.second); + } + + for (auto const & p : inserted) + { + expect(c.contains(p.first)); + expect_equal(c.at(p.first), p.second); + } +} + +test_case(util_lru__cache_shrink) +{ + lru_cache c(16); + expect_equal(c.size(), 0); + expect(c.empty()); + expect_equal(c.max_size(), 16); + + std::unordered_map inserted; + + for (int key = 0; key < 32; ++key) + { + int const value = key * key; + c.insert(key, value); + inserted[key] = value; + expect_lequal(c.size(), key + 1); + expect(c.contains(key)); + expect(c.find(key) != c.end()); + expect_equal(c.at(key), value); + expect_equal(c.find(key)->first, key); + expect_equal(c.find(key)->second, value); + + expect_lequal(c.size(), c.max_size()); + } + + for (int key = 0; key < 16; ++key) + expect(!c.contains(key)); + + for (int key = 16; key < 32; ++key) + { + expect(c.contains(key)); + expect_equal(c.at(key), inserted.at(key)); + } + + for (auto const & p : c) + { + expect(inserted.contains(p.first)); + expect_equal(inserted.at(p.first), p.second); + } +} + +test_case(util_lru__cache_touch) +{ + lru_cache c(64); + expect_equal(c.size(), 0); + expect(c.empty()); + + std::unordered_map inserted; + + for (int key = 0; key < 32; ++key) + { + int const value = key * key; + c.insert(key, value); + inserted[key] = value; + expect(c.find(key) == c.begin()); + } + + for (int key = 0; key < 32; ++key) + { + c.touch(key); + expect(c.find(key) == c.begin()); + } +} + +test_case(util_lru__cache_removable) +{ + auto removable = [](int key, int){ return key % 2 == 1; }; + + lru_cache> c(16, removable); + expect_equal(c.size(), 0); + expect(c.empty()); + + std::unordered_map inserted; + + for (int key = 0; key < 32; ++key) + { + int const value = key * key; + bool const is_removable = removable(key, value); + bool const should_contain = c.size() < c.max_size() || !is_removable; + + c.insert(key, value); + inserted[key] = value; + if (should_contain) + expect(c.contains(key)); + + if (!is_removable) + expect(c.find(key) == c.begin()); + } + + expect_equal(c.size(), c.max_size()); + for (auto const & p : c) + { + expect_equal(inserted.at(p.first), p.second); + expect(!removable(p.first, p.second)); + } +}