From 302a7577c3b315fa11216fb22e2de649cd0678ac Mon Sep 17 00:00:00 2001 From: Timo Glane Date: Thu, 15 Aug 2024 11:44:37 +0200 Subject: [PATCH] Eliminate timer wheel allocations by wrapping the wheel shards in an RwLock --- tokio/src/loom/mocked.rs | 5 ++ tokio/src/loom/std/mutex.rs | 8 +++ tokio/src/runtime/time/mod.rs | 112 +++++++++++++++++++++++----------- 3 files changed, 91 insertions(+), 34 deletions(-) diff --git a/tokio/src/loom/mocked.rs b/tokio/src/loom/mocked.rs index d40e2c1f8ea..c25018e7e8c 100644 --- a/tokio/src/loom/mocked.rs +++ b/tokio/src/loom/mocked.rs @@ -24,6 +24,11 @@ pub(crate) mod sync { pub(crate) fn try_lock(&self) -> Option> { self.0.try_lock().ok() } + + #[inline] + pub(crate) fn get_mut(&mut self) -> &mut T { + self.0.get_mut().unwrap() + } } pub(crate) use loom::sync::*; diff --git a/tokio/src/loom/std/mutex.rs b/tokio/src/loom/std/mutex.rs index 7b8f9ba1e24..3ea8e1df861 100644 --- a/tokio/src/loom/std/mutex.rs +++ b/tokio/src/loom/std/mutex.rs @@ -33,4 +33,12 @@ impl Mutex { Err(TryLockError::WouldBlock) => None, } } + + #[inline] + pub(crate) fn get_mut(&mut self) -> &mut T { + match self.0.get_mut() { + Ok(val) => val, + Err(p_err) => p_err.into_inner(), + } + } } diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index c01a5f2b25e..2eabf5ad922 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -28,6 +28,7 @@ use crate::util::WakeList; use crate::loom::sync::atomic::AtomicU64; use std::fmt; +use std::sync::RwLock; use std::{num::NonZeroU64, ptr::NonNull}; struct AtomicOptionNonZeroU64(AtomicU64); @@ -115,7 +116,7 @@ struct Inner { next_wake: AtomicOptionNonZeroU64, /// Sharded Timer wheels. - wheels: Box<[Mutex]>, + wheels: RwLock, /// True if the driver is being shutdown. pub(super) is_shutdown: AtomicBool, @@ -130,6 +131,9 @@ struct Inner { did_wake: AtomicBool, } +/// Wrapper around the sharded timer wheels. +struct ShardedWheel(Box<[Mutex]>); + // ===== impl Driver ===== impl Driver { @@ -149,7 +153,7 @@ impl Driver { time_source, inner: Inner { next_wake: AtomicOptionNonZeroU64::new(None), - wheels: wheels.into_boxed_slice(), + wheels: RwLock::new(ShardedWheel(wheels.into_boxed_slice())), is_shutdown: AtomicBool::new(false), #[cfg(feature = "test-util")] did_wake: AtomicBool::new(false), @@ -190,23 +194,28 @@ impl Driver { assert!(!handle.is_shutdown()); // Finds out the min expiration time to park. - let locks = (0..rt_handle.time().inner.get_shard_size()) - .map(|id| rt_handle.time().inner.lock_sharded_wheel(id)) - .collect::>(); - - let expiration_time = locks - .iter() - .filter_map(|lock| lock.next_expiration_time()) - .min(); - - rt_handle - .time() - .inner - .next_wake - .store(next_wake_time(expiration_time)); - - // Safety: After updating the `next_wake`, we drop all the locks. - drop(locks); + let expiration_time = { + let mut wheels_lock = rt_handle + .time() + .inner + .wheels + .write() + .expect("Timer wheel shards poisened"); + let expiration_time = (0..wheels_lock.get_shard_size()) + .filter_map(|id| { + let wheel = wheels_lock.get_sharded_wheel(id); + wheel.next_expiration_time() + }) + .min(); + + rt_handle + .time() + .inner + .next_wake + .store(next_wake_time(expiration_time)); + + expiration_time + }; match expiration_time { Some(when) => { @@ -312,7 +321,12 @@ impl Handle { // Returns the next wakeup time of this shard. pub(self) fn process_at_sharded_time(&self, id: u32, mut now: u64) -> Option { let mut waker_list = WakeList::new(); - let mut lock = self.inner.lock_sharded_wheel(id); + let wheels_lock = self + .inner + .wheels + .read() + .expect("Timer wheel shards poisened"); + let mut lock = wheels_lock.lock_sharded_wheel(id); if now < lock.elapsed() { // Time went backwards! This normally shouldn't happen as the Rust language @@ -337,7 +351,7 @@ impl Handle { waker_list.wake_all(); - lock = self.inner.lock_sharded_wheel(id); + lock = wheels_lock.lock_sharded_wheel(id); } } } @@ -360,7 +374,12 @@ impl Handle { /// `add_entry` must not be called concurrently. pub(self) unsafe fn clear_entry(&self, entry: NonNull) { unsafe { - let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id()); + let wheels_lock = self + .inner + .wheels + .read() + .expect("Timer wheel shards poisened"); + let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id()); if entry.as_ref().might_be_registered() { lock.remove(entry); @@ -383,7 +402,13 @@ impl Handle { entry: NonNull, ) { let waker = unsafe { - let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id()); + let wheels_lock = self + .inner + .wheels + .read() + .expect("Timer wheel shards poisened"); + + let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id()); // We may have raced with a firing/deregistration, so check before // deregistering. @@ -443,16 +468,6 @@ impl Handle { // ===== impl Inner ===== impl Inner { - /// Locks the driver's sharded wheel structure. - pub(super) fn lock_sharded_wheel( - &self, - shard_id: u32, - ) -> crate::loom::sync::MutexGuard<'_, Wheel> { - let index = shard_id % (self.wheels.len() as u32); - // Safety: This modulo operation ensures that the index is not out of bounds. - unsafe { self.wheels.get_unchecked(index as usize).lock() } - } - // Check whether the driver has been shutdown pub(super) fn is_shutdown(&self) -> bool { self.is_shutdown.load(Ordering::SeqCst) @@ -460,7 +475,10 @@ impl Inner { // Gets the number of shards. fn get_shard_size(&self) -> u32 { - self.wheels.len() as u32 + self.wheels + .read() + .expect("Timer wheel shards poisened") + .get_shard_size() } } @@ -470,5 +488,31 @@ impl fmt::Debug for Inner { } } +// ===== impl ShardedWheel ===== + +impl ShardedWheel { + /// Locks the driver's sharded wheel structure. + pub(super) fn lock_sharded_wheel( + &self, + shard_id: u32, + ) -> crate::loom::sync::MutexGuard<'_, Wheel> { + let index = shard_id % (self.0.len() as u32); + // Safety: This modulo operation ensures that the index is not out of bounds. + unsafe { self.0.get_unchecked(index as usize) }.lock() + } + + /// Gets a mutable reference to the sharded wheel with the given id. + pub(super) fn get_sharded_wheel(&mut self, shard_id: u32) -> &mut wheel::Wheel { + let index = shard_id % (self.0.len() as u32); + // Safety: This modulo operation ensures that the index is not out of bounds. + unsafe { self.0.get_unchecked_mut(index as usize) }.get_mut() + } + + /// Gets the number of shards. + fn get_shard_size(&self) -> u32 { + self.0.len() as u32 + } +} + #[cfg(test)] mod tests;