From 488290be4f63459ee46089166919b0838a4baff7 Mon Sep 17 00:00:00 2001 From: lisyarus Date: Mon, 3 Jun 2024 21:00:19 +0300 Subject: [PATCH] Refactor util::statistics and use a more robust mean & variance computation algorithm --- libs/util/include/psemek/util/statistics.hpp | 150 ++++++++++--------- 1 file changed, 77 insertions(+), 73 deletions(-) diff --git a/libs/util/include/psemek/util/statistics.hpp b/libs/util/include/psemek/util/statistics.hpp index b81aeb4b..ee1950f3 100644 --- a/libs/util/include/psemek/util/statistics.hpp +++ b/libs/util/include/psemek/util/statistics.hpp @@ -6,7 +6,6 @@ #include #include #include -#include #include namespace psemek::util @@ -51,6 +50,61 @@ namespace psemek::util return std::sqrt(2.0) * boost::math::erf_inv(2.0 * x - 1.0); } + template + T sqr(T const & value) + { + return value * value; + } + + template + struct base_statistics + { + void push(T const & value, std::size_t count = 1) + { + *this = merge(*this, singleton(value, count)); + } + + std::size_t count() const + { + return count_; + } + + T mean() const + { + return mean_; + } + + T variance() const + { + return variance_; + } + + friend base_statistics merge(base_statistics const & s1, base_statistics const & s2) + { + // See https://stackoverflow.com/questions/1480626/merging-two-statistical-result-sets + + base_statistics result; + result.count_ = s1.count_ + s2.count_; + result.mean_ = (s1.count_ * s1.mean_ + s2.count_ * s2.mean_) / result.count_; + result.variance_ = (s1.count_ * (s1.variance_ + sqr(s1.mean_ - result.mean_)) + s2.count_ * (s2.variance_ + sqr(s2.mean_ - result.mean_))) / result.count_; + return result; + } + + private: + std::size_t count_ = 0; + T mean_ = T{0}; + T variance_ = T{0}; + + static base_statistics singleton(T const & value, std::size_t count) + { + base_statistics result; + result.count_ = count; + result.mean_ = value; + result.variance_ = 0; + return result; + } + }; + } template @@ -58,9 +112,10 @@ namespace psemek::util { void push(T const & value, std::size_t count = 1); - std::size_t count() const { return count_; } - T mean() const; - T var() const; + std::size_t count() const { return base_.count(); } + T mean() const { return base_.mean(); } + T variance() const { return base_.variance(); } + T stddev() const { return std::sqrt(variance()); } T min() const { return min_; } T max() const { return max_; } @@ -70,9 +125,7 @@ namespace psemek::util friend statistics_lite merge(statistics_lite const & s1, statistics_lite const & s2); private: - std::size_t count_ = 0; - T sum_ = T{0}; - T sum_sqr_ = T{0}; + detail::base_statistics base_; T min_ = detail::max(); T max_ = detail::min(); }; @@ -80,30 +133,16 @@ namespace psemek::util template void statistics_lite::push(T const & value, std::size_t count) { - count_ += count; - sum_ += value * static_cast(count); - sum_sqr_ += value * value * static_cast(count); + base_.push(value, count); + min_ = std::min(min_, value); max_ = std::max(max_, value); } - template - T statistics_lite::mean() const - { - return sum_ / count_; - } - - template - T statistics_lite::var() const - { - T const m = mean(); - return std::sqrt(sum_sqr_ / count_ - m * m); - } - template std::ostream & operator << (std::ostream & os, statistics_lite const & s) { - os << "mean = " << s.mean() << ", var = " << s.var() << ", range = [" << s.min() << " .. " << s.max() << "]"; + os << "mean = " << s.mean() << ", dev = " << s.stddev() << ", range = [" << s.min() << " .. " << s.max() << "]"; return os; } @@ -118,7 +157,7 @@ namespace psemek::util // https://en.wikipedia.org/wiki/Truncated_normal_distribution float const mu = mean(); - float const sigma = var(); + float const sigma = stddev(); float const alpha = (min_ - mu) / sigma; float const beta = (max_ - mu) / sigma; @@ -129,9 +168,7 @@ namespace psemek::util statistics_lite merge(statistics_lite const & s1, statistics_lite const & s2) { statistics_lite result; - result.count_ = s1.count_ + s2.count_; - result.sum_ = s1.sum_ + s2.sum_; - result.sum_sqr_ = s1.sum_sqr_ + s2.sum_sqr_; + result.base_ = merge(s1.base_, s2.base_); result.min_ = std::min(s1.min_, s2.min_); result.max_ = std::max(s1.max_, s2.max_); return result; @@ -139,14 +176,11 @@ namespace psemek::util template struct statistics + : statistics_lite { - void push(T const & value); + void push(T const & value, std::size_t count = 1); std::size_t count() const { return values_.size(); } - T mean() const; - T var() const; - T min() const; - T max() const; T percentile(double p) const; @@ -156,52 +190,19 @@ namespace psemek::util private: mutable std::vector values_; mutable bool sorted_ = true; + + statistics_lite & lite() { return *this; } + statistics_lite const & lite() const { return *this; } }; template - void statistics::push(T const & value) + void statistics::push(T const & value, std::size_t count) { - values_.push_back(value); + lite().push(value, count); + values_.insert(values_.end(), value, count); sorted_ = false; } - template - T statistics::mean() const - { - return std::accumulate(values_.begin(), values_.end(), T{}) / count(); - } - - template - T statistics::var() const - { - T sum{}; - T sum_sqr{}; - for (auto const & v : values_) - { - sum += v; - sum_sqr += v * v; - } - - auto m = sum / count(); - return std::sqrt(sum_sqr / count() - m * m); - } - - template - T statistics::min() const - { - if (values_.empty()) - return std::numeric_limits::max(); - return *std::min_element(values_.begin(), values_.end()); - } - - template - T statistics::max() const - { - if (values_.empty()) - return std::numeric_limits::min(); - return *std::max_element(values_.begin(), values_.end()); - } - template T statistics::percentile(double p) const { @@ -217,7 +218,7 @@ namespace psemek::util template std::ostream & operator << (std::ostream & os, statistics const & s) { - os << "mean = " << s.mean() << ", var = " << s.var() << ", range = [" << s.min() << " .. " << s.max() << "]"; + os << "mean = " << s.mean() << ", dev = " << s.stddev() << ", range = [" << s.min() << " .. " << s.max() << "]"; return os; } @@ -225,6 +226,8 @@ namespace psemek::util statistics merge(statistics const & s1, statistics const & s2) { statistics result; + result.lite() = merge(s1.lite(), s2.lite()); + result.values_.reserve(s1.values_.size() + s2.values_.size()); result.sorted_ = s1.sorted_ && s2.sorted_; if (result.sorted_) @@ -237,6 +240,7 @@ namespace psemek::util it = std::copy(s1.values_.begin(), s1.values_.end(), it); it = std::copy(s2.values_.begin(), s2.values_.end(), it); } + return result; }