Remove obsolete 'optimize' mode from neural animation example
This commit is contained in:
parent
f73f536d8d
commit
25e0480805
1 changed files with 3 additions and 75 deletions
|
|
@ -669,13 +669,12 @@ struct animation_2d_app
|
|||
enum class mode
|
||||
{
|
||||
train,
|
||||
optimize,
|
||||
test,
|
||||
} mode = mode::train;
|
||||
|
||||
std::vector<controller> population;
|
||||
std::size_t const population_size = 256;
|
||||
std::size_t const max_train_frames = 60.f / physics.dt;
|
||||
std::size_t const population_size = 16;
|
||||
std::size_t const max_train_frames = 10.f / physics.dt;
|
||||
std::size_t const max_train_variations = 1;
|
||||
float const position_variation_amplitude = 0.f;
|
||||
float const angle_variation_amplitude = 0.f;
|
||||
|
|
@ -686,10 +685,6 @@ struct animation_2d_app
|
|||
static constexpr auto mutation_amplitude = [](float t){ return 10.f * geom::lerp(1.f, 0.01f, t); };
|
||||
// static constexpr auto mutation_amplitude = [](float){ return 0.00001f; };
|
||||
|
||||
std::size_t optimize_iterations = 0;
|
||||
std::size_t const max_optimize_iterations = 1;
|
||||
float const optimize_amplitude = 0.1f;
|
||||
|
||||
float best_score = 0.f;
|
||||
bool const warm_start = false;
|
||||
bool const enable_testing = true;
|
||||
|
|
@ -820,18 +815,6 @@ void animation_2d_app::update()
|
|||
{
|
||||
do_train();
|
||||
if (train_iterations >= max_train_iterations)
|
||||
{
|
||||
mode = mode::optimize;
|
||||
}
|
||||
else if (train_iterations >= max_train_iterations / 2 && best_score < -1000.f)
|
||||
{
|
||||
// train_iterations = 0;
|
||||
}
|
||||
}
|
||||
else if (mode == mode::optimize)
|
||||
{
|
||||
do_optimize();
|
||||
if (optimize_iterations >= max_optimize_iterations)
|
||||
{
|
||||
mode = mode::test;
|
||||
frame_clock.restart();
|
||||
|
|
@ -1243,61 +1226,6 @@ void animation_2d_app::do_train()
|
|||
best_score = scores.front().first;
|
||||
}
|
||||
|
||||
void animation_2d_app::do_optimize()
|
||||
{
|
||||
auto & c = population[0];
|
||||
|
||||
auto rng = this->rng;
|
||||
|
||||
float current_score = eval_score(c, rng);
|
||||
|
||||
float eps = 0.1f;
|
||||
|
||||
std::vector<float> scores(c.param_count, 0.f);
|
||||
std::atomic<std::size_t> dispatched{0};
|
||||
for (std::size_t i = 0; i < c.param_count; ++i)
|
||||
{
|
||||
bg.dispatch([&, i, rng = rng]() mutable {
|
||||
controller cc = c;
|
||||
cc.params()[i] += eps;
|
||||
scores[i] = eval_score(cc, rng);
|
||||
++dispatched;
|
||||
});
|
||||
}
|
||||
// bg.wait();
|
||||
while (dispatched.load() < scores.size());
|
||||
|
||||
for (auto & s : scores)
|
||||
s = (s - current_score) / eps;
|
||||
|
||||
float m = 0.f;
|
||||
for (auto s : scores)
|
||||
m = std::max(m, std::abs(s));
|
||||
|
||||
log::info() << "Max gradient: " << m;
|
||||
|
||||
for (std::size_t k = 0; k < 30; ++k)
|
||||
{
|
||||
auto cc = c;
|
||||
for (std::size_t i = 0; i < c.param_count; ++i)
|
||||
cc.params()[i] += optimize_amplitude * scores[i];
|
||||
float new_score = eval_score(cc, rng);
|
||||
|
||||
if (new_score > current_score)
|
||||
{
|
||||
log::info() << "k = " << k;
|
||||
log::info() << "Score change: " << current_score << " -> " << new_score;
|
||||
c = cc;
|
||||
break;
|
||||
}
|
||||
|
||||
for (auto & s : scores)
|
||||
s /= 2.f;
|
||||
}
|
||||
|
||||
++optimize_iterations;
|
||||
}
|
||||
|
||||
void animation_2d_app::do_test()
|
||||
{
|
||||
std::optional<geom::point<float, 2>> m;
|
||||
|
|
@ -1466,7 +1394,7 @@ void animation_2d_app::present()
|
|||
opts.x = gfx::painter::x_align::left;
|
||||
opts.y = gfx::painter::y_align::top;
|
||||
opts.scale = 2.f;
|
||||
painter.text({40.f, 40.f}, util::to_string(train_iterations, "/", max_train_iterations, " ", optimize_iterations, "/", max_optimize_iterations), opts);
|
||||
painter.text({40.f, 40.f}, util::to_string(train_iterations, "/", max_train_iterations), opts);
|
||||
painter.text({40.f, 64.f}, util::to_string("Best score: ", std::setprecision(10), best_score), opts);
|
||||
painter.text({40.f, 88.f}, util::to_string("Model: ", test_id, "/", population.size(), ", gen ", population[test_id].generation), opts);
|
||||
painter.text({40.f, 112.f}, util::to_string(util::pretty(test_clock.duration(), std::chrono::milliseconds{1})), opts);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue