Skip to content

Commit

Permalink
sync: apply cooperative scheduling to sync::watch (tokio-rs#6846)
Browse files Browse the repository at this point in the history
  • Loading branch information
tglane authored Sep 26, 2024
1 parent 623928e commit c8af499
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 12 deletions.
43 changes: 42 additions & 1 deletion tokio/src/runtime/coop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,11 @@ cfg_rt! {
}

cfg_coop! {
use pin_project_lite::pin_project;
use std::cell::Cell;
use std::task::{Context, Poll};
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

#[must_use]
pub(crate) struct RestoreOnPending(Cell<Budget>);
Expand Down Expand Up @@ -240,6 +243,44 @@ cfg_coop! {
self.0.is_none()
}
}

pin_project! {
/// Future wrapper to ensure cooperative scheduling.
///
/// When being polled `poll_proceed` is called before the inner future is polled to check
/// if the inner future has exceeded its budget. If the inner future resolves, this will
/// automatically call `RestoreOnPending::made_progress` before resolving this future with
/// the result of the inner one. If polling the inner future is pending, polling this future
/// type will also return a `Poll::Pending`.
#[must_use = "futures do nothing unless polled"]
pub(crate) struct Coop<F: Future> {
#[pin]
pub(crate) fut: F,
}
}

impl<F: Future> Future for Coop<F> {
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let coop = ready!(poll_proceed(cx));
let me = self.project();
if let Poll::Ready(ret) = me.fut.poll(cx) {
coop.made_progress();
Poll::Ready(ret)
} else {
Poll::Pending
}
}
}

/// Run a future with a budget constraint for cooperative scheduling.
/// If the future exceeds its budget while being polled, control is yielded back to the
/// runtime.
#[inline]
pub(crate) fn cooperative<F: Future>(fut: F) -> Coop<F> {
Coop { fut }
}
}

#[cfg(all(test, not(loom)))]
Expand Down
33 changes: 22 additions & 11 deletions tokio/src/sync/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
//! [`Sender::closed`]: crate::sync::watch::Sender::closed
//! [`Sender::subscribe()`]: crate::sync::watch::Sender::subscribe

use crate::runtime::coop::cooperative;
use crate::sync::notify::Notify;

use crate::loom::sync::atomic::AtomicUsize;
Expand Down Expand Up @@ -743,7 +744,7 @@ impl<T> Receiver<T> {
/// }
/// ```
pub async fn changed(&mut self) -> Result<(), error::RecvError> {
changed_impl(&self.shared, &mut self.version).await
cooperative(changed_impl(&self.shared, &mut self.version)).await
}

/// Waits for a value that satisfies the provided condition.
Expand Down Expand Up @@ -807,6 +808,13 @@ impl<T> Receiver<T> {
/// }
/// ```
pub async fn wait_for(
&mut self,
f: impl FnMut(&T) -> bool,
) -> Result<Ref<'_, T>, error::RecvError> {
cooperative(self.wait_for_inner(f)).await
}

async fn wait_for_inner(
&mut self,
mut f: impl FnMut(&T) -> bool,
) -> Result<Ref<'_, T>, error::RecvError> {
Expand Down Expand Up @@ -1224,19 +1232,22 @@ impl<T> Sender<T> {
/// }
/// ```
pub async fn closed(&self) {
crate::trace::async_trace_leaf().await;
cooperative(async {
crate::trace::async_trace_leaf().await;

while self.receiver_count() > 0 {
let notified = self.shared.notify_tx.notified();
while self.receiver_count() > 0 {
let notified = self.shared.notify_tx.notified();

if self.receiver_count() == 0 {
return;
}
if self.receiver_count() == 0 {
return;
}

notified.await;
// The channel could have been reopened in the meantime by calling
// `subscribe`, so we loop again.
}
notified.await;
// The channel could have been reopened in the meantime by calling
// `subscribe`, so we loop again.
}
})
.await;
}

/// Creates a new [`Receiver`] connected to this `Sender`.
Expand Down
82 changes: 82 additions & 0 deletions tokio/tests/sync_watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,85 @@ async fn receiver_is_notified_when_last_sender_is_dropped() {

assert!(t.is_woken());
}

#[tokio::test]
async fn receiver_changed_is_cooperative() {
let (tx, mut rx) = watch::channel(());

drop(tx);

tokio::select! {
biased;
_ = async {
loop {
assert!(rx.changed().await.is_err());
}
} => {},
_ = tokio::task::yield_now() => {},
}
}

#[tokio::test]
async fn receiver_changed_is_cooperative_ok() {
let (tx, mut rx) = watch::channel(());

tokio::select! {
biased;
_ = async {
loop {
assert!(tx.send(()).is_ok());
assert!(rx.changed().await.is_ok());
}
} => {},
_ = tokio::task::yield_now() => {},
}
}

#[tokio::test]
async fn receiver_wait_for_is_cooperative() {
let (tx, mut rx) = watch::channel(0);

drop(tx);

tokio::select! {
biased;
_ = async {
loop {
assert!(rx.wait_for(|val| *val == 1).await.is_err());
}
} => {},
_ = tokio::task::yield_now() => {},
}
}

#[tokio::test]
async fn receiver_wait_for_is_cooperative_ok() {
let (tx, mut rx) = watch::channel(0);

tokio::select! {
biased;
_ = async {
loop {
assert!(tx.send(1).is_ok());
assert!(rx.wait_for(|val| *val == 1).await.is_ok());
}
} => {},
_ = tokio::task::yield_now() => {},
}
}

#[tokio::test]
async fn sender_closed_is_cooperative() {
let (tx, rx) = watch::channel(());

drop(rx);

tokio::select! {
_ = async {
loop {
tx.closed().await;
}
} => {},
_ = tokio::task::yield_now() => {},
}
}

0 comments on commit c8af499

Please sign in to comment.