From 3dbbcadfe3c02693950dc8a92fe9204e6246d252 Mon Sep 17 00:00:00 2001 From: lisyarus Date: Fri, 21 Jan 2022 18:42:48 +0300 Subject: [PATCH] Add generic backpropagation (supporting any loss function, not only l2) --- .../include/psemek/ml/neural_net/learner.hpp | 45 ++++++++++++++----- libs/ml/tests/neural_net/gradient.cpp | 2 +- libs/ml/tests/neural_net/learn.cpp | 2 +- 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/libs/ml/include/psemek/ml/neural_net/learner.hpp b/libs/ml/include/psemek/ml/neural_net/learner.hpp index 5de2a28a..69472594 100644 --- a/libs/ml/include/psemek/ml/neural_net/learner.hpp +++ b/libs/ml/include/psemek/ml/neural_net/learner.hpp @@ -15,11 +15,16 @@ namespace psemek::ml std::vector const & apply(neural_net const & nn, std::vector input) const; std::vector const & result() const { return layers_.back(); } - // Compute the gradient of a loss function (defined as 1/2 of L^2 norm - // of the difference between neural net output and desired output) - // wrt neural net weights and accumulate them to the already computed - // gradient - void backpropagate(neural_net const & nn, std::vector const & output); + // Compute the gradient (wrt neural net weights) of a loss function + // based on its gradient (wrt neural net output) and accumulate + // them to the already computed per-weight gradient + void backpropagate(neural_net const & nn, std::vector const & gradient); + + // Compute the gradient (wrt neural net weights) of a loss function + // (defined as 1/2 of L^2 norm of the difference between neural + // net output and desired output) and accumulate them to the already + // computed per-weight gradient + void backpropagate_l2(neural_net const & nn, std::vector const & output); util::span gradient() const { return gradient_; } util::span gradient() { return gradient_; } @@ -79,7 +84,7 @@ namespace psemek::ml } template - void neural_net_learner::backpropagate(neural_net const & nn, std::vector const & output) + void neural_net_learner::backpropagate(neural_net const & nn, std::vector const & gradient) { if (nn.empty()) throw empty_neural_net_error{}; @@ -88,8 +93,8 @@ namespace psemek::ml auto const activation_types = nn.activation_types(); auto const weights = nn.weights(); - if (output.size() != layer_sizes.back()) - throw wrong_neural_net_output_size(layer_sizes.back(), output.size()); + if (gradient.size() != layer_sizes.back()) + throw wrong_neural_net_output_size(layer_sizes.back(), gradient.size()); gradient_.resize(nn.weights().size()); @@ -98,11 +103,11 @@ namespace psemek::ml { if (l + 2 == layer_sizes.size()) { - error_.resize(output.size()); - for (std::size_t i = 0; i < output.size(); ++i) + error_.resize(gradient.size()); + for (std::size_t i = 0; i < gradient.size(); ++i) { T const value = layers_.back()[i]; - error_[i] = (value - output[i]) * activation_derivative(value, activation_types.back()); + error_[i] = gradient[i] * activation_derivative(value, activation_types.back()); } } else @@ -136,6 +141,24 @@ namespace psemek::ml } } + template + void neural_net_learner::backpropagate_l2(neural_net const & nn, std::vector const & output) + { + if (nn.empty()) + throw empty_neural_net_error{}; + + auto const layer_sizes = nn.layer_sizes(); + + if (output.size() != layer_sizes.back()) + throw wrong_neural_net_output_size(layer_sizes.back(), output.size()); + + error_tmp_.resize(output.size()); + for (std::size_t i = 0; i < output.size(); ++i) + error_tmp_[i] = layers_.back()[i] - output[i]; + + backpropagate(nn, error_tmp_); + } + template T neural_net_learner::gradient_norm() const { diff --git a/libs/ml/tests/neural_net/gradient.cpp b/libs/ml/tests/neural_net/gradient.cpp index 29125fab..214efe11 100644 --- a/libs/ml/tests/neural_net/gradient.cpp +++ b/libs/ml/tests/neural_net/gradient.cpp @@ -39,7 +39,7 @@ test_case(ml_neural__net_gradient) neural_net_learner learner; learner.apply(nn, input); - learner.backpropagate(nn, output); + learner.backpropagate_l2(nn, output); double const eps = 1e-6; diff --git a/libs/ml/tests/neural_net/learn.cpp b/libs/ml/tests/neural_net/learn.cpp index 8123b32d..20f997ad 100644 --- a/libs/ml/tests/neural_net/learn.cpp +++ b/libs/ml/tests/neural_net/learn.cpp @@ -31,7 +31,7 @@ namespace { learner.apply(nn, data.first); error += l2_loss(data.second, learner.result()) / batch.size(); - learner.backpropagate(nn, data.second); + learner.backpropagate_l2(nn, data.second); } if ((iteration % debug_report_frequency) == 0) debug() << "Iteration " << iteration << " error: " << error;