diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 07120d63411..3e84bf46f72 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -58,6 +58,21 @@ pub enum TryAcquireError { /// The semaphore has no available permits. NoPermits, } + +/// Error returned from the [`Semaphore::try_reacquire`] function. +/// +/// [`Semaphore::try_acquire`]: crate::sync::Semaphore::try_acquire +#[derive(Debug, PartialEq, Eq)] +pub enum TryUpgradeError { + /// The semaphore has been [closed] and cannot issue new permits. + /// + /// [closed]: crate::sync::Semaphore::close + Closed, + + /// The semaphore has changed + Used, +} + /// Error returned from the [`Semaphore::acquire`] function. /// /// An `acquire` operation can only fail if the semaphore has been @@ -294,6 +309,36 @@ impl Semaphore { } } + pub(crate) fn try_upgrade(&self, num_permits: usize) -> Result<(), TryUpgradeError> { + assert!( + num_permits <= Self::MAX_PERMITS, + "a semaphore may not have more than MAX_PERMITS permits ({})", + Self::MAX_PERMITS + ); + let num_permits = num_permits << Self::PERMIT_SHIFT; + + let curr = self.permits.load(Acquire); + + + // Has the semaphore closed? + if curr & Self::CLOSED == Self::CLOSED { + return Err(TryUpgradeError::Closed); + } + + // Are there enough permits remaining? + if curr < num_permits{ + return Err(TryUpgradeError::Used); + } + + match self.permits.compare_exchange(curr, 0, AcqRel, Acquire) { + Ok(_) => { + return Ok(()); + } + Err(_actual) => return Err(TryUpgradeError::Used), + } + + } + pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> { Acquire::new(self, num_permits) } diff --git a/tokio/src/sync/rwlock.rs b/tokio/src/sync/rwlock.rs index ff02c7971d6..3340856df95 100644 --- a/tokio/src/sync/rwlock.rs +++ b/tokio/src/sync/rwlock.rs @@ -1,4 +1,4 @@ -use crate::sync::batch_semaphore::{Semaphore, TryAcquireError}; +use crate::sync::batch_semaphore::{Semaphore, TryAcquireError,TryUpgradeError}; use crate::sync::mutex::TryLockError; #[cfg(all(tokio_unstable, feature = "tracing"))] use crate::util::trace; @@ -1055,6 +1055,39 @@ impl RwLock { Ok(guard) } + pub fn try_upgrade<'a>(&'a self,reader:RwLockReadGuard<'a,T>) -> Result, RwLockReadGuard<'a,T>> { + match self.s.try_upgrade(self.mr as usize - 1) { + Ok(permit) => permit, + Err(TryUpgradeError::Used) => return Err(reader), + Err(TryUpgradeError::Closed) => unreachable!(), + } + + //readers permit already acounted for by upgrade + std::mem::forget(reader); + + let guard = RwLockWriteGuard { + permits_acquired: self.mr, + s: &self.s, + data: self.c.get(), + marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + Ok(guard) + } + + + /// Returns a mutable reference to the underlying data. /// /// Since this call borrows the `RwLock` mutably, no actual locking needs to