diff --git a/include/coro/condition_variable.hpp b/include/coro/condition_variable.hpp index d2db9a38..7ccbf95c 100644 --- a/include/coro/condition_variable.hpp +++ b/include/coro/condition_variable.hpp @@ -6,7 +6,7 @@ #include "coro/task.hpp" #include "coro/event.hpp" #include "coro/mutex.hpp" -#include "coro/when_any.hpp" +#include "coro/when_all.hpp" #include #include @@ -39,7 +39,7 @@ class condition_variable struct awaiter_base { - awaiter_base(coro::condition_variable& cv, coro::scoped_lock& l); + awaiter_base(coro::condition_variable& cv, coro::scoped_lock& l, coro::scoped_lock notify_lock); virtual ~awaiter_base() = default; awaiter_base(const awaiter_base&) = delete; @@ -55,6 +55,8 @@ class condition_variable coro::condition_variable& m_condition_variable; /// @brief The lock that the wait() was called with. coro::scoped_lock& m_lock; + /// @brief Lock used for mutual exclusion between notifiers and waiters + coro::scoped_lock m_notify_lock; /// @brief Each awaiter type defines its own notify behavior. /// @return The status of if the waiter's notify result. @@ -63,7 +65,7 @@ class condition_variable struct awaiter : public awaiter_base { - awaiter(coro::condition_variable& cv, coro::scoped_lock& l) noexcept; + awaiter(coro::condition_variable& cv, coro::scoped_lock& l, coro::scoped_lock notify_lock) noexcept; ~awaiter() override = default; awaiter(const awaiter&) = delete; @@ -83,6 +85,7 @@ class condition_variable awaiter_with_predicate( coro::condition_variable& cv, coro::scoped_lock& l, + coro::scoped_lock notify_lock, predicate_type p ) noexcept; ~awaiter_with_predicate() override = default; @@ -109,6 +112,7 @@ class condition_variable awaiter_with_predicate_stop_token( coro::condition_variable& cv, coro::scoped_lock& l, + coro::scoped_lock notify_lock, predicate_type p, std::stop_token stop_token ) noexcept; @@ -183,6 +187,7 @@ class condition_variable awaiter_with_wait_hook( coro::condition_variable& cv, coro::scoped_lock& l, + coro::scoped_lock notify_lock, controller_data& data ) noexcept; ~awaiter_with_wait_hook() override = default; @@ -199,11 +204,12 @@ class condition_variable std::unique_ptr& executor, coro::condition_variable& cv, coro::scoped_lock& l, + coro::scoped_lock notify_lock, const std::chrono::nanoseconds wait_for, std::optional predicate = std::nullopt, std::optional stop_token = std::nullopt ) noexcept - : awaiter_base(cv, l), + : awaiter_base(cv, l, std::move(notify_lock)), m_executor(executor), m_wait_for(wait_for), m_predicate(std::move(predicate)), @@ -277,8 +283,9 @@ class condition_variable { controller_data data{m_status, m_predicate_result, std::move(m_predicate), std::move(m_stop_token)}; // We enqueue the hook_task since we can make it live until the notify occurs and will properly resume the actual coroutine only once. - awaiter_with_wait_hook hook_task{m_condition_variable, m_lock, data}; + awaiter_with_wait_hook hook_task{m_condition_variable, m_lock, std::move(m_notify_lock), data}; detail::awaiter_list_push(m_condition_variable.m_awaiters, static_cast(&hook_task)); + hook_task.m_notify_lock.unlock(); m_lock.m_mutex->unlock(); // Unlock the actual lock now that we are setup, not the fake hook task. co_await coro::when_all(make_on_notify_callback_task(data), make_timeout_task(data)); @@ -378,16 +385,16 @@ class condition_variable /** * @brief Notifies all waiters and resumes them on the given executor. Note that each waiter must be notified synchronously so - * this is useful if the task is long lived and can be immediately parallelized after the condition is ready. This does not - * need to be co_await'ed like `notify_all()` since this will execute the notify on the given executor. + * this is useful if the task is long lived and can be immediately parallelized after the condition is ready. * * @tparam executor_type The type of executor that the waiters will be resumed on. * @param executor The executor that each waiter will be resumed on. - * @return void + * @return void coroutine to be awaited. */ template - auto notify_all(std::unique_ptr& executor) -> void + auto notify_all(std::unique_ptr& executor) -> coro::task { + co_await m_notify_mutex.scoped_lock(); auto* waiter = detail::awaiter_list_pop_all(m_awaiters); while (waiter != nullptr) @@ -400,7 +407,7 @@ class condition_variable waiter = next; } - return; + co_return; } @@ -408,23 +415,23 @@ class condition_variable * @brief Waits until notified. * * @param lock A lock that must be locked by the caller. - * @return awaiter + * @return void */ [[nodiscard]] auto wait( coro::scoped_lock& lock - ) -> awaiter; + ) -> coro::task; /** * @brief Waits until notified but only wakes up if the predicate passes. * * @param lock A lock that must be locked by the caller. * @param predicate The predicate to check whether the waiting can be completed. - * @return awaiter_with_predicate + * @return void */ [[nodiscard]] auto wait( coro::scoped_lock& lock, predicate_type predicate - ) -> awaiter_with_predicate; + ) -> coro::task; #ifndef EMSCRIPTEN /** @@ -433,13 +440,13 @@ class condition_variable * @param lock A lock which must be locked by the caller. * @param stop_token A stop token to register interruption for. * @param predicate The predicate to check whether the waiting can be completed. - * @return awaiter_with_predicate_stop_token The final predicate call result. + * @return bool The final predicate call result. */ [[nodiscard]] auto wait( coro::scoped_lock& lock, std::stop_token stop_token, predicate_type predicate - ) -> awaiter_with_predicate_stop_token; + ) -> coro::task; #endif #ifdef LIBCORO_FEATURE_NETWORKING @@ -449,9 +456,15 @@ class condition_variable std::unique_ptr& executor, coro::scoped_lock& lock, const std::chrono::duration wait_for - ) -> awaiter_with_wait + ) -> coro::task { - return awaiter_with_wait{executor, *this, lock, std::chrono::duration_cast(wait_for)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_wait{ + executor, + *this, + lock, + std::move(notify_lock), + std::chrono::duration_cast(wait_for)}; } template @@ -460,9 +473,16 @@ class condition_variable coro::scoped_lock& lock, const std::chrono::duration wait_for, predicate_type predicate - ) -> awaiter_with_wait + ) -> coro::task { - return awaiter_with_wait{executor, *this, lock, std::chrono::duration_cast(wait_for), std::move(predicate)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_wait{ + executor, + *this, + lock, + std::move(notify_lock), + std::chrono::duration_cast(wait_for), + std::move(predicate)}; } template @@ -472,9 +492,17 @@ class condition_variable std::stop_token stop_token, const std::chrono::duration wait_for, predicate_type predicate - ) -> awaiter_with_wait + ) -> coro::task { - return awaiter_with_wait{executor, *this, lock, std::chrono::duration_cast(wait_for), std::move(predicate), std::move(stop_token)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_wait{ + executor, + *this, + lock, + std::move(notify_lock), + std::chrono::duration_cast(wait_for), + std::move(predicate), + std::move(stop_token)}; } template @@ -482,11 +510,17 @@ class condition_variable std::unique_ptr& executor, coro::scoped_lock& lock, const std::chrono::time_point wait_until_time - ) -> awaiter_with_wait + ) -> coro::task { auto now = std::chrono::time_point::clock::now(); auto wait_for = (now < wait_until_time) ? (wait_until_time - now) : std::chrono::nanoseconds{1}; - return awaiter_with_wait{executor, *this, lock, std::chrono::duration_cast(wait_for)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_wait{ + executor, + *this, + lock, + std::move(notify_lock), + std::chrono::duration_cast(wait_for)}; } template @@ -495,11 +529,18 @@ class condition_variable coro::scoped_lock& lock, const std::chrono::time_point wait_until_time, predicate_type predicate - ) -> awaiter_with_wait + ) -> coro::task { auto now = std::chrono::time_point::clock::now(); auto wait_for = (now < wait_until_time) ? (wait_until_time - now) : std::chrono::nanoseconds{1}; - return awaiter_with_wait{executor, *this, lock, std::chrono::duration_cast(wait_for), std::move(predicate)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_wait{ + executor, + *this, + lock, + std::move(notify_lock), + std::chrono::duration_cast(wait_for), + std::move(predicate)}; } template @@ -509,20 +550,31 @@ class condition_variable std::stop_token stop_token, const std::chrono::time_point wait_until_time, predicate_type predicate - ) -> awaiter_with_wait + ) -> coro::task { auto now = std::chrono::time_point::clock::now(); auto wait_for = (now < wait_until_time) ? (wait_until_time - now) : std::chrono::nanoseconds{1}; - return awaiter_with_wait{executor, *this, lock, std::chrono::duration_cast(wait_for), std::move(predicate), std::move(stop_token)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_wait{ + executor, + *this, + lock, + std::move(notify_lock), + std::chrono::duration_cast(wait_for), + std::move(predicate), + std::move(stop_token)}; } #endif private: /// @brief The list of waiters. std::atomic m_awaiters{nullptr}; + /// @brief mutual exclusion for notification/arrival + coro::mutex m_notify_mutex; auto make_notify_all_executor_individual_task(awaiter_base* waiter) -> coro::task { + // Precondition: m_notify_mutex is held. switch (co_await waiter->on_notify()) { case notify_status_t::not_ready: diff --git a/src/condition_variable.cpp b/src/condition_variable.cpp index aa641d46..2f40bcc6 100644 --- a/src/condition_variable.cpp +++ b/src/condition_variable.cpp @@ -1,25 +1,20 @@ #include "coro/condition_variable.hpp" -#include "coro/sync_wait.hpp" namespace coro { condition_variable::awaiter_base::awaiter_base( - coro::condition_variable& cv, - coro::scoped_lock& l) + coro::condition_variable& cv, coro::scoped_lock& l, coro::scoped_lock notify_lock) : m_condition_variable(cv), - m_lock(l) + m_lock(l), + m_notify_lock(std::move(notify_lock)) { - } condition_variable::awaiter::awaiter( - coro::condition_variable& cv, - coro::scoped_lock& l -) noexcept - : awaiter_base(cv, l) + coro::condition_variable& cv, coro::scoped_lock& l, coro::scoped_lock notify_lock) noexcept + : awaiter_base(cv, l, std::move(notify_lock)) { - } auto condition_variable::awaiter::await_ready() const noexcept -> bool @@ -31,6 +26,7 @@ auto condition_variable::awaiter::await_suspend(std::coroutine_handle<> awaiting { m_awaiting_coroutine = awaiting_coroutine; coro::detail::awaiter_list_push(m_condition_variable.m_awaiters, static_cast(this)); + m_notify_lock.unlock(); m_lock.m_mutex->unlock(); return true; } @@ -44,23 +40,26 @@ auto condition_variable::awaiter::on_notify() -> coro::task bool { return m_predicate(); } -auto condition_variable::awaiter_with_predicate::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool +auto condition_variable::awaiter_with_predicate::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept + -> bool { m_awaiting_coroutine = awaiting_coroutine; coro::detail::awaiter_list_push(m_condition_variable.m_awaiters, static_cast(this)); + m_notify_lock.unlock(); m_lock.m_mutex->unlock(); return true; } @@ -81,16 +80,15 @@ auto condition_variable::awaiter_with_predicate::on_notify() -> coro::task bool @@ -99,15 +97,18 @@ auto condition_variable::awaiter_with_predicate_stop_token::await_ready() noexce return m_predicate_result; } -auto condition_variable::awaiter_with_predicate_stop_token::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool +auto condition_variable::awaiter_with_predicate_stop_token::await_suspend( + std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { m_awaiting_coroutine = awaiting_coroutine; coro::detail::awaiter_list_push(m_condition_variable.m_awaiters, static_cast(this)); + m_notify_lock.unlock(); m_lock.m_mutex->unlock(); return true; } -auto condition_variable::awaiter_with_predicate_stop_token::on_notify() -> coro::task +auto condition_variable::awaiter_with_predicate_stop_token::on_notify() + -> coro::task { co_await m_lock.m_mutex->lock(); m_predicate_result = m_predicate(); @@ -128,28 +129,25 @@ auto condition_variable::awaiter_with_predicate_stop_token::on_notify() -> coro: #ifdef LIBCORO_FEATURE_NETWORKING condition_variable::controller_data::controller_data( - std::optional& status, - bool& predicate_result, + std::optional& status, + bool& predicate_result, std::optional predicate, - std::optional stop_token -) noexcept + std::optional stop_token) noexcept : m_status(status), m_predicate_result(predicate_result), m_predicate(std::move(predicate)), m_stop_token({std::move(stop_token)}) { - } condition_variable::awaiter_with_wait_hook::awaiter_with_wait_hook( - coro::condition_variable& cv, - coro::scoped_lock& l, - condition_variable::controller_data& data -) noexcept - : awaiter_base(cv, l), + coro::condition_variable& cv, + coro::scoped_lock& l, + coro::scoped_lock notify_lock, + condition_variable::controller_data& data) noexcept + : awaiter_base(cv, l, std::move(notify_lock)), m_data(data) { - } auto condition_variable::awaiter_with_wait_hook::on_notify() -> coro::task @@ -199,6 +197,7 @@ auto condition_variable::awaiter_with_wait_hook::on_notify() -> coro::task coro::task { // The loop is here in case there are *dead* awaiter_hook_tasks that need to be skipped. + co_await m_notify_mutex.scoped_lock(); while (true) { auto* waiter = detail::awaiter_list_pop(m_awaiters); @@ -225,6 +224,7 @@ auto condition_variable::notify_one() -> coro::task auto condition_variable::notify_all() -> coro::task { + co_await m_notify_mutex.scoped_lock(); auto* waiter = detail::awaiter_list_pop_all(m_awaiters); while (waiter != nullptr) @@ -235,7 +235,7 @@ auto condition_variable::notify_all() -> coro::task switch (co_await waiter->on_notify()) { case notify_status_t::not_ready: - // Re-enqueue since the predicate isn't ready and return since the notify has been satisfied. + // Re-enqueue since the predicate isn't ready. coro::detail::awaiter_list_push(m_awaiters, waiter); break; case notify_status_t::ready: @@ -250,28 +250,27 @@ auto condition_variable::notify_all() -> coro::task co_return; } -auto condition_variable::wait(coro::scoped_lock& lock) -> awaiter +auto condition_variable::wait(coro::scoped_lock& lock) -> coro::task { - return awaiter{*this, lock}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter{*this, lock, std::move(notify_lock)}; } -auto condition_variable::wait( - coro::scoped_lock& lock, - condition_variable::predicate_type predicate -) -> awaiter_with_predicate +auto condition_variable::wait(coro::scoped_lock& lock, condition_variable::predicate_type predicate) -> coro::task { - return awaiter_with_predicate{*this, lock, std::move(predicate)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_predicate{*this, lock, std::move(notify_lock), std::move(predicate)}; } #ifndef EMSCRIPTEN auto condition_variable::wait( - coro::scoped_lock& lock, - std::stop_token stop_token, - condition_variable::predicate_type predicate -) -> awaiter_with_predicate_stop_token + coro::scoped_lock& lock, std::stop_token stop_token, condition_variable::predicate_type predicate) + -> coro::task { - return awaiter_with_predicate_stop_token{*this, lock, std::move(predicate), std::move(stop_token)}; + auto notify_lock = co_await m_notify_mutex.scoped_lock(); + co_return co_await awaiter_with_predicate_stop_token{ + *this, lock, std::move(notify_lock), std::move(predicate), std::move(stop_token)}; } #endif diff --git a/test/test_condition_variable.cpp b/test/test_condition_variable.cpp index 422b934e..dd5e4f51 100644 --- a/test/test_condition_variable.cpp +++ b/test/test_condition_variable.cpp @@ -924,7 +924,7 @@ TEST_CASE("notify_all(executor)", "[condition_variable]") } std::cerr << "notify_all(s)\n"; - cv.notify_all(s); + co_await cv.notify_all(s); co_return 0; };