Skip to content

Commit

Permalink
Eliminate timer wheel allocations by wrapping the wheel shards in an
Browse files Browse the repository at this point in the history
RwLock
  • Loading branch information
tglane committed Aug 15, 2024
1 parent 39c3c19 commit 302a757
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 34 deletions.
5 changes: 5 additions & 0 deletions tokio/src/loom/mocked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ pub(crate) mod sync {
pub(crate) fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
self.0.try_lock().ok()
}

#[inline]
pub(crate) fn get_mut(&mut self) -> &mut T {
self.0.get_mut().unwrap()
}
}
pub(crate) use loom::sync::*;

Expand Down
8 changes: 8 additions & 0 deletions tokio/src/loom/std/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,12 @@ impl<T> Mutex<T> {
Err(TryLockError::WouldBlock) => None,
}
}

#[inline]
pub(crate) fn get_mut(&mut self) -> &mut T {
match self.0.get_mut() {
Ok(val) => val,
Err(p_err) => p_err.into_inner(),
}
}
}
112 changes: 78 additions & 34 deletions tokio/src/runtime/time/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::util::WakeList;

use crate::loom::sync::atomic::AtomicU64;
use std::fmt;
use std::sync::RwLock;
use std::{num::NonZeroU64, ptr::NonNull};

struct AtomicOptionNonZeroU64(AtomicU64);
Expand Down Expand Up @@ -115,7 +116,7 @@ struct Inner {
next_wake: AtomicOptionNonZeroU64,

/// Sharded Timer wheels.
wheels: Box<[Mutex<wheel::Wheel>]>,
wheels: RwLock<ShardedWheel>,

/// True if the driver is being shutdown.
pub(super) is_shutdown: AtomicBool,
Expand All @@ -130,6 +131,9 @@ struct Inner {
did_wake: AtomicBool,
}

/// Wrapper around the sharded timer wheels.
struct ShardedWheel(Box<[Mutex<wheel::Wheel>]>);

// ===== impl Driver =====

impl Driver {
Expand All @@ -149,7 +153,7 @@ impl Driver {
time_source,
inner: Inner {
next_wake: AtomicOptionNonZeroU64::new(None),
wheels: wheels.into_boxed_slice(),
wheels: RwLock::new(ShardedWheel(wheels.into_boxed_slice())),
is_shutdown: AtomicBool::new(false),
#[cfg(feature = "test-util")]
did_wake: AtomicBool::new(false),
Expand Down Expand Up @@ -190,23 +194,28 @@ impl Driver {
assert!(!handle.is_shutdown());

// Finds out the min expiration time to park.
let locks = (0..rt_handle.time().inner.get_shard_size())
.map(|id| rt_handle.time().inner.lock_sharded_wheel(id))
.collect::<Vec<_>>();

let expiration_time = locks
.iter()
.filter_map(|lock| lock.next_expiration_time())
.min();

rt_handle
.time()
.inner
.next_wake
.store(next_wake_time(expiration_time));

// Safety: After updating the `next_wake`, we drop all the locks.
drop(locks);
let expiration_time = {
let mut wheels_lock = rt_handle
.time()
.inner
.wheels
.write()
.expect("Timer wheel shards poisened");
let expiration_time = (0..wheels_lock.get_shard_size())
.filter_map(|id| {
let wheel = wheels_lock.get_sharded_wheel(id);
wheel.next_expiration_time()
})
.min();

rt_handle
.time()
.inner
.next_wake
.store(next_wake_time(expiration_time));

expiration_time
};

match expiration_time {
Some(when) => {
Expand Down Expand Up @@ -312,7 +321,12 @@ impl Handle {
// Returns the next wakeup time of this shard.
pub(self) fn process_at_sharded_time(&self, id: u32, mut now: u64) -> Option<u64> {
let mut waker_list = WakeList::new();
let mut lock = self.inner.lock_sharded_wheel(id);
let wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisened");
let mut lock = wheels_lock.lock_sharded_wheel(id);

if now < lock.elapsed() {
// Time went backwards! This normally shouldn't happen as the Rust language
Expand All @@ -337,7 +351,7 @@ impl Handle {

waker_list.wake_all();

lock = self.inner.lock_sharded_wheel(id);
lock = wheels_lock.lock_sharded_wheel(id);
}
}
}
Expand All @@ -360,7 +374,12 @@ impl Handle {
/// `add_entry` must not be called concurrently.
pub(self) unsafe fn clear_entry(&self, entry: NonNull<TimerShared>) {
unsafe {
let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id());
let wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisened");
let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());

if entry.as_ref().might_be_registered() {
lock.remove(entry);
Expand All @@ -383,7 +402,13 @@ impl Handle {
entry: NonNull<TimerShared>,
) {
let waker = unsafe {
let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id());
let wheels_lock = self
.inner
.wheels
.read()
.expect("Timer wheel shards poisened");

let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());

// We may have raced with a firing/deregistration, so check before
// deregistering.
Expand Down Expand Up @@ -443,24 +468,17 @@ impl Handle {
// ===== impl Inner =====

impl Inner {
/// Locks the driver's sharded wheel structure.
pub(super) fn lock_sharded_wheel(
&self,
shard_id: u32,
) -> crate::loom::sync::MutexGuard<'_, Wheel> {
let index = shard_id % (self.wheels.len() as u32);
// Safety: This modulo operation ensures that the index is not out of bounds.
unsafe { self.wheels.get_unchecked(index as usize).lock() }
}

// Check whether the driver has been shutdown
pub(super) fn is_shutdown(&self) -> bool {
self.is_shutdown.load(Ordering::SeqCst)
}

// Gets the number of shards.
fn get_shard_size(&self) -> u32 {
self.wheels.len() as u32
self.wheels
.read()
.expect("Timer wheel shards poisened")
.get_shard_size()
}
}

Expand All @@ -470,5 +488,31 @@ impl fmt::Debug for Inner {
}
}

// ===== impl ShardedWheel =====

impl ShardedWheel {
/// Locks the driver's sharded wheel structure.
pub(super) fn lock_sharded_wheel(
&self,
shard_id: u32,
) -> crate::loom::sync::MutexGuard<'_, Wheel> {
let index = shard_id % (self.0.len() as u32);
// Safety: This modulo operation ensures that the index is not out of bounds.
unsafe { self.0.get_unchecked(index as usize) }.lock()
}

/// Gets a mutable reference to the sharded wheel with the given id.
pub(super) fn get_sharded_wheel(&mut self, shard_id: u32) -> &mut wheel::Wheel {
let index = shard_id % (self.0.len() as u32);
// Safety: This modulo operation ensures that the index is not out of bounds.
unsafe { self.0.get_unchecked_mut(index as usize) }.get_mut()
}

/// Gets the number of shards.
fn get_shard_size(&self) -> u32 {
self.0.len() as u32
}
}

#[cfg(test)]
mod tests;

0 comments on commit 302a757

Please sign in to comment.