diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 3a6c8b8270c..7b9fa517560 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -449,7 +449,7 @@ cfg_sync! { /// Named future types. pub mod futures { - pub use super::notify::Notified; + pub use super::notify::{Notified, OwnedNotified}; } mod barrier; diff --git a/tokio/src/sync/notify.rs b/tokio/src/sync/notify.rs index dbdb9b15609..d460797936d 100644 --- a/tokio/src/sync/notify.rs +++ b/tokio/src/sync/notify.rs @@ -17,6 +17,7 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; use std::pin::Pin; use std::ptr::NonNull; use std::sync::atomic::Ordering::{self, Acquire, Relaxed, Release, SeqCst}; +use std::sync::Arc; use std::task::{Context, Poll, Waker}; type WaitList = LinkedList::Target>; @@ -397,6 +398,38 @@ pub struct Notified<'a> { unsafe impl<'a> Send for Notified<'a> {} unsafe impl<'a> Sync for Notified<'a> {} +/// Future returned from [`Notify::notified_owned()`]. +/// +/// This future is fused, so once it has completed, any future calls to poll +/// will immediately return `Poll::Ready`. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct OwnedNotified { + /// The `Notify` being received on. + notify: Arc, + + /// The current state of the receiving process. + state: State, + + /// Number of calls to `notify_waiters` at the time of creation. + notify_waiters_calls: usize, + + /// Entry in the waiter `LinkedList`. + waiter: Waiter, +} + +unsafe impl Sync for OwnedNotified {} + +/// A custom `project` implementation is used in place of `pin-project-lite` +/// as a custom drop for [`Notified`] and [`OwnedNotified`] implementation +/// is needed. +struct NotifiedProject<'a> { + notify: &'a Notify, + state: &'a mut State, + notify_waiters_calls: &'a usize, + waiter: &'a Waiter, +} + #[derive(Debug)] enum State { Init, @@ -541,6 +574,53 @@ impl Notify { } } + /// Wait for a notification with an owned `Future`. + /// + /// Unlike [`Self::notified`] which returns a future tied to the `Notify`'s + /// lifetime, `notified_owned` creates a self-contained future that owns its + /// notification state, making it safe to move between threads. + /// + /// See [`Self::notified`] for more details. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute notifications in the order + /// they were requested. Cancelling a call to `notified_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Notify; + /// + /// #[tokio::main] + /// async fn main() { + /// let notify = Arc::new(Notify::new()); + /// + /// for _ in 0..10 { + /// let notified = notify.clone().notified_owned(); + /// tokio::spawn(async move { + /// notified.await; + /// println!("received notification"); + /// }); + /// } + /// + /// println!("sending notification"); + /// notify.notify_waiters(); + /// } + /// ``` + pub fn notified_owned(self: Arc) -> OwnedNotified { + // we load the number of times notify_waiters + // was called and store that in the future. + let state = self.state.load(SeqCst); + OwnedNotified { + notify: self, + state: State::Init, + notify_waiters_calls: get_num_notify_waiters_calls(state), + waiter: Waiter::new(), + } + } /// Notifies the first waiting task. /// /// If a task is currently waiting, that task is notified. Otherwise, a @@ -911,9 +991,62 @@ impl Notified<'_> { self.poll_notified(None).is_ready() } + fn project(self: Pin<&mut Self>) -> NotifiedProject<'_> { + unsafe { + // Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`. + + is_unpin::<&Notify>(); + is_unpin::(); + is_unpin::(); + + let me = self.get_unchecked_mut(); + NotifiedProject { + notify: me.notify, + state: &mut me.state, + notify_waiters_calls: &me.notify_waiters_calls, + waiter: &me.waiter, + } + } + } + + fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> { + self.project().poll_notified(waker) + } +} + +impl Future for Notified<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + self.poll_notified(Some(cx.waker())) + } +} + +impl Drop for Notified<'_> { + fn drop(&mut self) { + // Safety: The type only transitions to a "Waiting" state when pinned. + unsafe { Pin::new_unchecked(self) } + .project() + .drop_notified(); + } +} + +// ===== impl OwnedNotified ===== + +impl OwnedNotified { + /// Adds this future to the list of futures that are ready to receive + /// wakeups from calls to [`notify_one`]. + /// + /// See [`Notified::enable`] for more details. + /// + /// [`notify_one`]: Notify::notify_one() + pub fn enable(self: Pin<&mut Self>) -> bool { + self.poll_notified(None).is_ready() + } + /// A custom `project` implementation is used in place of `pin-project-lite` /// as a custom drop implementation is needed. - fn project(self: Pin<&mut Self>) -> (&Notify, &mut State, &usize, &Waiter) { + fn project(self: Pin<&mut Self>) -> NotifiedProject<'_> { unsafe { // Safety: `notify`, `state` and `notify_waiters_calls` are `Unpin`. @@ -922,17 +1055,47 @@ impl Notified<'_> { is_unpin::(); let me = self.get_unchecked_mut(); - ( - me.notify, - &mut me.state, - &me.notify_waiters_calls, - &me.waiter, - ) + NotifiedProject { + notify: &me.notify, + state: &mut me.state, + notify_waiters_calls: &me.notify_waiters_calls, + waiter: &me.waiter, + } } } fn poll_notified(self: Pin<&mut Self>, waker: Option<&Waker>) -> Poll<()> { - let (notify, state, notify_waiters_calls, waiter) = self.project(); + self.project().poll_notified(waker) + } +} + +impl Future for OwnedNotified { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + self.poll_notified(Some(cx.waker())) + } +} + +impl Drop for OwnedNotified { + fn drop(&mut self) { + // Safety: The type only transitions to a "Waiting" state when pinned. + unsafe { Pin::new_unchecked(self) } + .project() + .drop_notified(); + } +} + +// ===== impl NotifiedProject ===== + +impl NotifiedProject<'_> { + fn poll_notified(self, waker: Option<&Waker>) -> Poll<()> { + let NotifiedProject { + notify, + state, + notify_waiters_calls, + waiter, + } = self; 'outer_loop: loop { match *state { @@ -1143,20 +1306,14 @@ impl Notified<'_> { } } } -} - -impl Future for Notified<'_> { - type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - self.poll_notified(Some(cx.waker())) - } -} - -impl Drop for Notified<'_> { - fn drop(&mut self) { - // Safety: The type only transitions to a "Waiting" state when pinned. - let (notify, state, _, waiter) = unsafe { Pin::new_unchecked(self).project() }; + fn drop_notified(self) { + let NotifiedProject { + notify, + state, + waiter, + .. + } = self; // This is where we ensure safety. The `Notified` value is being // dropped, which means we must ensure that the waiter entry is no diff --git a/tokio/tests/async_send_sync.rs b/tokio/tests/async_send_sync.rs index c9cedc38b02..5585215546d 100644 --- a/tokio/tests/async_send_sync.rs +++ b/tokio/tests/async_send_sync.rs @@ -398,6 +398,7 @@ assert_value!(tokio::sync::broadcast::WeakSender: !Send & !Sync & Unpin); assert_value!(tokio::sync::broadcast::WeakSender: Send & Sync & Unpin); assert_value!(tokio::sync::broadcast::WeakSender: Send & Sync & Unpin); assert_value!(tokio::sync::futures::Notified<'_>: Send & Sync & !Unpin); +assert_value!(tokio::sync::futures::OwnedNotified: Send & Sync & !Unpin); assert_value!(tokio::sync::mpsc::OwnedPermit: !Send & !Sync & Unpin); assert_value!(tokio::sync::mpsc::OwnedPermit: Send & Sync & Unpin); assert_value!(tokio::sync::mpsc::OwnedPermit: Send & Sync & Unpin); diff --git a/tokio/tests/sync_notify_owned.rs b/tokio/tests/sync_notify_owned.rs new file mode 100644 index 00000000000..06a0f6ade57 --- /dev/null +++ b/tokio/tests/sync_notify_owned.rs @@ -0,0 +1,304 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "sync")] + +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +use wasm_bindgen_test::wasm_bindgen_test as test; + +use std::sync::Arc; +use tokio::sync::Notify; +use tokio_test::task::spawn; +use tokio_test::*; + +#[allow(unused)] +trait AssertSend: Send + Sync {} +impl AssertSend for Notify {} + +#[test] +fn notify_notified_one() { + let notify = Arc::new(Notify::new()); + let mut notified = spawn(async { notify.clone().notified_owned().await }); + + notify.notify_one(); + assert_ready!(notified.poll()); +} + +#[test] +fn notify_multi_notified_one() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + + // add two waiters into the queue + assert_pending!(notified1.poll()); + assert_pending!(notified2.poll()); + + // should wakeup the first one + notify.notify_one(); + assert_ready!(notified1.poll()); + assert_pending!(notified2.poll()); +} + +#[test] +fn notify_multi_notified_last() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + + // add two waiters into the queue + assert_pending!(notified1.poll()); + assert_pending!(notified2.poll()); + + // should wakeup the last one + notify.notify_last(); + assert_pending!(notified1.poll()); + assert_ready!(notified2.poll()); +} + +#[test] +fn notified_one_notify() { + let notify = Arc::new(Notify::new()); + let mut notified = spawn(async { notify.clone().notified_owned().await }); + + assert_pending!(notified.poll()); + + notify.notify_one(); + assert!(notified.is_woken()); + assert_ready!(notified.poll()); +} + +#[test] +fn notified_multi_notify() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + + assert_pending!(notified1.poll()); + assert_pending!(notified2.poll()); + + notify.notify_one(); + assert!(notified1.is_woken()); + assert!(!notified2.is_woken()); + + assert_ready!(notified1.poll()); + assert_pending!(notified2.poll()); +} + +#[test] +fn notify_notified_multi() { + let notify = Arc::new(Notify::new()); + + notify.notify_one(); + + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + + assert_ready!(notified1.poll()); + assert_pending!(notified2.poll()); + + notify.notify_one(); + + assert!(notified2.is_woken()); + assert_ready!(notified2.poll()); +} + +#[test] +fn notified_drop_notified_notify() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + + assert_pending!(notified1.poll()); + + drop(notified1); + + assert_pending!(notified2.poll()); + + notify.notify_one(); + assert!(notified2.is_woken()); + assert_ready!(notified2.poll()); +} + +#[test] +fn notified_multi_notify_drop_one() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + + assert_pending!(notified1.poll()); + assert_pending!(notified2.poll()); + + notify.notify_one(); + + assert!(notified1.is_woken()); + assert!(!notified2.is_woken()); + + drop(notified1); + + assert!(notified2.is_woken()); + assert_ready!(notified2.poll()); +} + +#[test] +fn notified_multi_notify_one_drop() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + let mut notified3 = spawn(async { notify.clone().notified_owned().await }); + + // add waiters by order of poll execution + assert_pending!(notified1.poll()); + assert_pending!(notified2.poll()); + assert_pending!(notified3.poll()); + + // by default fifo + notify.notify_one(); + + drop(notified1); + + // next waiter should be the one to be to woken up + assert_ready!(notified2.poll()); + assert_pending!(notified3.poll()); +} + +#[test] +fn notified_multi_notify_last_drop() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + let mut notified3 = spawn(async { notify.clone().notified_owned().await }); + + // add waiters by order of poll execution + assert_pending!(notified1.poll()); + assert_pending!(notified2.poll()); + assert_pending!(notified3.poll()); + + notify.notify_last(); + + drop(notified3); + + // latest waiter added should be the one to woken up + assert_ready!(notified2.poll()); + assert_pending!(notified1.poll()); +} + +#[test] +fn notify_in_drop_after_wake() { + use futures::task::ArcWake; + use std::future::Future; + use std::sync::Arc; + + let notify = Arc::new(Notify::new()); + + struct NotifyOnDrop(Arc); + + impl ArcWake for NotifyOnDrop { + fn wake_by_ref(_arc_self: &Arc) {} + } + + impl Drop for NotifyOnDrop { + fn drop(&mut self) { + self.0.notify_waiters(); + } + } + + let mut fut = Box::pin(async { + notify.clone().notified_owned().await; + }); + + { + let waker = futures::task::waker(Arc::new(NotifyOnDrop(notify.clone()))); + let mut cx = std::task::Context::from_waker(&waker); + assert!(fut.as_mut().poll(&mut cx).is_pending()); + } + + // Now, notifying **should not** deadlock + notify.notify_waiters(); +} + +#[test] +fn notify_one_after_dropped_all() { + let notify = Arc::new(Notify::new()); + let mut notified1 = spawn(async { notify.clone().notified_owned().await }); + + assert_pending!(notified1.poll()); + + notify.notify_waiters(); + notify.notify_one(); + + drop(notified1); + + let mut notified2 = spawn(async { notify.clone().notified_owned().await }); + + assert_ready!(notified2.poll()); +} + +#[test] +fn test_notify_one_not_enabled() { + let notify = Arc::new(Notify::new()); + let mut future = spawn(notify.clone().notified_owned()); + + notify.notify_one(); + assert_ready!(future.poll()); +} + +#[test] +fn test_notify_one_after_enable() { + let notify = Arc::new(Notify::new()); + let mut future = spawn(notify.clone().notified_owned()); + + future.enter(|_, fut| assert!(!fut.enable())); + + notify.notify_one(); + assert_ready!(future.poll()); + future.enter(|_, fut| assert!(fut.enable())); +} + +#[test] +fn test_poll_after_enable() { + let notify = Arc::new(Notify::new()); + let mut future = spawn(notify.clone().notified_owned()); + + future.enter(|_, fut| assert!(!fut.enable())); + assert_pending!(future.poll()); +} + +#[test] +fn test_enable_after_poll() { + let notify = Arc::new(Notify::new()); + let mut future = spawn(notify.clone().notified_owned()); + + assert_pending!(future.poll()); + future.enter(|_, fut| assert!(!fut.enable())); +} + +#[test] +fn test_enable_consumes_permit() { + let notify = Arc::new(Notify::new()); + + // Add a permit. + notify.notify_one(); + + let mut future1 = spawn(notify.clone().notified_owned()); + future1.enter(|_, fut| assert!(fut.enable())); + + let mut future2 = spawn(notify.clone().notified_owned()); + future2.enter(|_, fut| assert!(!fut.enable())); +} + +#[test] +fn test_waker_update() { + use futures::task::noop_waker; + use std::future::Future; + use std::task::Context; + + let notify = Arc::new(Notify::new()); + let mut future = spawn(notify.clone().notified_owned()); + + let noop = noop_waker(); + future.enter(|_, fut| assert_pending!(fut.poll(&mut Context::from_waker(&noop)))); + + assert_pending!(future.poll()); + notify.notify_one(); + + assert!(future.is_woken()); +}