Fixes in neural animation example
This commit is contained in:
parent
25e0480805
commit
63e2d226d2
1 changed files with 9 additions and 9 deletions
|
|
@ -673,7 +673,7 @@ struct animation_2d_app
|
|||
} mode = mode::train;
|
||||
|
||||
std::vector<controller> population;
|
||||
std::size_t const population_size = 16;
|
||||
std::size_t const population_size = 256;
|
||||
std::size_t const max_train_frames = 10.f / physics.dt;
|
||||
std::size_t const max_train_variations = 1;
|
||||
float const position_variation_amplitude = 0.f;
|
||||
|
|
@ -683,13 +683,12 @@ struct animation_2d_app
|
|||
std::size_t const max_train_iterations = 1024*8;
|
||||
float const randomize_amplitude = 10.f;
|
||||
static constexpr auto mutation_amplitude = [](float t){ return 10.f * geom::lerp(1.f, 0.01f, t); };
|
||||
// static constexpr auto mutation_amplitude = [](float){ return 0.00001f; };
|
||||
|
||||
float best_score = 0.f;
|
||||
bool const warm_start = false;
|
||||
bool const enable_testing = true;
|
||||
bool testing_control = true;
|
||||
std::string const cache_location = "/home/lisyarus/runner";
|
||||
std::string const cache_location = "./runner";
|
||||
|
||||
std::size_t test_id = 0;
|
||||
std::vector<float> test_speeds;
|
||||
|
|
@ -1201,17 +1200,18 @@ void animation_2d_app::do_train()
|
|||
}
|
||||
|
||||
std::vector<std::pair<float, std::size_t>> scores(population.size());
|
||||
std::vector<psemek::async::future<void>> futures;
|
||||
std::atomic<std::size_t> finished_count{0};
|
||||
for (std::size_t i = 0; i < population.size(); ++i)
|
||||
{
|
||||
futures.push_back(bg.dispatch([&, i, rng = rng]() mutable {
|
||||
bg.dispatch([&, i, rng = rng]() mutable {
|
||||
scores[i] = {eval_score(population[i], rng), i};
|
||||
}));
|
||||
++finished_count;
|
||||
});
|
||||
}
|
||||
// bg.wait();
|
||||
// while (trained.load() != population.size());
|
||||
bg.wait();
|
||||
|
||||
bg.wait_all(futures.begin(), futures.end()).get();
|
||||
if (finished_count.load() != population.size())
|
||||
throw std::runtime_error("bg.wait() didn't wait for all tasks to finish");
|
||||
|
||||
std::sort(scores.begin(), scores.end(), [](auto const & p1, auto const & p2){ return p1.first > p2.first; });
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue