Add neural net random initialization
This commit is contained in:
parent
587a80b0de
commit
3a3fffa232
2 changed files with 51 additions and 0 deletions
40
libs/ml/include/psemek/ml/neural_net/randomize.hpp
Normal file
40
libs/ml/include/psemek/ml/neural_net/randomize.hpp
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <psemek/ml/neural_net/neural_net.hpp>
|
||||||
|
#include <psemek/random/normal.hpp>
|
||||||
|
#include <psemek/random/uniform_real.hpp>
|
||||||
|
#include <psemek/random/generator.hpp>
|
||||||
|
|
||||||
|
namespace psemek::ml
|
||||||
|
{
|
||||||
|
|
||||||
|
template <typename T, typename RNG>
|
||||||
|
void randomize_uniform(neural_net<T> & nn, RNG && rng)
|
||||||
|
{
|
||||||
|
auto const layer_sizes = nn.layer_sizes();
|
||||||
|
for_each_layer(nn, [&layer_sizes, &rng](std::size_t l, util::span<T> weights){
|
||||||
|
T length = std::sqrt(T{6} / (layer_sizes[l] + layer_sizes[l + 1]));
|
||||||
|
random::uniform_real_distribution<T> d{-length, length};
|
||||||
|
for (T & w : weights)
|
||||||
|
w = d(rng);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename RNG>
|
||||||
|
void randomize_normal(neural_net<T> & nn, RNG && rng)
|
||||||
|
{
|
||||||
|
auto const layer_sizes = nn.layer_sizes();
|
||||||
|
for_each_layer(nn, [&layer_sizes, &rng](std::size_t l, util::span<T> weights){
|
||||||
|
T stddev = std::sqrt(T{2} / (layer_sizes[l] + layer_sizes[l + 1]));
|
||||||
|
random::normal_distribution<T> d{T{0}, stddev};
|
||||||
|
for (T & w : weights)
|
||||||
|
w = d(rng);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
extern template void randomize_uniform<float, random::generator>(neural_net<float> &, random::generator &&);
|
||||||
|
extern template void randomize_uniform<double, random::generator>(neural_net<double> &, random::generator &&);
|
||||||
|
extern template void randomize_normal<float, random::generator>(neural_net<float> &, random::generator &&);
|
||||||
|
extern template void randomize_normal<double, random::generator>(neural_net<double> &, random::generator &&);
|
||||||
|
|
||||||
|
}
|
||||||
11
libs/ml/source/neural_net/randomize.cpp
Normal file
11
libs/ml/source/neural_net/randomize.cpp
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
#include <psemek/ml/neural_net/randomize.hpp>
|
||||||
|
|
||||||
|
namespace psemek::ml
|
||||||
|
{
|
||||||
|
|
||||||
|
template void randomize_uniform<float, random::generator>(neural_net<float> &, random::generator &&);
|
||||||
|
template void randomize_uniform<double, random::generator>(neural_net<double> &, random::generator &&);
|
||||||
|
template void randomize_normal<float, random::generator>(neural_net<float> &, random::generator &&);
|
||||||
|
template void randomize_normal<double, random::generator>(neural_net<double> &, random::generator &&);
|
||||||
|
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue