Refactor util::statistics and use a more robust mean & variance computation algorithm

This commit is contained in:
Nikita Lisitsa 2024-06-03 21:00:19 +03:00
parent f8c52bcfe2
commit 488290be4f

View file

@ -6,7 +6,6 @@
#include <cmath>
#include <limits>
#include <vector>
#include <numeric>
#include <algorithm>
namespace psemek::util
@ -51,6 +50,61 @@ namespace psemek::util
return std::sqrt(2.0) * boost::math::erf_inv(2.0 * x - 1.0);
}
template <typename T>
T sqr(T const & value)
{
return value * value;
}
template <typename T>
struct base_statistics
{
void push(T const & value, std::size_t count = 1)
{
*this = merge(*this, singleton(value, count));
}
std::size_t count() const
{
return count_;
}
T mean() const
{
return mean_;
}
T variance() const
{
return variance_;
}
friend base_statistics merge(base_statistics const & s1, base_statistics const & s2)
{
// See https://stackoverflow.com/questions/1480626/merging-two-statistical-result-sets
base_statistics result;
result.count_ = s1.count_ + s2.count_;
result.mean_ = (s1.count_ * s1.mean_ + s2.count_ * s2.mean_) / result.count_;
result.variance_ = (s1.count_ * (s1.variance_ + sqr(s1.mean_ - result.mean_)) + s2.count_ * (s2.variance_ + sqr(s2.mean_ - result.mean_))) / result.count_;
return result;
}
private:
std::size_t count_ = 0;
T mean_ = T{0};
T variance_ = T{0};
static base_statistics singleton(T const & value, std::size_t count)
{
base_statistics result;
result.count_ = count;
result.mean_ = value;
result.variance_ = 0;
return result;
}
};
}
template <typename T>
@ -58,9 +112,10 @@ namespace psemek::util
{
void push(T const & value, std::size_t count = 1);
std::size_t count() const { return count_; }
T mean() const;
T var() const;
std::size_t count() const { return base_.count(); }
T mean() const { return base_.mean(); }
T variance() const { return base_.variance(); }
T stddev() const { return std::sqrt(variance()); }
T min() const { return min_; }
T max() const { return max_; }
@ -70,9 +125,7 @@ namespace psemek::util
friend statistics_lite<H> merge(statistics_lite<H> const & s1, statistics_lite<H> const & s2);
private:
std::size_t count_ = 0;
T sum_ = T{0};
T sum_sqr_ = T{0};
detail::base_statistics<T> base_;
T min_ = detail::max<T>();
T max_ = detail::min<T>();
};
@ -80,30 +133,16 @@ namespace psemek::util
template <typename T>
void statistics_lite<T>::push(T const & value, std::size_t count)
{
count_ += count;
sum_ += value * static_cast<T>(count);
sum_sqr_ += value * value * static_cast<T>(count);
base_.push(value, count);
min_ = std::min(min_, value);
max_ = std::max(max_, value);
}
template <typename T>
T statistics_lite<T>::mean() const
{
return sum_ / count_;
}
template <typename T>
T statistics_lite<T>::var() const
{
T const m = mean();
return std::sqrt(sum_sqr_ / count_ - m * m);
}
template <typename T>
std::ostream & operator << (std::ostream & os, statistics_lite<T> const & s)
{
os << "mean = " << s.mean() << ", var = " << s.var() << ", range = [" << s.min() << " .. " << s.max() << "]";
os << "mean = " << s.mean() << ", dev = " << s.stddev() << ", range = [" << s.min() << " .. " << s.max() << "]";
return os;
}
@ -118,7 +157,7 @@ namespace psemek::util
// https://en.wikipedia.org/wiki/Truncated_normal_distribution
float const mu = mean();
float const sigma = var();
float const sigma = stddev();
float const alpha = (min_ - mu) / sigma;
float const beta = (max_ - mu) / sigma;
@ -129,9 +168,7 @@ namespace psemek::util
statistics_lite<T> merge(statistics_lite<T> const & s1, statistics_lite<T> const & s2)
{
statistics_lite<T> result;
result.count_ = s1.count_ + s2.count_;
result.sum_ = s1.sum_ + s2.sum_;
result.sum_sqr_ = s1.sum_sqr_ + s2.sum_sqr_;
result.base_ = merge(s1.base_, s2.base_);
result.min_ = std::min(s1.min_, s2.min_);
result.max_ = std::max(s1.max_, s2.max_);
return result;
@ -139,14 +176,11 @@ namespace psemek::util
template <typename T>
struct statistics
: statistics_lite<T>
{
void push(T const & value);
void push(T const & value, std::size_t count = 1);
std::size_t count() const { return values_.size(); }
T mean() const;
T var() const;
T min() const;
T max() const;
T percentile(double p) const;
@ -156,52 +190,19 @@ namespace psemek::util
private:
mutable std::vector<T> values_;
mutable bool sorted_ = true;
statistics_lite<T> & lite() { return *this; }
statistics_lite<T> const & lite() const { return *this; }
};
template <typename T>
void statistics<T>::push(T const & value)
void statistics<T>::push(T const & value, std::size_t count)
{
values_.push_back(value);
lite().push(value, count);
values_.insert(values_.end(), value, count);
sorted_ = false;
}
template <typename T>
T statistics<T>::mean() const
{
return std::accumulate(values_.begin(), values_.end(), T{}) / count();
}
template <typename T>
T statistics<T>::var() const
{
T sum{};
T sum_sqr{};
for (auto const & v : values_)
{
sum += v;
sum_sqr += v * v;
}
auto m = sum / count();
return std::sqrt(sum_sqr / count() - m * m);
}
template <typename T>
T statistics<T>::min() const
{
if (values_.empty())
return std::numeric_limits<T>::max();
return *std::min_element(values_.begin(), values_.end());
}
template <typename T>
T statistics<T>::max() const
{
if (values_.empty())
return std::numeric_limits<T>::min();
return *std::max_element(values_.begin(), values_.end());
}
template <typename T>
T statistics<T>::percentile(double p) const
{
@ -217,7 +218,7 @@ namespace psemek::util
template <typename T>
std::ostream & operator << (std::ostream & os, statistics<T> const & s)
{
os << "mean = " << s.mean() << ", var = " << s.var() << ", range = [" << s.min() << " .. " << s.max() << "]";
os << "mean = " << s.mean() << ", dev = " << s.stddev() << ", range = [" << s.min() << " .. " << s.max() << "]";
return os;
}
@ -225,6 +226,8 @@ namespace psemek::util
statistics<T> merge(statistics<T> const & s1, statistics<T> const & s2)
{
statistics<T> result;
result.lite() = merge(s1.lite(), s2.lite());
result.values_.reserve(s1.values_.size() + s2.values_.size());
result.sorted_ = s1.sorted_ && s2.sorted_;
if (result.sorted_)
@ -237,6 +240,7 @@ namespace psemek::util
it = std::copy(s1.values_.begin(), s1.values_.end(), it);
it = std::copy(s2.values_.begin(), s2.values_.end(), it);
}
return result;
}