psemek/libs/ml/tests/neural_net/learn.cpp

150 lines
4.1 KiB
C++

#include <psemek/test/test.hpp>
#include <psemek/ml/neural_net/learner.hpp>
#include <psemek/ml/neural_net/evaluator.hpp>
#include <psemek/ml/neural_net/randomize.hpp>
#include <psemek/ml/neural_net/loss.hpp>
#include <psemek/random/generator.hpp>
#include <psemek/random/uniform.hpp>
#include <psemek/math/math.hpp>
#include <psemek/log/log.hpp>
using namespace psemek::ml;
using namespace psemek::random;
using namespace psemek::math;
using namespace psemek::log;
namespace
{
double learn_batch(neural_net<double> & nn, std::vector<std::pair<std::vector<double>, std::vector<double>>> const & batch, double mu, std::size_t iterations)
{
neural_net_learner<double> learner;
std::size_t const debug_report_frequency = iterations / 8;
double error = 0.0;
for (std::size_t iteration = 0; iteration < iterations; ++iteration)
{
learner.clear();
error = 0.0;
for (auto const & data : batch)
{
learner.apply(nn, data.first);
error += l2_loss(data.second, learner.result()) / batch.size();
learner.backpropagate_l2(nn, data.second);
}
if ((iteration % debug_report_frequency) == 0)
debug() << "Iteration " << iteration << " error: " << error;
learner.descend(nn, mu);
}
return error;
}
}
test_case(ml_neural__net_learn_not__2__layers)
{
generator rng;
neural_net<double> nn({1, 1}, activation_type::sigmoid);
randomize_normal(nn, rng);
std::vector<std::pair<std::vector<double>, std::vector<double>>> dataset;
dataset.push_back({{0.0}, {1.0}});
dataset.push_back({{1.0}, {0.0}});
double error = learn_batch(nn, dataset, 16.0, 1024);
expect_less(error, 1e-3);
info() << "Error: " << error;
}
test_case(ml_neural__net_learn_and__2__layers)
{
generator rng;
neural_net<double> nn({2, 1}, activation_type::sigmoid);
randomize_normal(nn, rng);
std::vector<std::pair<std::vector<double>, std::vector<double>>> dataset;
dataset.push_back({{0.0, 0.0}, {0.0}});
dataset.push_back({{1.0, 0.0}, {0.0}});
dataset.push_back({{0.0, 1.0}, {0.0}});
dataset.push_back({{1.0, 1.0}, {1.0}});
double error = learn_batch(nn, dataset, 16.0, 1024);
expect_less(error, 1e-3);
info() << "Error: " << error;
}
test_case(ml_neural__net_learn_and__3__layers)
{
generator rng;
neural_net<double> nn({2, 2, 1}, activation_type::sigmoid);
randomize_normal(nn, rng);
std::vector<std::pair<std::vector<double>, std::vector<double>>> dataset;
dataset.push_back({{0.0, 0.0}, {0.0}});
dataset.push_back({{1.0, 0.0}, {0.0}});
dataset.push_back({{0.0, 1.0}, {0.0}});
dataset.push_back({{1.0, 1.0}, {1.0}});
double error = learn_batch(nn, dataset, 1.0, 1024);
expect_less(error, 1e-3);
info() << "Error: " << error;
}
test_case(ml_neural__net_learn_or__2__layers)
{
generator rng;
neural_net<double> nn({2, 1}, activation_type::sigmoid);
randomize_normal(nn, rng);
std::vector<std::pair<std::vector<double>, std::vector<double>>> dataset;
dataset.push_back({{0.0, 0.0}, {0.0}});
dataset.push_back({{1.0, 0.0}, {1.0}});
dataset.push_back({{0.0, 1.0}, {1.0}});
dataset.push_back({{1.0, 1.0}, {1.0}});
double error = learn_batch(nn, dataset, 16.0, 1024);
expect_less(error, 1e-3);
info() << "Error: " << error;
}
test_case(ml_neural__net_learn_or__3__layers)
{
generator rng;
neural_net<double> nn({2, 2, 1}, activation_type::sigmoid);
randomize_normal(nn, rng);
std::vector<std::pair<std::vector<double>, std::vector<double>>> dataset;
dataset.push_back({{0.0, 0.0}, {0.0}});
dataset.push_back({{1.0, 0.0}, {1.0}});
dataset.push_back({{0.0, 1.0}, {1.0}});
dataset.push_back({{1.0, 1.0}, {1.0}});
double error = learn_batch(nn, dataset, 1.0, 1024);
expect_less(error, 1e-3);
info() << "Error: " << error;
}
test_case(ml_neural__net_learn_xor__3__layers)
{
generator rng;
neural_net<double> nn({2, 2, 1}, activation_type::sigmoid);
randomize_normal(nn, rng);
std::vector<std::pair<std::vector<double>, std::vector<double>>> dataset;
dataset.push_back({{0.0, 0.0}, {1.0}});
dataset.push_back({{1.0, 0.0}, {0.0}});
dataset.push_back({{0.0, 1.0}, {0.0}});
dataset.push_back({{1.0, 1.0}, {1.0}});
double error = learn_batch(nn, dataset, 1.0, 2048);
expect_less(error, 1e-3);
info() << "Error: " << error;
}