Add generic backpropagation (supporting any loss function, not only l2)
This commit is contained in:
parent
8630525dcf
commit
3dbbcadfe3
3 changed files with 36 additions and 13 deletions
|
|
@ -15,11 +15,16 @@ namespace psemek::ml
|
|||
std::vector<T> const & apply(neural_net<T> const & nn, std::vector<T> input) const;
|
||||
std::vector<T> const & result() const { return layers_.back(); }
|
||||
|
||||
// Compute the gradient of a loss function (defined as 1/2 of L^2 norm
|
||||
// of the difference between neural net output and desired output)
|
||||
// wrt neural net weights and accumulate them to the already computed
|
||||
// gradient
|
||||
void backpropagate(neural_net<T> const & nn, std::vector<T> const & output);
|
||||
// Compute the gradient (wrt neural net weights) of a loss function
|
||||
// based on its gradient (wrt neural net output) and accumulate
|
||||
// them to the already computed per-weight gradient
|
||||
void backpropagate(neural_net<T> const & nn, std::vector<T> const & gradient);
|
||||
|
||||
// Compute the gradient (wrt neural net weights) of a loss function
|
||||
// (defined as 1/2 of L^2 norm of the difference between neural
|
||||
// net output and desired output) and accumulate them to the already
|
||||
// computed per-weight gradient
|
||||
void backpropagate_l2(neural_net<T> const & nn, std::vector<T> const & output);
|
||||
|
||||
util::span<T const> gradient() const { return gradient_; }
|
||||
util::span<T> gradient() { return gradient_; }
|
||||
|
|
@ -79,7 +84,7 @@ namespace psemek::ml
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void neural_net_learner<T>::backpropagate(neural_net<T> const & nn, std::vector<T> const & output)
|
||||
void neural_net_learner<T>::backpropagate(neural_net<T> const & nn, std::vector<T> const & gradient)
|
||||
{
|
||||
if (nn.empty())
|
||||
throw empty_neural_net_error{};
|
||||
|
|
@ -88,8 +93,8 @@ namespace psemek::ml
|
|||
auto const activation_types = nn.activation_types();
|
||||
auto const weights = nn.weights();
|
||||
|
||||
if (output.size() != layer_sizes.back())
|
||||
throw wrong_neural_net_output_size(layer_sizes.back(), output.size());
|
||||
if (gradient.size() != layer_sizes.back())
|
||||
throw wrong_neural_net_output_size(layer_sizes.back(), gradient.size());
|
||||
|
||||
gradient_.resize(nn.weights().size());
|
||||
|
||||
|
|
@ -98,11 +103,11 @@ namespace psemek::ml
|
|||
{
|
||||
if (l + 2 == layer_sizes.size())
|
||||
{
|
||||
error_.resize(output.size());
|
||||
for (std::size_t i = 0; i < output.size(); ++i)
|
||||
error_.resize(gradient.size());
|
||||
for (std::size_t i = 0; i < gradient.size(); ++i)
|
||||
{
|
||||
T const value = layers_.back()[i];
|
||||
error_[i] = (value - output[i]) * activation_derivative(value, activation_types.back());
|
||||
error_[i] = gradient[i] * activation_derivative(value, activation_types.back());
|
||||
}
|
||||
}
|
||||
else
|
||||
|
|
@ -136,6 +141,24 @@ namespace psemek::ml
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void neural_net_learner<T>::backpropagate_l2(neural_net<T> const & nn, std::vector<T> const & output)
|
||||
{
|
||||
if (nn.empty())
|
||||
throw empty_neural_net_error{};
|
||||
|
||||
auto const layer_sizes = nn.layer_sizes();
|
||||
|
||||
if (output.size() != layer_sizes.back())
|
||||
throw wrong_neural_net_output_size(layer_sizes.back(), output.size());
|
||||
|
||||
error_tmp_.resize(output.size());
|
||||
for (std::size_t i = 0; i < output.size(); ++i)
|
||||
error_tmp_[i] = layers_.back()[i] - output[i];
|
||||
|
||||
backpropagate(nn, error_tmp_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T neural_net_learner<T>::gradient_norm() const
|
||||
{
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ test_case(ml_neural__net_gradient)
|
|||
|
||||
neural_net_learner<double> learner;
|
||||
learner.apply(nn, input);
|
||||
learner.backpropagate(nn, output);
|
||||
learner.backpropagate_l2(nn, output);
|
||||
|
||||
double const eps = 1e-6;
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ namespace
|
|||
{
|
||||
learner.apply(nn, data.first);
|
||||
error += l2_loss(data.second, learner.result()) / batch.size();
|
||||
learner.backpropagate(nn, data.second);
|
||||
learner.backpropagate_l2(nn, data.second);
|
||||
}
|
||||
if ((iteration % debug_report_frequency) == 0)
|
||||
debug() << "Iteration " << iteration << " error: " << error;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue