Skip to content

Commit 10bd2f4

Browse files
committed
sync::watch: Fix changed handling on version overflow
1 parent 9d51b76 commit 10bd2f4

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

tokio/src/sync/watch.rs

+16-19
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,10 @@ mod state {
359359
use crate::loom::sync::atomic::AtomicUsize;
360360
use crate::loom::sync::atomic::Ordering::SeqCst;
361361

362-
const CLOSED: usize = 1;
362+
const CLOSED_BIT: usize = 1;
363+
364+
// Using 2 as the step size preserves the the `CLOSED_BIT`.
365+
const STEP_SIZE: usize = 2;
363366

364367
/// The version part of the state. The lowest bit is always zero.
365368
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
@@ -378,39 +381,34 @@ mod state {
378381
pub(super) struct AtomicState(AtomicUsize);
379382

380383
impl Version {
381-
/// Get the initial version when creating the channel.
382-
pub(super) fn initial() -> Self {
383-
// The initial version is 1 so that `mark_changed` can decrement by one.
384-
// (The value is 2 due to the closed bit.)
385-
Version(2)
386-
}
387-
388384
/// Decrements the version.
389385
pub(super) fn decrement(&mut self) {
390-
// Decrement by two to avoid touching the CLOSED bit.
391-
if self.0 >= 2 {
392-
self.0 -= 2;
393-
}
386+
// Using a wrapping decrement here is required to ensure that the
387+
// operation is consistent with `std::sync::atomic::AtomicUsize::fetch_add()`
388+
// which wraps on overflow.
389+
self.0 = self.0.wrapping_sub(STEP_SIZE);
394390
}
391+
392+
pub(super) const INITIAL: Self = Version(0);
395393
}
396394

397395
impl StateSnapshot {
398396
/// Extract the version from the state.
399397
pub(super) fn version(self) -> Version {
400-
Version(self.0 & !CLOSED)
398+
Version(self.0 & !CLOSED_BIT)
401399
}
402400

403401
/// Is the closed bit set?
404402
pub(super) fn is_closed(self) -> bool {
405-
(self.0 & CLOSED) == CLOSED
403+
(self.0 & CLOSED_BIT) == CLOSED_BIT
406404
}
407405
}
408406

409407
impl AtomicState {
410408
/// Create a new `AtomicState` that is not closed and which has the
411409
/// version set to `Version::initial()`.
412410
pub(super) fn new() -> Self {
413-
AtomicState(AtomicUsize::new(2))
411+
AtomicState(AtomicUsize::new(Version::INITIAL.0))
414412
}
415413

416414
/// Load the current value of the state.
@@ -420,13 +418,12 @@ mod state {
420418

421419
/// Increment the version counter.
422420
pub(super) fn increment_version(&self) {
423-
// Increment by two to avoid touching the CLOSED bit.
424-
self.0.fetch_add(2, SeqCst);
421+
self.0.fetch_add(STEP_SIZE, SeqCst);
425422
}
426423

427424
/// Set the closed bit in the state.
428425
pub(super) fn set_closed(&self) {
429-
self.0.fetch_or(CLOSED, SeqCst);
426+
self.0.fetch_or(CLOSED_BIT, SeqCst);
430427
}
431428
}
432429
}
@@ -482,7 +479,7 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
482479

483480
let rx = Receiver {
484481
shared,
485-
version: Version::initial(),
482+
version: Version::INITIAL,
486483
};
487484

488485
(tx, rx)

0 commit comments

Comments
 (0)