diff --git a/tokio/src/runtime/coop.rs b/tokio/src/runtime/coop.rs index aaca8b6baa2..c01e5e3de8b 100644 --- a/tokio/src/runtime/coop.rs +++ b/tokio/src/runtime/coop.rs @@ -135,8 +135,11 @@ cfg_rt! { } cfg_coop! { + use pin_project_lite::pin_project; use std::cell::Cell; - use std::task::{Context, Poll}; + use std::future::Future; + use std::pin::Pin; + use std::task::{ready, Context, Poll}; #[must_use] pub(crate) struct RestoreOnPending(Cell); @@ -240,6 +243,44 @@ cfg_coop! { self.0.is_none() } } + + pin_project! { + /// Future wrapper to ensure cooperative scheduling. + /// + /// When being polled `poll_proceed` is called before the inner future is polled to check + /// if the inner future has exceeded its budget. If the inner future resolves, this will + /// automatically call `RestoreOnPending::made_progress` before resolving this future with + /// the result of the inner one. If polling the inner future is pending, polling this future + /// type will also return a `Poll::Pending`. + #[must_use = "futures do nothing unless polled"] + pub(crate) struct Coop { + #[pin] + pub(crate) fut: F, + } + } + + impl Future for Coop { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let coop = ready!(poll_proceed(cx)); + let me = self.project(); + if let Poll::Ready(ret) = me.fut.poll(cx) { + coop.made_progress(); + Poll::Ready(ret) + } else { + Poll::Pending + } + } + } + + /// Run a future with a budget constraint for cooperative scheduling. + /// If the future exceeds its budget while being polled, control is yielded back to the + /// runtime. + #[inline] + pub(crate) fn cooperative(fut: F) -> Coop { + Coop { fut } + } } #[cfg(all(test, not(loom)))] diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 490b9e4df88..af72e30dd32 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -111,6 +111,7 @@ //! [`Sender::closed`]: crate::sync::watch::Sender::closed //! [`Sender::subscribe()`]: crate::sync::watch::Sender::subscribe +use crate::runtime::coop::cooperative; use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; @@ -743,7 +744,7 @@ impl Receiver { /// } /// ``` pub async fn changed(&mut self) -> Result<(), error::RecvError> { - changed_impl(&self.shared, &mut self.version).await + cooperative(changed_impl(&self.shared, &mut self.version)).await } /// Waits for a value that satisfies the provided condition. @@ -807,6 +808,13 @@ impl Receiver { /// } /// ``` pub async fn wait_for( + &mut self, + f: impl FnMut(&T) -> bool, + ) -> Result, error::RecvError> { + cooperative(self.wait_for_inner(f)).await + } + + async fn wait_for_inner( &mut self, mut f: impl FnMut(&T) -> bool, ) -> Result, error::RecvError> { @@ -1224,19 +1232,22 @@ impl Sender { /// } /// ``` pub async fn closed(&self) { - crate::trace::async_trace_leaf().await; + cooperative(async { + crate::trace::async_trace_leaf().await; - while self.receiver_count() > 0 { - let notified = self.shared.notify_tx.notified(); + while self.receiver_count() > 0 { + let notified = self.shared.notify_tx.notified(); - if self.receiver_count() == 0 { - return; - } + if self.receiver_count() == 0 { + return; + } - notified.await; - // The channel could have been reopened in the meantime by calling - // `subscribe`, so we loop again. - } + notified.await; + // The channel could have been reopened in the meantime by calling + // `subscribe`, so we loop again. + } + }) + .await; } /// Creates a new [`Receiver`] connected to this `Sender`. diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index 17f0c81087a..4418f88e57b 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -368,3 +368,85 @@ async fn receiver_is_notified_when_last_sender_is_dropped() { assert!(t.is_woken()); } + +#[tokio::test] +async fn receiver_changed_is_cooperative() { + let (tx, mut rx) = watch::channel(()); + + drop(tx); + + tokio::select! { + biased; + _ = async { + loop { + assert!(rx.changed().await.is_err()); + } + } => {}, + _ = tokio::task::yield_now() => {}, + } +} + +#[tokio::test] +async fn receiver_changed_is_cooperative_ok() { + let (tx, mut rx) = watch::channel(()); + + tokio::select! { + biased; + _ = async { + loop { + assert!(tx.send(()).is_ok()); + assert!(rx.changed().await.is_ok()); + } + } => {}, + _ = tokio::task::yield_now() => {}, + } +} + +#[tokio::test] +async fn receiver_wait_for_is_cooperative() { + let (tx, mut rx) = watch::channel(0); + + drop(tx); + + tokio::select! { + biased; + _ = async { + loop { + assert!(rx.wait_for(|val| *val == 1).await.is_err()); + } + } => {}, + _ = tokio::task::yield_now() => {}, + } +} + +#[tokio::test] +async fn receiver_wait_for_is_cooperative_ok() { + let (tx, mut rx) = watch::channel(0); + + tokio::select! { + biased; + _ = async { + loop { + assert!(tx.send(1).is_ok()); + assert!(rx.wait_for(|val| *val == 1).await.is_ok()); + } + } => {}, + _ = tokio::task::yield_now() => {}, + } +} + +#[tokio::test] +async fn sender_closed_is_cooperative() { + let (tx, rx) = watch::channel(()); + + drop(rx); + + tokio::select! { + _ = async { + loop { + tx.closed().await; + } + } => {}, + _ = tokio::task::yield_now() => {}, + } +}