Add basic behavior_tree implementation

This commit is contained in:
Nikita Lisitsa 2021-10-16 13:26:37 +03:00
parent d3b2790f97
commit b8d03d41a4

View file

@ -0,0 +1,459 @@
#pragma once
#include <variant>
#include <tuple>
#include <memory>
#include <vector>
namespace psemek::util
{
template <typename Time, typename Event, typename ... Args>
struct behavior_tree
{
struct running
{};
struct finished
{
bool result;
};
struct suspended
{
Time duration;
};
using status = std::variant<running, finished, suspended>;
struct any_node_base
{
virtual void start(Args ...) = 0;
virtual status update(Time, Args ...) = 0;
virtual bool event(Event const &) = 0;
virtual ~any_node_base() {}
};
template <typename Node>
struct any_node_impl
: any_node_base
{
Node node;
any_node_impl(Node node)
: node(std::move(node))
{}
void start(Args ... args) override
{
node.start(args...);
}
status update(Time dt, Args ... args) override
{
return node.update(dt, args...);
}
bool event(Event const & e) override
{
return node.event(e);
}
};
struct any_node
{
template <typename Node>
any_node(Node node)
: impl_(new any_node_impl<Node>{std::move(node)})
{}
void start(Args ... args)
{
impl_->start(args...);
}
status update(Time dt, Args ... args)
{
return impl_->update(dt, args...);
}
bool event(Event const & e)
{
return impl_->event(e);
}
private:
std::unique_ptr<any_node_base> impl_;
};
struct updater
{
struct node_data
{
any_node node;
std::tuple<Args...> args;
Time duration;
Time elapsed;
};
template <typename Node>
void add(Node node, Args ... args)
{
new_.push_back(node_data{std::move(node), {args...}, 0, 0});
}
struct update_statistics
{
int active_count = 0;
int suspended_count = 0;
};
update_statistics update(Time dt)
{
activate();
fill_dt(dt);
wake_up(dt);
update_statistics result;
result.active_count = active_.size();
result.suspended_count = suspended_.size();
do_update();
return result;
}
private:
std::vector<node_data> new_;
std::vector<node_data> active_;
std::vector<node_data> suspended_;
void fill_dt(Time dt)
{
for (auto & n : active_)
n.elapsed = dt;
}
void activate()
{
for (auto & n : new_)
{
std::apply([&](Args ... args){ n.node.start(args...); }, n.args);
active_.push_back(std::move(n));
}
new_.clear();
}
void wake_up(Time dt)
{
std::vector<node_data> new_suspended;
for (auto & n : suspended_)
{
n.elapsed += dt;
if (n.elapsed >= n.duration)
active_.push_back(std::move(n));
else
new_suspended.push_back(std::move(n));
}
suspended_ = std::move(new_suspended);
}
void do_update()
{
std::vector<node_data> new_active;
for (auto & n : active_)
{
auto result = std::apply([&](Args ... args){ return n.node.update(n.elapsed, args...); }, n.args);
if (auto s = std::get_if<finished>(&result))
{
new_.push_back(std::move(n));
}
else if (auto s = std::get_if<suspended>(&result))
{
n.elapsed = 0;
n.duration = s->duration;
suspended_.push_back(std::move(n));
}
else
new_active.push_back(std::move(n));
}
active_ = std::move(new_active);
}
};
template <typename TimeFn>
struct wait
{
TimeFn timeFn;
wait(TimeFn timeFn)
: timeFn(std::move(timeFn))
{}
Time remaining_time = Time{0};
void start(Args ... args)
{
remaining_time = timeFn(args...);
}
status update(Time dt, Args ... args)
{
remaining_time -= dt;
if (remaining_time <= 0)
return finished{true};
return suspended{remaining_time};
}
bool event(Event const &)
{
return false;
}
};
template <typename CondFn>
struct condition
{
CondFn condFn;
condition(CondFn condFn)
: condFn(std::move(condFn))
{}
void start(Args ...)
{}
status update(Time dt, Args ... args)
{
return finished{condFn(args...)};
}
bool event(Event const &)
{
return false;
}
};
template <typename Child>
struct success
{
Child child;
success(Child child)
: child(std::move(child))
{}
void start(Args ... args)
{
child.start(args...);
}
status update(Time dt, Args ... args)
{
auto child_status = child.update(dt, std::forward<Args>(args)...);
if (child_status.index() == 1)
return finished{true};
return child_status;
}
bool event(Event const & e)
{
return child.event(e);
}
};
template <typename Child>
struct failure
{
Child child;
failure(Child child)
: child(std::move(child))
{}
void start(Args ... args)
{
child.start(args...);
}
status update(Time dt, Args ... args)
{
auto child_status = child.update(dt, std::forward<Args>(args)...);
if (child_status.index() == 1)
return finished{false};
return child_status;
}
bool event(Event const & e)
{
return child.event(e);
}
};
template <typename Child>
struct repeat
{
Child child;
repeat(Child child)
: child(std::move(child))
{}
void start(Args ... args)
{
child.start(args...);
}
status update(Time dt, Args ... args)
{
auto result = child.update(dt, std::forward<Args>(args)...);
if (auto f = std::get_if<finished>(&result))
{
if (!(f->result))
return finished{true};
else
{
start(args...);
return running{};
}
}
return result;
}
bool event(Event const & e)
{
return child.event(e);
}
};
template <typename Child>
struct retry
{
Child child;
retry(Child child)
: child(std::move(child))
{}
void start(Args ... args)
{
child.start(args...);
}
status update(Time dt, Args ... args)
{
auto result = child.update(dt, std::forward<Args>(args)...);
if (auto f = std::get_if<finished>(&result))
{
if (f->result)
return finished{true};
else
{
start(args...);
return running{};
}
}
return result;
}
bool event(Event const & e)
{
return child.event(e);
}
};
template <typename ... Children>
struct sequence
{
std::tuple<Children...> children;
bool current_started = false;
size_t current = 0;
sequence(Children ... children)
: children{std::move(children)...}
{}
void start(Args ...)
{
if constexpr (sizeof...(Children) == 0)
{
current = 1;
}
else
{
current = 0;
current_started = false;
}
}
status update(Time dt, Args ... args)
{
return update_impl<0>(dt, args...);
}
bool event(Event const & e)
{
return event_impl<0>(e);
}
private:
template <size_t I>
status update_impl(Time dt, Args ... args)
{
if constexpr (I == sizeof...(Children))
{
return finished{true};
}
else
{
if (current != I)
{
return update_impl<I + 1>(dt, args...);
}
else
{
if (!current_started)
{
std::get<I>(children).start(args...);
current_started = true;
}
auto result = std::get<I>(children).update(dt, args...);
if (auto f = std::get_if<finished>(&result))
{
if (f->result)
{
current_started = false;
++current;
return update_impl<I + 1>(dt, args...);
}
}
return result;
}
}
}
template <size_t I>
bool event_impl(Event const & e)
{
if constexpr (I == sizeof...(Children))
{
return false;
}
else
{
if (current == I)
{
return std::get<I>(children).event(e);
}
else
{
return event_impl<I + 1>(e);
}
}
}
};
};
}