Add neural net random initialization

This commit is contained in:
Nikita Lisitsa 2022-01-20 14:16:54 +03:00
parent 587a80b0de
commit 3a3fffa232
2 changed files with 51 additions and 0 deletions

View file

@ -0,0 +1,40 @@
#pragma once
#include <psemek/ml/neural_net/neural_net.hpp>
#include <psemek/random/normal.hpp>
#include <psemek/random/uniform_real.hpp>
#include <psemek/random/generator.hpp>
namespace psemek::ml
{
template <typename T, typename RNG>
void randomize_uniform(neural_net<T> & nn, RNG && rng)
{
auto const layer_sizes = nn.layer_sizes();
for_each_layer(nn, [&layer_sizes, &rng](std::size_t l, util::span<T> weights){
T length = std::sqrt(T{6} / (layer_sizes[l] + layer_sizes[l + 1]));
random::uniform_real_distribution<T> d{-length, length};
for (T & w : weights)
w = d(rng);
});
}
template <typename T, typename RNG>
void randomize_normal(neural_net<T> & nn, RNG && rng)
{
auto const layer_sizes = nn.layer_sizes();
for_each_layer(nn, [&layer_sizes, &rng](std::size_t l, util::span<T> weights){
T stddev = std::sqrt(T{2} / (layer_sizes[l] + layer_sizes[l + 1]));
random::normal_distribution<T> d{T{0}, stddev};
for (T & w : weights)
w = d(rng);
});
}
extern template void randomize_uniform<float, random::generator>(neural_net<float> &, random::generator &&);
extern template void randomize_uniform<double, random::generator>(neural_net<double> &, random::generator &&);
extern template void randomize_normal<float, random::generator>(neural_net<float> &, random::generator &&);
extern template void randomize_normal<double, random::generator>(neural_net<double> &, random::generator &&);
}

View file

@ -0,0 +1,11 @@
#include <psemek/ml/neural_net/randomize.hpp>
namespace psemek::ml
{
template void randomize_uniform<float, random::generator>(neural_net<float> &, random::generator &&);
template void randomize_uniform<double, random::generator>(neural_net<double> &, random::generator &&);
template void randomize_normal<float, random::generator>(neural_net<float> &, random::generator &&);
template void randomize_normal<double, random::generator>(neural_net<double> &, random::generator &&);
}