Add lru cache implementation & tests

This commit is contained in:
Nikita Lisitsa 2022-05-17 18:22:31 +03:00
parent cbfd8f83de
commit fcfd7138d1
2 changed files with 448 additions and 27 deletions

View file

@ -1,59 +1,322 @@
#pragma once
#include <psemek/util/functional.hpp>
#include <psemek/util/key_error.hpp>
#include <list>
#include <unordered_map>
namespace psemek::util
{
template <typename T>
// No function except `insert` touches the accessed elements
// To mark an element as accessed, one must call `touch` manually
// Calling touch from inside the `removable` predicate or while iterating is UB
template <typename Key, typename Mapped, typename Predicate = decltype(always_true)>
struct lru_cache
{
bool empty() const;
std::size_t count(T const & value) const;
using key_type = Key;
using mapped_type = Mapped;
using value_type = std::pair<Key const, Mapped>;
T pop();
void push(T value);
using iterator = typename std::list<value_type>::iterator;
using const_iterator = typename std::list<value_type>::const_iterator;
// Predicate must be callable with arguments of type (Key const, Value)
lru_cache(std::size_t max_size, Predicate removable = Predicate{});
// Find an element without touching it
iterator find(Key const & key);
const_iterator find(Key const & key) const;
bool contains(Key const & key) const;
// Access an element without touching it
// Throws if the key is not present
Mapped & at(Key const & key);
Mapped const & at(Key const & key) const;
// Insert an element (automatically touches it)
// N.B. the element might be immediately deleted by a shrink
void insert(Key const & key, Mapped && mapped);
void insert(Key const & key, Mapped const & mapped);
void erase(const_iterator it);
void erase(Key const & key);
void touch(const_iterator it);
void touch(Key const & key);
iterator begin();
iterator end();
const_iterator begin() const;
const_iterator end() const;
const_iterator cbegin() const;
const_iterator cend() const;
bool empty() const;
std::size_t size() const;
std::size_t max_size() const;
std::size_t max_size(std::size_t new_max_size);
void shrink();
Predicate & removable();
Predicate const & removable() const;
bool removable(const_iterator it) const;
bool removable(Key const & key, Mapped const & value) const;
private:
std::list<T> queue_;
std::unordered_map<T, typename std::list<T>::iterator> queue_it_map_;
std::size_t max_size_;
Predicate removable_;
// N.B.: keys are stored twice - one in map_, one in list_
// list_ contains non-removable items first, then all removable items
std::list<value_type> list_;
iterator removable_begin_ = list_.end();
std::unordered_map<Key, iterator> map_;
iterator find_safe(Key const & key);
const_iterator find_safe(Key const & key) const;
void insert(iterator it);
};
template <typename T>
bool lru_cache<T>::empty() const
template <typename Key, typename Mapped, typename Predicate>
lru_cache<Key, Mapped, Predicate>::lru_cache(std::size_t max_size, Predicate removable)
: max_size_(max_size)
, removable_(std::move(removable))
{}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::find(Key const & key) -> iterator
{
return queue_.empty();
if (auto it = map_.find(key); it != map_.end())
return it->second;
return list_.end();
}
template <typename T>
std::size_t lru_cache<T>::count(T const & value) const
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::find(Key const & key) const -> const_iterator
{
return queue_it_map_.count(value);
return const_cast<lru_cache &>(*this).find(key);
}
template <typename T>
T lru_cache<T>::pop()
template <typename Key, typename Mapped, typename Predicate>
bool lru_cache<Key, Mapped, Predicate>::contains(Key const & key) const
{
auto value = std::move(queue_.front());
queue_.pop_front();
queue_it_map_.erase(value);
return value;
return map_.contains(key);
}
template <typename T>
void lru_cache<T>::push(T value)
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::at(Key const & key) -> Mapped &
{
auto it = queue_it_map_.find(value);
if (it == queue_it_map_.end())
return find_safe(key)->second;
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::at(Key const & key) const -> Mapped const &
{
queue_.push_back(value);
queue_it_map_[std::move(value)] = queue_.back();
return const_cast<lru_cache &>(*this).at(key);
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::insert(Key const & key, Mapped && mapped)
{
list_.emplace_front(key, std::move(mapped));
insert(list_.begin());
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::insert(Key const & key, Mapped const & mapped)
{
list_.emplace_front(key, mapped);
insert(list_.begin());
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::erase(const_iterator it)
{
if (it == removable_begin_)
++removable_begin_;
map_.erase(it->first);
list_.erase(it);
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::erase(Key const & key)
{
erase(find_safe(key));
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::touch(const_iterator it)
{
bool const removable = removable_(it->first, it->second);
list_.splice(removable ? removable_begin_ : list_.begin(), list_, it);
if (removable && it != removable_begin_)
--removable_begin_;
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::touch(Key const & key)
{
touch(find_safe(key));
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::begin() -> iterator
{
return list_.begin();
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::end() -> iterator
{
return list_.end();
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::begin() const -> const_iterator
{
return list_.begin();
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::end() const -> const_iterator
{
return list_.end();
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::cbegin() const -> const_iterator
{
return list_.begin();
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::cend() const -> const_iterator
{
return list_.end();
}
template <typename Key, typename Mapped, typename Predicate>
bool lru_cache<Key, Mapped, Predicate>::empty() const
{
return list_.empty();
}
template <typename Key, typename Mapped, typename Predicate>
std::size_t lru_cache<Key, Mapped, Predicate>::size() const
{
return list_.size();
}
template <typename Key, typename Mapped, typename Predicate>
std::size_t lru_cache<Key, Mapped, Predicate>::max_size() const
{
return max_size_;
}
template <typename Key, typename Mapped, typename Predicate>
std::size_t lru_cache<Key, Mapped, Predicate>::max_size(std::size_t new_max_size)
{
std::swap(max_size_, new_max_size);
shrink();
return new_max_size;
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::shrink()
{
while (size() > max_size())
{
if (removable_begin_ != list_.end())
{
bool const removable = removable_(list_.back().first, list_.back().second);
if (removable)
{
erase(std::prev(list_.end()));
}
else
{
queue_.splice(queue_.begin(), queue_, *it);
// Move the item to non-removable sublist
// We will have one potentially removable item less in the next loop iteration
if (removable_begin_ == std::prev(list_.end()))
++removable_begin_;
else
list_.splice(removable_begin_, list_, std::prev(list_.end()));
}
}
else
{
// No potentially removable elements, try non-removable ones
if (list_.begin() == removable_begin_)
break;
auto it = std::prev(removable_begin_);
if (!removable_(it->first, it->second))
break;
erase(it);
}
}
}
template <typename Key, typename Mapped, typename Predicate>
Predicate & lru_cache<Key, Mapped, Predicate>::removable()
{
return removable_;
}
template <typename Key, typename Mapped, typename Predicate>
Predicate const & lru_cache<Key, Mapped, Predicate>::removable() const
{
return removable_;
}
template <typename Key, typename Mapped, typename Predicate>
bool lru_cache<Key, Mapped, Predicate>::removable(const_iterator it) const
{
return removable_(it->first, it->second);
}
template <typename Key, typename Mapped, typename Predicate>
bool lru_cache<Key, Mapped, Predicate>::removable(Key const & key, Mapped const & value) const
{
return removable_(key, value);
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::find_safe(Key const & key) -> iterator
{
if (auto it = find(key); it != list_.end())
return it;
throw key_error{key};
}
template <typename Key, typename Mapped, typename Predicate>
auto lru_cache<Key, Mapped, Predicate>::find_safe(Key const & key) const -> const_iterator
{
return const_cast<lru_cache &>(*this).find_safe(key);
}
template <typename Key, typename Mapped, typename Predicate>
void lru_cache<Key, Mapped, Predicate>::insert(iterator it)
{
bool const removable = removable_(it->first, it->second);
if (removable)
{
list_.splice(removable_begin_, list_, it);
it = --removable_begin_;
}
map_[it->first] = it;
shrink();
}
}

View file

@ -0,0 +1,158 @@
#include <psemek/test/test.hpp>
#include <psemek/util/lru_cache.hpp>
#include <psemek/util/function.hpp>
#include <unordered_map>
using namespace psemek::util;
test_case(util_lru__cache_empty)
{
lru_cache<int, int> c(64);
expect_equal(c.size(), 0);
expect(c.empty());
}
test_case(util_lru__cache_insert)
{
lru_cache<int, int> c(64);
expect_equal(c.size(), 0);
expect(c.empty());
for (int key = 0; key < 32; ++key)
{
int const value = key * key;
c.insert(key, value);
expect_lequal(c.size(), key + 1);
expect(c.contains(key));
expect(c.find(key) != c.end());
expect_equal(c.at(key), value);
expect_equal(c.find(key)->first, key);
expect_equal(c.find(key)->second, value);
}
}
test_case(util_lru__cache_iterate)
{
lru_cache<int, int> c(64);
expect_equal(c.size(), 0);
expect(c.empty());
std::unordered_map<int, int> inserted;
for (int key = 0; key < 32; ++key)
{
int const value = key * key;
c.insert(key, value);
inserted[key] = value;
}
for (auto const & p : c)
{
expect(inserted.contains(p.first));
expect_equal(inserted.at(p.first), p.second);
}
for (auto const & p : inserted)
{
expect(c.contains(p.first));
expect_equal(c.at(p.first), p.second);
}
}
test_case(util_lru__cache_shrink)
{
lru_cache<int, int> c(16);
expect_equal(c.size(), 0);
expect(c.empty());
expect_equal(c.max_size(), 16);
std::unordered_map<int, int> inserted;
for (int key = 0; key < 32; ++key)
{
int const value = key * key;
c.insert(key, value);
inserted[key] = value;
expect_lequal(c.size(), key + 1);
expect(c.contains(key));
expect(c.find(key) != c.end());
expect_equal(c.at(key), value);
expect_equal(c.find(key)->first, key);
expect_equal(c.find(key)->second, value);
expect_lequal(c.size(), c.max_size());
}
for (int key = 0; key < 16; ++key)
expect(!c.contains(key));
for (int key = 16; key < 32; ++key)
{
expect(c.contains(key));
expect_equal(c.at(key), inserted.at(key));
}
for (auto const & p : c)
{
expect(inserted.contains(p.first));
expect_equal(inserted.at(p.first), p.second);
}
}
test_case(util_lru__cache_touch)
{
lru_cache<int, int> c(64);
expect_equal(c.size(), 0);
expect(c.empty());
std::unordered_map<int, int> inserted;
for (int key = 0; key < 32; ++key)
{
int const value = key * key;
c.insert(key, value);
inserted[key] = value;
expect(c.find(key) == c.begin());
}
for (int key = 0; key < 32; ++key)
{
c.touch(key);
expect(c.find(key) == c.begin());
}
}
test_case(util_lru__cache_removable)
{
auto removable = [](int key, int){ return key % 2 == 1; };
lru_cache<int, int, function<bool(int, int)>> c(16, removable);
expect_equal(c.size(), 0);
expect(c.empty());
std::unordered_map<int, int> inserted;
for (int key = 0; key < 32; ++key)
{
int const value = key * key;
bool const is_removable = removable(key, value);
bool const should_contain = c.size() < c.max_size() || !is_removable;
c.insert(key, value);
inserted[key] = value;
if (should_contain)
expect(c.contains(key));
if (!is_removable)
expect(c.find(key) == c.begin());
}
expect_equal(c.size(), c.max_size());
for (auto const & p : c)
{
expect_equal(inserted.at(p.first), p.second);
expect(!removable(p.first, p.second));
}
}