From 3a3fffa232234de2eed6fd4819639ca9ca3ec06a Mon Sep 17 00:00:00 2001 From: lisyarus Date: Thu, 20 Jan 2022 14:16:54 +0300 Subject: [PATCH] Add neural net random initialization --- .../psemek/ml/neural_net/randomize.hpp | 40 +++++++++++++++++++ libs/ml/source/neural_net/randomize.cpp | 11 +++++ 2 files changed, 51 insertions(+) create mode 100644 libs/ml/include/psemek/ml/neural_net/randomize.hpp create mode 100644 libs/ml/source/neural_net/randomize.cpp diff --git a/libs/ml/include/psemek/ml/neural_net/randomize.hpp b/libs/ml/include/psemek/ml/neural_net/randomize.hpp new file mode 100644 index 00000000..250e7c8b --- /dev/null +++ b/libs/ml/include/psemek/ml/neural_net/randomize.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include +#include + +namespace psemek::ml +{ + + template + void randomize_uniform(neural_net & nn, RNG && rng) + { + auto const layer_sizes = nn.layer_sizes(); + for_each_layer(nn, [&layer_sizes, &rng](std::size_t l, util::span weights){ + T length = std::sqrt(T{6} / (layer_sizes[l] + layer_sizes[l + 1])); + random::uniform_real_distribution d{-length, length}; + for (T & w : weights) + w = d(rng); + }); + } + + template + void randomize_normal(neural_net & nn, RNG && rng) + { + auto const layer_sizes = nn.layer_sizes(); + for_each_layer(nn, [&layer_sizes, &rng](std::size_t l, util::span weights){ + T stddev = std::sqrt(T{2} / (layer_sizes[l] + layer_sizes[l + 1])); + random::normal_distribution d{T{0}, stddev}; + for (T & w : weights) + w = d(rng); + }); + } + + extern template void randomize_uniform(neural_net &, random::generator &&); + extern template void randomize_uniform(neural_net &, random::generator &&); + extern template void randomize_normal(neural_net &, random::generator &&); + extern template void randomize_normal(neural_net &, random::generator &&); + +} diff --git a/libs/ml/source/neural_net/randomize.cpp b/libs/ml/source/neural_net/randomize.cpp new file mode 100644 index 00000000..6a94c020 --- /dev/null +++ b/libs/ml/source/neural_net/randomize.cpp @@ -0,0 +1,11 @@ +#include + +namespace psemek::ml +{ + + template void randomize_uniform(neural_net &, random::generator &&); + template void randomize_uniform(neural_net &, random::generator &&); + template void randomize_normal(neural_net &, random::generator &&); + template void randomize_normal(neural_net &, random::generator &&); + +}