diff --git a/libs/async/include/psemek/async/executor.hpp b/libs/async/include/psemek/async/executor.hpp index 8a940c90..74fab0a5 100644 --- a/libs/async/include/psemek/async/executor.hpp +++ b/libs/async/include/psemek/async/executor.hpp @@ -81,36 +81,40 @@ namespace psemek::async auto executor::dispatch(F && f, Args && ... args) { using R = decltype(f()); - auto state = std::make_shared>(false); - post(detail::wrap_task(state, std::forward(f), std::forward(args)...)); - return future(state); + packaged_task t([f = std::forward(f), ... args = std::forward(args)]() mutable { return std::forward(f)(std::forward(args)...); }); + auto fut = t.get_future(); + post(std::move(t)); + return fut; } template - auto executor::dispatch(auto_cancel_tag, F && f, Args && ... args) + auto executor::dispatch(auto_cancel_tag tag, F && f, Args && ... args) { using R = decltype(f()); - auto state = std::make_shared>(true); - post(detail::wrap_task(state, std::forward(f), std::forward(args)...)); - return future(state); + packaged_task t(tag, [f = std::forward(f), ... args = std::forward(args)]() mutable { return std::forward(f)(std::forward(args)...); }); + auto fut = t.get_future(); + post(std::move(t)); + return fut; } template auto executor::dispatch_at(TimePoint time, F && f, Args && ... args) { using R = decltype(f()); - auto state = std::make_shared>(false); - post_at(std::chrono::time_point_cast(time), detail::wrap_task(state, std::forward(f), std::forward(args)...)); - return future(state); + packaged_task t([f = std::forward(f), ... args = std::forward(args)]() mutable { return std::forward(f)(std::forward(args)...); }); + auto fut = t.get_future(); + post_at(time, std::move(t)); + return fut; } template - auto executor::dispatch_at(TimePoint time, auto_cancel_tag, F && f, Args && ... args) + auto executor::dispatch_at(TimePoint time, auto_cancel_tag tag, F && f, Args && ... args) { using R = decltype(f()); - auto state = std::make_shared>(true); - post_at(std::chrono::time_point_cast(time), detail::wrap_task(state, std::forward(f), std::forward(args)...)); - return future(state); + packaged_task t(tag, [f = std::forward(f), ... args = std::forward(args)]() mutable { return std::forward(f)(std::forward(args)...); }); + auto fur = t.get_future(); + post_at(time, std::move(t)); + return fur; } template diff --git a/libs/async/include/psemek/async/future.hpp b/libs/async/include/psemek/async/future.hpp index 2590f981..782ecfb3 100644 --- a/libs/async/include/psemek/async/future.hpp +++ b/libs/async/include/psemek/async/future.hpp @@ -1,16 +1,30 @@ #pragma once +#include + #include #include #include #include #include +#include namespace psemek::async { namespace detail { + template + struct get_return_type + { + using type = T &; + }; + + template <> + struct get_return_type + { + using type = void; + }; template struct task_value_container @@ -31,7 +45,7 @@ namespace psemek::async bool const auto_cancel; std::mutex value_mutex; - task_value_container::type value; + typename task_value_container::type value{}; std::exception_ptr exception; std::condition_variable value_cv; @@ -41,38 +55,6 @@ namespace psemek::async {} }; - template - auto wrap_task(std::shared_ptr> state, F && f, Args && ... args) - { - return [state, f = std::forward(f), ... args = std::forward(args)]() mutable { - if (state->canceled) return; - - try - { - if constexpr (std::is_same_v) - { - std::forward(f)(std::forward(args)...); - std::lock_guard lock{state->value_mutex}; - state->value = true; - state->value_cv.notify_all(); - } - else - { - auto value = std::forward(f)(std::forward(args)...); - std::lock_guard lock{state->value_mutex}; - state->value = std::move(value); - state->value_cv.notify_all(); - } - } - catch(...) - { - std::lock_guard lock{state->value_mutex}; - state->exception = std::current_exception(); - state->value_cv.notify_all(); - } - }; - } - } struct empty_future_error @@ -124,7 +106,7 @@ namespace psemek::async bool wait() const { - if (!state_) return false; + if (!state_) throw empty_future_error{}; std::unique_lock lock(state_->value_mutex); state_->value_cv.wait(lock, [this]{ return has_value_unsafe(); }); return true; @@ -133,7 +115,7 @@ namespace psemek::async template bool wait_for(Duration period) const { - if (!state_) return false; + if (!state_) throw empty_future_error{}; std::unique_lock lock(state_->value_mutex); state_->value_cv.wait_for(lock, period, [this]{ return has_value_unsafe(); }); return has_value_unsafe(); @@ -142,26 +124,25 @@ namespace psemek::async template bool wait_until(TimePoint time) const { - if (!state_) return false; + if (!state_) throw empty_future_error{}; std::unique_lock lock(state_->value_mutex); state_->value_cv.wait_until(lock, time, [this]{ return has_value_unsafe(); }); return has_value_unsafe(); } - T get() + detail::get_return_type::type get() { if (!state_) throw empty_future_error{}; if (state_->canceled) throw canceled_task_error{}; wait(); - std::lock_guard lock{state_->value_mutex}; if (state_->value) { if constexpr (std::is_same_v) return; else - return std::move(*(state_->value)); + return *(state_->value); } else std::rethrow_exception(state_->exception); @@ -187,4 +168,169 @@ namespace psemek::async } }; + struct empty_promise_error + : std::exception + { + char const * what() const noexcept { return "promise is empty"; } + }; + + struct satisfied_promise_error + : std::exception + { + char const * what() const noexcept { return "promise already contains a value or exception"; } + }; + + template + struct promise + { + promise(std::nullptr_t) + {} + + promise() + : state_{std::make_shared>(false)} + {} + + promise(auto_cancel_tag tag) + : state_{std::make_shared>(true)} + {} + + explicit operator bool() const + { + return static_cast(state_); + } + + void set_value(T const & value) + { + if (!state_) throw empty_promise_error{}; + std::lock_guard lock{state_->value_mutex}; + if (state_->value || state_->exception) throw satisfied_promise_error{}; + state_->value = value; + } + + void set_value(T && value) + { + if (!state_) throw empty_promise_error{}; + std::lock_guard lock{state_->value_mutex}; + if (state_->value || state_->exception) throw satisfied_promise_error{}; + state_->value = std::move(value); + } + + void set_exception(std::exception_ptr e) + { + if (!state_) throw empty_promise_error{}; + std::lock_guard lock{state_->value_mutex}; + if (state_->value || state_->exception) throw satisfied_promise_error{}; + state_->exception = std::move(e); + } + + future get_future() const + { + return future(state_); + } + + private: + std::shared_ptr> state_; + }; + + template <> + struct promise + { + promise(std::nullptr_t) + {} + + promise() + : state_{std::make_shared>(false)} + {} + + promise(auto_cancel_tag) + : state_{std::make_shared>(true)} + {} + + explicit operator bool() const + { + return static_cast(state_); + } + + void set_value() + { + if (!state_) throw empty_promise_error{}; + std::lock_guard lock{state_->value_mutex}; + if (state_->value || state_->exception) throw satisfied_promise_error{}; + state_->value = true; + } + + void set_exception(std::exception_ptr e) + { + if (!state_) throw empty_promise_error{}; + std::lock_guard lock{state_->value_mutex}; + if (state_->value || state_->exception) throw satisfied_promise_error{}; + state_->exception = std::move(e); + } + + future get_future() const + { + return future(state_); + } + + private: + std::shared_ptr> state_; + }; + + template + struct packaged_task; + + template + struct packaged_task + { + packaged_task() + : promise_(nullptr) + {} + + template + packaged_task(F && f) + : func_(std::forward(f)) + {} + + template + packaged_task(auto_cancel_tag tag, F && f) + : promise_(tag) + , func_(std::forward(f)) + {} + + explicit operator bool() const + { + return static_cast(promise_); + } + + future get_future() const + { + return promise_.get_future(); + } + + template + void operator() (Args1 && ... args) + { + try + { + if constexpr (std::is_same_v) + { + func_(std::forward(args)...); + promise_.set_value(); + } + else + { + promise_.set_value(func_(std::forward(args)...)); + } + } + catch(...) + { + promise_.set_exception(std::current_exception()); + } + } + + private: + promise promise_; + util::function func_; + }; + }