Skip to content

Commit d769972

Browse files
committed
sync::broadcast: don't lock in channel()
1 parent f5d2b5a commit d769972

File tree

1 file changed

+37
-18
lines changed

1 file changed

+37
-18
lines changed

tokio/src/sync/broadcast.rs

+37-18
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,12 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2;
445445
/// than `usize::MAX / 2`.
446446
#[track_caller]
447447
pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
448-
let tx = Sender::new(capacity);
449-
let rx = tx.subscribe();
448+
// SAFETY: In the line below we are creating one extra receiver, so there will be 1 in total.
449+
let tx = unsafe { Sender::new_with_receiver_count(1, capacity) };
450+
let rx = Receiver {
451+
shared: tx.shared.clone(),
452+
next: 0,
453+
};
450454
(tx, rx)
451455
}
452456

@@ -464,9 +468,28 @@ impl<T> Sender<T> {
464468
/// [`broadcast`]: crate::sync::broadcast
465469
/// [`broadcast::channel`]: crate::sync::broadcast
466470
#[track_caller]
467-
pub fn new(mut capacity: usize) -> Self {
468-
assert!(capacity > 0, "capacity is empty");
469-
assert!(capacity <= usize::MAX >> 1, "requested capacity too large");
471+
pub fn new(capacity: usize) -> Self {
472+
// SAFETY: We don't create extra receivers, so there are 0.
473+
unsafe { Self::new_with_receiver_count(0, capacity) }
474+
}
475+
476+
/// Creates the sending-half of the [`broadcast`] channel, and provide the receiver count.
477+
///
478+
/// See the documentation of [`broadcast::channel`] for more errors when calling this
479+
/// function.
480+
///
481+
/// # Safety:
482+
///
483+
/// The caller must ensure that the amount of receivers for this Sender is correct before
484+
/// the channel functionalities are used, the count is zero by default, as this function
485+
/// does not create any receivers by itself.
486+
#[track_caller]
487+
unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self {
488+
assert!(capacity > 0, "broadcast channel capacity cannot be zero");
489+
assert!(
490+
capacity <= usize::MAX >> 1,
491+
"broadcast channel capacity exceeded `usize::MAX / 2`"
492+
);
470493

471494
// Round to a power of two
472495
capacity = capacity.next_power_of_two();
@@ -486,7 +509,7 @@ impl<T> Sender<T> {
486509
mask: capacity - 1,
487510
tail: Mutex::new(Tail {
488511
pos: 0,
489-
rx_cnt: 0,
512+
rx_cnt: receiver_count,
490513
closed: false,
491514
waiters: LinkedList::new(),
492515
}),
@@ -1383,37 +1406,33 @@ mod tests {
13831406

13841407
#[test]
13851408
fn receiver_count_on_sender_constructor() {
1386-
let count_of = |sender: &Sender<i32>| sender.shared.tail.lock().rx_cnt;
1387-
13881409
let sender = Sender::<i32>::new(16);
1389-
assert_eq!(count_of(&sender), 0);
1410+
assert_eq!(sender.receiver_count(), 0);
13901411

13911412
let rx_1 = sender.subscribe();
1392-
assert_eq!(count_of(&sender), 1);
1413+
assert_eq!(sender.receiver_count(), 1);
13931414

13941415
let rx_2 = rx_1.resubscribe();
1395-
assert_eq!(count_of(&sender), 2);
1416+
assert_eq!(sender.receiver_count(), 2);
13961417

13971418
let rx_3 = sender.subscribe();
1398-
assert_eq!(count_of(&sender), 3);
1419+
assert_eq!(sender.receiver_count(), 3);
13991420

14001421
drop(rx_3);
14011422
drop(rx_1);
1402-
assert_eq!(count_of(&sender), 1);
1423+
assert_eq!(sender.receiver_count(), 1);
14031424

14041425
drop(rx_2);
1405-
assert_eq!(count_of(&sender), 0);
1426+
assert_eq!(sender.receiver_count(), 0);
14061427
}
14071428

14081429
#[cfg(not(loom))]
14091430
#[test]
14101431
fn receiver_count_on_channel_constructor() {
1411-
let count_of = |sender: &Sender<i32>| sender.shared.tail.lock().rx_cnt;
1412-
14131432
let (sender, rx) = channel::<i32>(16);
1414-
assert_eq!(count_of(&sender), 1);
1433+
assert_eq!(sender.receiver_count(), 1);
14151434

14161435
let _rx_2 = rx.resubscribe();
1417-
assert_eq!(count_of(&sender), 2);
1436+
assert_eq!(sender.receiver_count(), 2);
14181437
}
14191438
}

0 commit comments

Comments
 (0)