diff --git a/benches/Cargo.toml b/benches/Cargo.toml index 1eea2e04489..c581055cf65 100644 --- a/benches/Cargo.toml +++ b/benches/Cargo.toml @@ -26,6 +26,11 @@ name = "spawn" path = "spawn.rs" harness = false +[[bench]] +name = "sync_broadcast" +path = "sync_broadcast.rs" +harness = false + [[bench]] name = "sync_mpsc" path = "sync_mpsc.rs" diff --git a/benches/sync_broadcast.rs b/benches/sync_broadcast.rs new file mode 100644 index 00000000000..38a2141387b --- /dev/null +++ b/benches/sync_broadcast.rs @@ -0,0 +1,82 @@ +use rand::{Rng, RngCore, SeedableRng}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::{broadcast, Notify}; + +use criterion::measurement::WallTime; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion}; + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(6) + .build() + .unwrap() +} + +fn do_work(rng: &mut impl RngCore) -> u32 { + use std::fmt::Write; + let mut message = String::new(); + for i in 1..=10 { + let _ = write!(&mut message, " {i}={}", rng.gen::()); + } + message + .as_bytes() + .iter() + .map(|&c| c as u32) + .fold(0, u32::wrapping_add) +} + +fn contention_impl(g: &mut BenchmarkGroup) { + let rt = rt(); + + let (tx, _rx) = broadcast::channel::(1000); + let wg = Arc::new((AtomicUsize::new(0), Notify::new())); + + for n in 0..N_TASKS { + let wg = wg.clone(); + let mut rx = tx.subscribe(); + let mut rng = rand::rngs::StdRng::seed_from_u64(n as u64); + rt.spawn(async move { + while let Ok(_) = rx.recv().await { + let r = do_work(&mut rng); + let _ = black_box(r); + if wg.0.fetch_sub(1, Ordering::Relaxed) == 1 { + wg.1.notify_one(); + } + } + }); + } + + const N_ITERS: usize = 100; + + g.bench_function(N_TASKS.to_string(), |b| { + b.iter(|| { + rt.block_on({ + let wg = wg.clone(); + let tx = tx.clone(); + async move { + for i in 0..N_ITERS { + assert_eq!(wg.0.fetch_add(N_TASKS, Ordering::Relaxed), 0); + tx.send(i).unwrap(); + while wg.0.load(Ordering::Relaxed) > 0 { + wg.1.notified().await; + } + } + } + }) + }) + }); +} + +fn bench_contention(c: &mut Criterion) { + let mut group = c.benchmark_group("contention"); + contention_impl::<10>(&mut group); + contention_impl::<100>(&mut group); + contention_impl::<500>(&mut group); + contention_impl::<1000>(&mut group); + group.finish(); +} + +criterion_group!(contention, bench_contention); + +criterion_main!(contention); diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 568a50bd59b..499e5296da4 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -117,7 +117,7 @@ //! ``` use crate::loom::cell::UnsafeCell; -use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::atomic::{AtomicBool, AtomicUsize}; use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; use crate::util::linked_list::{self, GuardedLinkedList, LinkedList}; use crate::util::WakeList; @@ -127,7 +127,7 @@ use std::future::Future; use std::marker::PhantomPinned; use std::pin::Pin; use std::ptr::NonNull; -use std::sync::atomic::Ordering::SeqCst; +use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst}; use std::task::{Context, Poll, Waker}; use std::usize; @@ -354,7 +354,7 @@ struct Slot { /// An entry in the wait queue. struct Waiter { /// True if queued. - queued: bool, + queued: AtomicBool, /// Task waiting on the broadcast channel. waker: Option, @@ -369,7 +369,7 @@ struct Waiter { impl Waiter { fn new() -> Self { Self { - queued: false, + queued: AtomicBool::new(false), waker: None, pointers: linked_list::Pointers::new(), _p: PhantomPinned, @@ -897,15 +897,22 @@ impl Shared { 'outer: loop { while wakers.can_push() { match list.pop_back_locked(&mut tail) { - Some(mut waiter) => { - // Safety: `tail` lock is still held. - let waiter = unsafe { waiter.as_mut() }; - - assert!(waiter.queued); - waiter.queued = false; - - if let Some(waker) = waiter.waker.take() { - wakers.push(waker); + Some(waiter) => { + unsafe { + // Safety: accessing `waker` is safe because + // the tail lock is held. + if let Some(waker) = (*waiter.as_ptr()).waker.take() { + wakers.push(waker); + } + + // Safety: `queued` is atomic. + let queued = &(*waiter.as_ptr()).queued; + // `Relaxed` suffices because the tail lock is held. + assert!(queued.load(Relaxed)); + // `Release` is needed to synchronize with `Recv::drop`. + // It is critical to set this variable **after** waker + // is extracted, otherwise we may data race with `Recv::drop`. + queued.store(false, Release); } } None => { @@ -1104,8 +1111,13 @@ impl Receiver { } } - if !(*ptr).queued { - (*ptr).queued = true; + // If the waiter is not already queued, enqueue it. + // `Relaxed` order suffices: we have synchronized with + // all writers through the tail lock that we hold. + if !(*ptr).queued.load(Relaxed) { + // `Relaxed` order suffices: all the readers will + // synchronize with this write through the tail lock. + (*ptr).queued.store(true, Relaxed); tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr)); } }); @@ -1357,7 +1369,7 @@ impl<'a, T> Recv<'a, T> { Recv { receiver, waiter: UnsafeCell::new(Waiter { - queued: false, + queued: AtomicBool::new(false), waker: None, pointers: linked_list::Pointers::new(), _p: PhantomPinned, @@ -1402,22 +1414,37 @@ where impl<'a, T> Drop for Recv<'a, T> { fn drop(&mut self) { - // Acquire the tail lock. This is required for safety before accessing - // the waiter node. - let mut tail = self.receiver.shared.tail.lock(); - - // safety: tail lock is held - let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued }); - + // Safety: `waiter.queued` is atomic. + // Acquire ordering is required to synchronize with + // `Shared::notify_rx` before we drop the object. + let queued = self + .waiter + .with(|ptr| unsafe { (*ptr).queued.load(Acquire) }); + + // If the waiter is queued, we need to unlink it from the waiters list. + // If not, no further synchronization is required, since the waiter + // is not in the list and, as such, is not shared with any other threads. if queued { - // Remove the node - // - // safety: tail lock is held and the wait node is verified to be in - // the list. - unsafe { - self.waiter.with_mut(|ptr| { - tail.waiters.remove((&mut *ptr).into()); - }); + // Acquire the tail lock. This is required for safety before accessing + // the waiter node. + let mut tail = self.receiver.shared.tail.lock(); + + // Safety: tail lock is held. + // `Relaxed` order suffices because we hold the tail lock. + let queued = self + .waiter + .with_mut(|ptr| unsafe { (*ptr).queued.load(Relaxed) }); + + if queued { + // Remove the node + // + // safety: tail lock is held and the wait node is verified to be in + // the list. + unsafe { + self.waiter.with_mut(|ptr| { + tail.waiters.remove((&mut *ptr).into()); + }); + } } } }