Refactor util::statistics and use a more robust mean & variance computation algorithm
This commit is contained in:
parent
f8c52bcfe2
commit
488290be4f
1 changed files with 77 additions and 73 deletions
|
|
@ -6,7 +6,6 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <numeric>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
namespace psemek::util
|
namespace psemek::util
|
||||||
|
|
@ -51,6 +50,61 @@ namespace psemek::util
|
||||||
return std::sqrt(2.0) * boost::math::erf_inv(2.0 * x - 1.0);
|
return std::sqrt(2.0) * boost::math::erf_inv(2.0 * x - 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T sqr(T const & value)
|
||||||
|
{
|
||||||
|
return value * value;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct base_statistics
|
||||||
|
{
|
||||||
|
void push(T const & value, std::size_t count = 1)
|
||||||
|
{
|
||||||
|
*this = merge(*this, singleton(value, count));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t count() const
|
||||||
|
{
|
||||||
|
return count_;
|
||||||
|
}
|
||||||
|
|
||||||
|
T mean() const
|
||||||
|
{
|
||||||
|
return mean_;
|
||||||
|
}
|
||||||
|
|
||||||
|
T variance() const
|
||||||
|
{
|
||||||
|
return variance_;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend base_statistics merge(base_statistics const & s1, base_statistics const & s2)
|
||||||
|
{
|
||||||
|
// See https://stackoverflow.com/questions/1480626/merging-two-statistical-result-sets
|
||||||
|
|
||||||
|
base_statistics result;
|
||||||
|
result.count_ = s1.count_ + s2.count_;
|
||||||
|
result.mean_ = (s1.count_ * s1.mean_ + s2.count_ * s2.mean_) / result.count_;
|
||||||
|
result.variance_ = (s1.count_ * (s1.variance_ + sqr(s1.mean_ - result.mean_)) + s2.count_ * (s2.variance_ + sqr(s2.mean_ - result.mean_))) / result.count_;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::size_t count_ = 0;
|
||||||
|
T mean_ = T{0};
|
||||||
|
T variance_ = T{0};
|
||||||
|
|
||||||
|
static base_statistics singleton(T const & value, std::size_t count)
|
||||||
|
{
|
||||||
|
base_statistics result;
|
||||||
|
result.count_ = count;
|
||||||
|
result.mean_ = value;
|
||||||
|
result.variance_ = 0;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
@ -58,9 +112,10 @@ namespace psemek::util
|
||||||
{
|
{
|
||||||
void push(T const & value, std::size_t count = 1);
|
void push(T const & value, std::size_t count = 1);
|
||||||
|
|
||||||
std::size_t count() const { return count_; }
|
std::size_t count() const { return base_.count(); }
|
||||||
T mean() const;
|
T mean() const { return base_.mean(); }
|
||||||
T var() const;
|
T variance() const { return base_.variance(); }
|
||||||
|
T stddev() const { return std::sqrt(variance()); }
|
||||||
T min() const { return min_; }
|
T min() const { return min_; }
|
||||||
T max() const { return max_; }
|
T max() const { return max_; }
|
||||||
|
|
||||||
|
|
@ -70,9 +125,7 @@ namespace psemek::util
|
||||||
friend statistics_lite<H> merge(statistics_lite<H> const & s1, statistics_lite<H> const & s2);
|
friend statistics_lite<H> merge(statistics_lite<H> const & s1, statistics_lite<H> const & s2);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::size_t count_ = 0;
|
detail::base_statistics<T> base_;
|
||||||
T sum_ = T{0};
|
|
||||||
T sum_sqr_ = T{0};
|
|
||||||
T min_ = detail::max<T>();
|
T min_ = detail::max<T>();
|
||||||
T max_ = detail::min<T>();
|
T max_ = detail::min<T>();
|
||||||
};
|
};
|
||||||
|
|
@ -80,30 +133,16 @@ namespace psemek::util
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void statistics_lite<T>::push(T const & value, std::size_t count)
|
void statistics_lite<T>::push(T const & value, std::size_t count)
|
||||||
{
|
{
|
||||||
count_ += count;
|
base_.push(value, count);
|
||||||
sum_ += value * static_cast<T>(count);
|
|
||||||
sum_sqr_ += value * value * static_cast<T>(count);
|
|
||||||
min_ = std::min(min_, value);
|
min_ = std::min(min_, value);
|
||||||
max_ = std::max(max_, value);
|
max_ = std::max(max_, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T statistics_lite<T>::mean() const
|
|
||||||
{
|
|
||||||
return sum_ / count_;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T statistics_lite<T>::var() const
|
|
||||||
{
|
|
||||||
T const m = mean();
|
|
||||||
return std::sqrt(sum_sqr_ / count_ - m * m);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::ostream & operator << (std::ostream & os, statistics_lite<T> const & s)
|
std::ostream & operator << (std::ostream & os, statistics_lite<T> const & s)
|
||||||
{
|
{
|
||||||
os << "mean = " << s.mean() << ", var = " << s.var() << ", range = [" << s.min() << " .. " << s.max() << "]";
|
os << "mean = " << s.mean() << ", dev = " << s.stddev() << ", range = [" << s.min() << " .. " << s.max() << "]";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -118,7 +157,7 @@ namespace psemek::util
|
||||||
// https://en.wikipedia.org/wiki/Truncated_normal_distribution
|
// https://en.wikipedia.org/wiki/Truncated_normal_distribution
|
||||||
|
|
||||||
float const mu = mean();
|
float const mu = mean();
|
||||||
float const sigma = var();
|
float const sigma = stddev();
|
||||||
float const alpha = (min_ - mu) / sigma;
|
float const alpha = (min_ - mu) / sigma;
|
||||||
float const beta = (max_ - mu) / sigma;
|
float const beta = (max_ - mu) / sigma;
|
||||||
|
|
||||||
|
|
@ -129,9 +168,7 @@ namespace psemek::util
|
||||||
statistics_lite<T> merge(statistics_lite<T> const & s1, statistics_lite<T> const & s2)
|
statistics_lite<T> merge(statistics_lite<T> const & s1, statistics_lite<T> const & s2)
|
||||||
{
|
{
|
||||||
statistics_lite<T> result;
|
statistics_lite<T> result;
|
||||||
result.count_ = s1.count_ + s2.count_;
|
result.base_ = merge(s1.base_, s2.base_);
|
||||||
result.sum_ = s1.sum_ + s2.sum_;
|
|
||||||
result.sum_sqr_ = s1.sum_sqr_ + s2.sum_sqr_;
|
|
||||||
result.min_ = std::min(s1.min_, s2.min_);
|
result.min_ = std::min(s1.min_, s2.min_);
|
||||||
result.max_ = std::max(s1.max_, s2.max_);
|
result.max_ = std::max(s1.max_, s2.max_);
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -139,14 +176,11 @@ namespace psemek::util
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct statistics
|
struct statistics
|
||||||
|
: statistics_lite<T>
|
||||||
{
|
{
|
||||||
void push(T const & value);
|
void push(T const & value, std::size_t count = 1);
|
||||||
|
|
||||||
std::size_t count() const { return values_.size(); }
|
std::size_t count() const { return values_.size(); }
|
||||||
T mean() const;
|
|
||||||
T var() const;
|
|
||||||
T min() const;
|
|
||||||
T max() const;
|
|
||||||
|
|
||||||
T percentile(double p) const;
|
T percentile(double p) const;
|
||||||
|
|
||||||
|
|
@ -156,52 +190,19 @@ namespace psemek::util
|
||||||
private:
|
private:
|
||||||
mutable std::vector<T> values_;
|
mutable std::vector<T> values_;
|
||||||
mutable bool sorted_ = true;
|
mutable bool sorted_ = true;
|
||||||
|
|
||||||
|
statistics_lite<T> & lite() { return *this; }
|
||||||
|
statistics_lite<T> const & lite() const { return *this; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void statistics<T>::push(T const & value)
|
void statistics<T>::push(T const & value, std::size_t count)
|
||||||
{
|
{
|
||||||
values_.push_back(value);
|
lite().push(value, count);
|
||||||
|
values_.insert(values_.end(), value, count);
|
||||||
sorted_ = false;
|
sorted_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T statistics<T>::mean() const
|
|
||||||
{
|
|
||||||
return std::accumulate(values_.begin(), values_.end(), T{}) / count();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T statistics<T>::var() const
|
|
||||||
{
|
|
||||||
T sum{};
|
|
||||||
T sum_sqr{};
|
|
||||||
for (auto const & v : values_)
|
|
||||||
{
|
|
||||||
sum += v;
|
|
||||||
sum_sqr += v * v;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto m = sum / count();
|
|
||||||
return std::sqrt(sum_sqr / count() - m * m);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T statistics<T>::min() const
|
|
||||||
{
|
|
||||||
if (values_.empty())
|
|
||||||
return std::numeric_limits<T>::max();
|
|
||||||
return *std::min_element(values_.begin(), values_.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T statistics<T>::max() const
|
|
||||||
{
|
|
||||||
if (values_.empty())
|
|
||||||
return std::numeric_limits<T>::min();
|
|
||||||
return *std::max_element(values_.begin(), values_.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T statistics<T>::percentile(double p) const
|
T statistics<T>::percentile(double p) const
|
||||||
{
|
{
|
||||||
|
|
@ -217,7 +218,7 @@ namespace psemek::util
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::ostream & operator << (std::ostream & os, statistics<T> const & s)
|
std::ostream & operator << (std::ostream & os, statistics<T> const & s)
|
||||||
{
|
{
|
||||||
os << "mean = " << s.mean() << ", var = " << s.var() << ", range = [" << s.min() << " .. " << s.max() << "]";
|
os << "mean = " << s.mean() << ", dev = " << s.stddev() << ", range = [" << s.min() << " .. " << s.max() << "]";
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -225,6 +226,8 @@ namespace psemek::util
|
||||||
statistics<T> merge(statistics<T> const & s1, statistics<T> const & s2)
|
statistics<T> merge(statistics<T> const & s1, statistics<T> const & s2)
|
||||||
{
|
{
|
||||||
statistics<T> result;
|
statistics<T> result;
|
||||||
|
result.lite() = merge(s1.lite(), s2.lite());
|
||||||
|
|
||||||
result.values_.reserve(s1.values_.size() + s2.values_.size());
|
result.values_.reserve(s1.values_.size() + s2.values_.size());
|
||||||
result.sorted_ = s1.sorted_ && s2.sorted_;
|
result.sorted_ = s1.sorted_ && s2.sorted_;
|
||||||
if (result.sorted_)
|
if (result.sorted_)
|
||||||
|
|
@ -237,6 +240,7 @@ namespace psemek::util
|
||||||
it = std::copy(s1.values_.begin(), s1.values_.end(), it);
|
it = std::copy(s1.values_.begin(), s1.values_.end(), it);
|
||||||
it = std::copy(s2.values_.begin(), s2.values_.end(), it);
|
it = std::copy(s2.values_.begin(), s2.values_.end(), it);
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue