Skip to content

Commit 4b8d1cd

Browse files
committed
sync::broadcast: don't lock in channel()
1 parent f5d2b5a commit 4b8d1cd

File tree

1 file changed

+38
-18
lines changed

1 file changed

+38
-18
lines changed

tokio/src/sync/broadcast.rs

+38-18
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,12 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
445445
/// than `usize::MAX / 2`.
446446
#[track_caller]
447447
pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
448-
let tx = Sender::new(capacity);
449-
let rx = tx.subscribe();
448+
// SAFETY: In the line below we are creating one extra receiver, so there will be 1 in total.
449+
let tx = unsafe { Sender::new_with_receiver_count(1, capacity) };
450+
let rx = Receiver {
451+
shared: tx.shared.clone(),
452+
next: 0,
453+
};
450454
(tx, rx)
451455
}
452456

@@ -464,9 +468,29 @@ impl<T> Sender<T> {
464468
/// [`broadcast`]: crate::sync::broadcast
465469
/// [`broadcast::channel`]: crate::sync::broadcast
466470
#[track_caller]
467-
pub fn new(mut capacity: usize) -> Self {
468-
assert!(capacity > 0, "capacity is empty");
469-
assert!(capacity <= usize::MAX >> 1, "requested capacity too large");
471+
pub fn new(capacity: usize) -> Self {
472+
// SAFETY: We don't create extra receivers, so there are 0.
473+
unsafe { Self::new_with_receiver_count(0, capacity) }
474+
}
475+
476+
/// Creates the sending-half of the [`broadcast`](self) channel, and provide the receiver
477+
/// count.
478+
///
479+
/// See the documentation of [`broadcast::channel`](self::channel) for more errors when
480+
/// calling this function.
481+
///
482+
/// # Safety:
483+
///
484+
/// The caller must ensure that the amount of receivers for this Sender is correct before
485+
/// the channel functionalities are used, the count is zero by default, as this function
486+
/// does not create any receivers by itself.
487+
#[track_caller]
488+
unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self {
489+
assert!(capacity > 0, "broadcast channel capacity cannot be zero");
490+
assert!(
491+
capacity <= usize::MAX >> 1,
492+
"broadcast channel capacity exceeded `usize::MAX / 2`"
493+
);
470494

471495
// Round to a power of two
472496
capacity = capacity.next_power_of_two();
@@ -486,7 +510,7 @@ impl<T> Sender<T> {
486510
mask: capacity - 1,
487511
tail: Mutex::new(Tail {
488512
pos: 0,
489-
rx_cnt: 0,
513+
rx_cnt: receiver_count,
490514
closed: false,
491515
waiters: LinkedList::new(),
492516
}),
@@ -1383,37 +1407,33 @@ mod tests {
13831407

13841408
#[test]
13851409
fn receiver_count_on_sender_constructor() {
1386-
let count_of = |sender: &Sender<i32>| sender.shared.tail.lock().rx_cnt;
1387-
13881410
let sender = Sender::<i32>::new(16);
1389-
assert_eq!(count_of(&sender), 0);
1411+
assert_eq!(sender.receiver_count(), 0);
13901412

13911413
let rx_1 = sender.subscribe();
1392-
assert_eq!(count_of(&sender), 1);
1414+
assert_eq!(sender.receiver_count(), 1);
13931415

13941416
let rx_2 = rx_1.resubscribe();
1395-
assert_eq!(count_of(&sender), 2);
1417+
assert_eq!(sender.receiver_count(), 2);
13961418

13971419
let rx_3 = sender.subscribe();
1398-
assert_eq!(count_of(&sender), 3);
1420+
assert_eq!(sender.receiver_count(), 3);
13991421

14001422
drop(rx_3);
14011423
drop(rx_1);
1402-
assert_eq!(count_of(&sender), 1);
1424+
assert_eq!(sender.receiver_count(), 1);
14031425

14041426
drop(rx_2);
1405-
assert_eq!(count_of(&sender), 0);
1427+
assert_eq!(sender.receiver_count(), 0);
14061428
}
14071429

14081430
#[cfg(not(loom))]
14091431
#[test]
14101432
fn receiver_count_on_channel_constructor() {
1411-
let count_of = |sender: &Sender<i32>| sender.shared.tail.lock().rx_cnt;
1412-
14131433
let (sender, rx) = channel::<i32>(16);
1414-
assert_eq!(count_of(&sender), 1);
1434+
assert_eq!(sender.receiver_count(), 1);
14151435

14161436
let _rx_2 = rx.resubscribe();
1417-
assert_eq!(count_of(&sender), 2);
1437+
assert_eq!(sender.receiver_count(), 2);
14181438
}
14191439
}

0 commit comments

Comments
 (0)