Fix task cancelling

This commit is contained in:
Nikita Lisitsa 2021-03-04 20:30:12 +03:00
parent 8091183375
commit 104ecb528a

View file

@ -50,24 +50,22 @@ namespace psemek::async
using type = util::function<void()>; using type = util::function<void()>;
}; };
struct cancel_token
{};
template <typename T> template <typename T>
struct task_state struct task_state
{ {
std::atomic<bool> canceled = false;
bool const auto_cancel;
std::mutex value_mutex; std::mutex value_mutex;
typename task_value_container<T>::type value{}; typename task_value_container<T>::type value{};
std::exception_ptr exception; std::exception_ptr exception;
std::condition_variable value_cv; std::condition_variable value_cv;
std::shared_ptr<cancel_token> shared_cancel;
std::weak_ptr<cancel_token> weak_cancel;
std::mutex then_mutex; std::mutex then_mutex;
typename then_function<T>::type then_func; typename then_function<T>::type then_func;
task_state(bool auto_cancel)
: auto_cancel{auto_cancel}
{}
}; };
} }
@ -95,8 +93,9 @@ namespace psemek::async
future() = default; future() = default;
future(future&&) = default; future(future&&) = default;
future(std::shared_ptr<detail::task_state<T>> state) future(std::shared_ptr<detail::task_state<T>> state, std::shared_ptr<detail::cancel_token> cancel)
: state_(std::move(state)) : state_(std::move(state))
, cancel_(std::move(cancel))
{} {}
future & operator = (future&&) = default; future & operator = (future&&) = default;
@ -108,10 +107,8 @@ namespace psemek::async
void reset() void reset()
{ {
if (state_ && state_->auto_cancel)
cancel();
state_.reset(); state_.reset();
cancel_.reset();
} }
explicit operator bool() const explicit operator bool() const
@ -149,7 +146,7 @@ namespace psemek::async
{ {
if (!state_) if (!state_)
throw empty_future_error{}; throw empty_future_error{};
if (state_->canceled) if (!state_->weak_cancel.lock())
throw canceled_task_error{}; throw canceled_task_error{};
wait(); wait();
if (state_->value) if (state_->value)
@ -165,8 +162,9 @@ namespace psemek::async
void cancel() void cancel()
{ {
if (state_) if (!state_) throw empty_future_error{};
state_->canceled = true;
cancel_.reset();
} }
bool ready() const bool ready() const
@ -179,6 +177,7 @@ namespace psemek::async
private: private:
std::shared_ptr<detail::task_state<T>> state_; std::shared_ptr<detail::task_state<T>> state_;
std::shared_ptr<detail::cancel_token> cancel_;
bool has_value_unsafe() const bool has_value_unsafe() const
{ {
@ -205,12 +204,26 @@ namespace psemek::async
{} {}
promise() promise()
: state_{std::make_shared<detail::task_state<T>>(false)} : state_{std::make_shared<detail::task_state<T>>()}
{} , cancel_{std::make_shared<detail::cancel_token>()}
{
state_->shared_cancel = cancel_;
state_->weak_cancel = cancel_;
}
promise(auto_cancel_tag tag) promise(auto_cancel_tag)
: state_{std::make_shared<detail::task_state<T>>(true)} : state_{std::make_shared<detail::task_state<T>>()}
{} , cancel_{std::make_shared<detail::cancel_token>()}
{
state_->weak_cancel = cancel_;
}
promise(std::shared_ptr<detail::cancel_token> cancel)
: state_{std::make_shared<detail::task_state<T>>()}
, cancel_{cancel}
{
state_->weak_cancel = cancel_;
}
explicit operator bool() const explicit operator bool() const
{ {
@ -249,13 +262,19 @@ namespace psemek::async
state_->exception = std::move(e); state_->exception = std::move(e);
} }
future<T> get_future() const future<T> get_future()
{ {
return future<T>(state_); return future<T>(state_, std::move(cancel_));
}
bool canceled() const
{
return !state_->weak_cancel.lock();
} }
private: private:
std::shared_ptr<detail::task_state<T>> state_; std::shared_ptr<detail::task_state<T>> state_;
std::shared_ptr<detail::cancel_token> cancel_;
}; };
template <> template <>
@ -265,12 +284,26 @@ namespace psemek::async
{} {}
promise() promise()
: state_{std::make_shared<detail::task_state<void>>(false)} : state_{std::make_shared<detail::task_state<void>>()}
{} , cancel_{std::make_shared<detail::cancel_token>()}
{
state_->shared_cancel = cancel_;
state_->weak_cancel = cancel_;
}
promise(auto_cancel_tag) promise(auto_cancel_tag)
: state_{std::make_shared<detail::task_state<void>>(true)} : state_{std::make_shared<detail::task_state<void>>()}
{} , cancel_{std::make_shared<detail::cancel_token>()}
{
state_->weak_cancel = cancel_;
}
promise(std::shared_ptr<detail::cancel_token> cancel)
: state_{std::make_shared<detail::task_state<void>>()}
, cancel_{cancel}
{
state_->weak_cancel = cancel_;
}
explicit operator bool() const explicit operator bool() const
{ {
@ -297,13 +330,19 @@ namespace psemek::async
state_->exception = std::move(e); state_->exception = std::move(e);
} }
future<void> get_future() const future<void> get_future()
{ {
return future<void>(state_); return future<void>(state_, std::move(cancel_));
}
bool canceled() const
{
return !state_->weak_cancel.lock();
} }
private: private:
std::shared_ptr<detail::task_state<void>> state_; std::shared_ptr<detail::task_state<void>> state_;
std::shared_ptr<detail::cancel_token> cancel_;
}; };
template <typename T> template <typename T>
@ -335,12 +374,18 @@ namespace psemek::async
, func_(std::forward<F>(f)) , func_(std::forward<F>(f))
{} {}
template <typename F>
packaged_task(std::shared_ptr<detail::cancel_token> cancel, F && f)
: promise_(cancel)
, func_(std::forward<F>(f))
{}
explicit operator bool() const explicit operator bool() const
{ {
return static_cast<bool>(promise_); return static_cast<bool>(promise_);
} }
future<R> get_future() const future<R> get_future()
{ {
return promise_.get_future(); return promise_.get_future();
} }
@ -348,6 +393,9 @@ namespace psemek::async
template <typename ... Args1> template <typename ... Args1>
void operator() (Args1 && ... args) void operator() (Args1 && ... args)
{ {
if (promise_.canceled())
return;
try try
{ {
if constexpr (std::is_same_v<R, void>) if constexpr (std::is_same_v<R, void>)
@ -427,7 +475,7 @@ namespace psemek::async
if constexpr (std::is_same_v<T, void>) if constexpr (std::is_same_v<T, void>)
{ {
using R = decltype(f()); using R = decltype(f());
packaged_task<R()> t(std::forward<F>(f)); packaged_task<R()> t(cancel_, std::forward<F>(f));
auto fut = t.get_future(); auto fut = t.get_future();
std::lock_guard lock{state_->then_mutex}; std::lock_guard lock{state_->then_mutex};
@ -437,7 +485,7 @@ namespace psemek::async
else else
{ {
using R = decltype(f(*(state_->value))); using R = decltype(f(*(state_->value)));
packaged_task<R(T &)> t(std::forward<F>(f)); packaged_task<R(T &)> t(cancel_, std::forward<F>(f));
auto fut = t.get_future(); auto fut = t.get_future();
std::lock_guard lock{state_->then_mutex}; std::lock_guard lock{state_->then_mutex};