diff --git a/libs/async/include/psemek/async/future.hpp b/libs/async/include/psemek/async/future.hpp index d33d9139..f66702a7 100644 --- a/libs/async/include/psemek/async/future.hpp +++ b/libs/async/include/psemek/async/future.hpp @@ -50,24 +50,22 @@ namespace psemek::async using type = util::function; }; + struct cancel_token + {}; + template struct task_state { - std::atomic canceled = false; - bool const auto_cancel; - std::mutex value_mutex; typename task_value_container::type value{}; std::exception_ptr exception; - std::condition_variable value_cv; + std::shared_ptr shared_cancel; + std::weak_ptr weak_cancel; + std::mutex then_mutex; typename then_function::type then_func; - - task_state(bool auto_cancel) - : auto_cancel{auto_cancel} - {} }; } @@ -95,8 +93,9 @@ namespace psemek::async future() = default; future(future&&) = default; - future(std::shared_ptr> state) + future(std::shared_ptr> state, std::shared_ptr cancel) : state_(std::move(state)) + , cancel_(std::move(cancel)) {} future & operator = (future&&) = default; @@ -108,10 +107,8 @@ namespace psemek::async void reset() { - if (state_ && state_->auto_cancel) - cancel(); - state_.reset(); + cancel_.reset(); } explicit operator bool() const @@ -149,7 +146,7 @@ namespace psemek::async { if (!state_) throw empty_future_error{}; - if (state_->canceled) + if (!state_->weak_cancel.lock()) throw canceled_task_error{}; wait(); if (state_->value) @@ -165,8 +162,9 @@ namespace psemek::async void cancel() { - if (state_) - state_->canceled = true; + if (!state_) throw empty_future_error{}; + + cancel_.reset(); } bool ready() const @@ -179,6 +177,7 @@ namespace psemek::async private: std::shared_ptr> state_; + std::shared_ptr cancel_; bool has_value_unsafe() const { @@ -205,12 +204,26 @@ namespace psemek::async {} promise() - : state_{std::make_shared>(false)} - {} + : state_{std::make_shared>()} + , cancel_{std::make_shared()} + { + state_->shared_cancel = cancel_; + state_->weak_cancel = cancel_; + } - promise(auto_cancel_tag tag) - : state_{std::make_shared>(true)} - {} + promise(auto_cancel_tag) + : state_{std::make_shared>()} + , cancel_{std::make_shared()} + { + state_->weak_cancel = cancel_; + } + + promise(std::shared_ptr cancel) + : state_{std::make_shared>()} + , cancel_{cancel} + { + state_->weak_cancel = cancel_; + } explicit operator bool() const { @@ -249,13 +262,19 @@ namespace psemek::async state_->exception = std::move(e); } - future get_future() const + future get_future() { - return future(state_); + return future(state_, std::move(cancel_)); + } + + bool canceled() const + { + return !state_->weak_cancel.lock(); } private: std::shared_ptr> state_; + std::shared_ptr cancel_; }; template <> @@ -265,12 +284,26 @@ namespace psemek::async {} promise() - : state_{std::make_shared>(false)} - {} + : state_{std::make_shared>()} + , cancel_{std::make_shared()} + { + state_->shared_cancel = cancel_; + state_->weak_cancel = cancel_; + } promise(auto_cancel_tag) - : state_{std::make_shared>(true)} - {} + : state_{std::make_shared>()} + , cancel_{std::make_shared()} + { + state_->weak_cancel = cancel_; + } + + promise(std::shared_ptr cancel) + : state_{std::make_shared>()} + , cancel_{cancel} + { + state_->weak_cancel = cancel_; + } explicit operator bool() const { @@ -297,13 +330,19 @@ namespace psemek::async state_->exception = std::move(e); } - future get_future() const + future get_future() { - return future(state_); + return future(state_, std::move(cancel_)); + } + + bool canceled() const + { + return !state_->weak_cancel.lock(); } private: std::shared_ptr> state_; + std::shared_ptr cancel_; }; template @@ -335,12 +374,18 @@ namespace psemek::async , func_(std::forward(f)) {} + template + packaged_task(std::shared_ptr cancel, F && f) + : promise_(cancel) + , func_(std::forward(f)) + {} + explicit operator bool() const { return static_cast(promise_); } - future get_future() const + future get_future() { return promise_.get_future(); } @@ -348,6 +393,9 @@ namespace psemek::async template void operator() (Args1 && ... args) { + if (promise_.canceled()) + return; + try { if constexpr (std::is_same_v) @@ -427,7 +475,7 @@ namespace psemek::async if constexpr (std::is_same_v) { using R = decltype(f()); - packaged_task t(std::forward(f)); + packaged_task t(cancel_, std::forward(f)); auto fut = t.get_future(); std::lock_guard lock{state_->then_mutex}; @@ -437,7 +485,7 @@ namespace psemek::async else { using R = decltype(f(*(state_->value))); - packaged_task t(std::forward(f)); + packaged_task t(cancel_, std::forward(f)); auto fut = t.get_future(); std::lock_guard lock{state_->then_mutex};