Skip to content

Commit 3e12f33

Browse files
committed
Eliminate timer wheel allocations by wrapping the wheel shards in an
RwLock
1 parent 53ea44b commit 3e12f33

File tree

2 files changed

+65
-29
lines changed

2 files changed

+65
-29
lines changed

tokio/src/loom/std/mutex.rs

+8
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,12 @@ impl<T> Mutex<T> {
3333
Err(TryLockError::WouldBlock) => None,
3434
}
3535
}
36+
37+
#[inline]
38+
pub(crate) fn get_mut(&mut self) -> &mut T {
39+
match self.0.get_mut() {
40+
Ok(val) => val,
41+
Err(p_err) => p_err.into_inner(),
42+
}
43+
}
3644
}

tokio/src/runtime/time/mod.rs

+57-29
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use crate::util::WakeList;
2828

2929
use crate::loom::sync::atomic::AtomicU64;
3030
use std::fmt;
31+
use std::sync::RwLock;
3132
use std::{num::NonZeroU64, ptr::NonNull};
3233

3334
struct AtomicOptionNonZeroU64(AtomicU64);
@@ -115,7 +116,7 @@ struct Inner {
115116
next_wake: AtomicOptionNonZeroU64,
116117

117118
/// Sharded Timer wheels.
118-
wheels: Box<[Mutex<wheel::Wheel>]>,
119+
wheels: RwLock<ShardedWheel>,
119120

120121
/// True if the driver is being shutdown.
121122
pub(super) is_shutdown: AtomicBool,
@@ -130,6 +131,9 @@ struct Inner {
130131
did_wake: AtomicBool,
131132
}
132133

134+
/// Wrapper around the sharded timer wheels.
135+
struct ShardedWheel(Box<[Mutex<wheel::Wheel>]>);
136+
133137
// ===== impl Driver =====
134138

135139
impl Driver {
@@ -149,7 +153,7 @@ impl Driver {
149153
time_source,
150154
inner: Inner {
151155
next_wake: AtomicOptionNonZeroU64::new(None),
152-
wheels: wheels.into_boxed_slice(),
156+
wheels: RwLock::new(ShardedWheel(wheels.into_boxed_slice())),
153157
is_shutdown: AtomicBool::new(false),
154158
#[cfg(feature = "test-util")]
155159
did_wake: AtomicBool::new(false),
@@ -190,18 +194,23 @@ impl Driver {
190194
assert!(!handle.is_shutdown());
191195

192196
// Finds out the min expiration time to park.
193-
let expiration_time = (0..rt_handle.time().inner.get_shard_size())
194-
.filter_map(|id| {
195-
let lock = rt_handle.time().inner.lock_sharded_wheel(id);
196-
lock.next_expiration_time()
197-
})
198-
.min();
199-
200-
rt_handle
201-
.time()
202-
.inner
203-
.next_wake
204-
.store(next_wake_time(expiration_time));
197+
let expiration_time = {
198+
let mut wheels_lock = rt_handle.time().inner.wheels.write().expect("");
199+
let expiration_time = (0..wheels_lock.get_shard_size())
200+
.filter_map(|id| {
201+
let wheel = wheels_lock.get_sharded_wheel(id);
202+
wheel.next_expiration_time()
203+
})
204+
.min();
205+
206+
rt_handle
207+
.time()
208+
.inner
209+
.next_wake
210+
.store(next_wake_time(expiration_time));
211+
212+
expiration_time
213+
};
205214

206215
match expiration_time {
207216
Some(when) => {
@@ -307,7 +316,8 @@ impl Handle {
307316
// Returns the next wakeup time of this shard.
308317
pub(self) fn process_at_sharded_time(&self, id: u32, mut now: u64) -> Option<u64> {
309318
let mut waker_list = WakeList::new();
310-
let mut lock = self.inner.lock_sharded_wheel(id);
319+
let wheels_lock = self.inner.wheels.read().expect("");
320+
let mut lock = wheels_lock.lock_sharded_wheel(id);
311321

312322
if now < lock.elapsed() {
313323
// Time went backwards! This normally shouldn't happen as the Rust language
@@ -332,7 +342,7 @@ impl Handle {
332342

333343
waker_list.wake_all();
334344

335-
lock = self.inner.lock_sharded_wheel(id);
345+
lock = wheels_lock.lock_sharded_wheel(id);
336346
}
337347
}
338348
}
@@ -355,7 +365,8 @@ impl Handle {
355365
/// `add_entry` must not be called concurrently.
356366
pub(self) unsafe fn clear_entry(&self, entry: NonNull<TimerShared>) {
357367
unsafe {
358-
let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id());
368+
let wheels_lock = self.inner.wheels.read().expect("");
369+
let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());
359370

360371
if entry.as_ref().might_be_registered() {
361372
lock.remove(entry);
@@ -378,7 +389,8 @@ impl Handle {
378389
entry: NonNull<TimerShared>,
379390
) {
380391
let waker = unsafe {
381-
let mut lock = self.inner.lock_sharded_wheel(entry.as_ref().shard_id());
392+
let wheels_lock = self.inner.wheels.read().expect("");
393+
let mut lock = wheels_lock.lock_sharded_wheel(entry.as_ref().shard_id());
382394

383395
// We may have raced with a firing/deregistration, so check before
384396
// deregistering.
@@ -438,24 +450,14 @@ impl Handle {
438450
// ===== impl Inner =====
439451

440452
impl Inner {
441-
/// Locks the driver's sharded wheel structure.
442-
pub(super) fn lock_sharded_wheel(
443-
&self,
444-
shard_id: u32,
445-
) -> crate::loom::sync::MutexGuard<'_, Wheel> {
446-
let index = shard_id % (self.wheels.len() as u32);
447-
// Safety: This modulo operation ensures that the index is not out of bounds.
448-
unsafe { self.wheels.get_unchecked(index as usize).lock() }
449-
}
450-
451453
// Check whether the driver has been shutdown
452454
pub(super) fn is_shutdown(&self) -> bool {
453455
self.is_shutdown.load(Ordering::SeqCst)
454456
}
455457

456458
// Gets the number of shards.
457459
fn get_shard_size(&self) -> u32 {
458-
self.wheels.len() as u32
460+
self.wheels.read().expect("").get_shard_size()
459461
}
460462
}
461463

@@ -465,5 +467,31 @@ impl fmt::Debug for Inner {
465467
}
466468
}
467469

470+
// ===== impl ShardedWheel =====
471+
472+
impl ShardedWheel {
473+
/// Locks the driver's sharded wheel structure.
474+
pub(super) fn lock_sharded_wheel(
475+
&self,
476+
shard_id: u32,
477+
) -> crate::loom::sync::MutexGuard<'_, Wheel> {
478+
let index = shard_id % (self.0.len() as u32);
479+
// Safety: This modulo operation ensures that the index is not out of bounds.
480+
unsafe { self.0.get_unchecked(index as usize) }.lock()
481+
}
482+
483+
/// Gets a mutable reference to the sharded wheel with the given id.
484+
pub(super) fn get_sharded_wheel(&mut self, shard_id: u32) -> &mut wheel::Wheel {
485+
let index = shard_id % (self.0.len() as u32);
486+
// Safety: This modulo operation ensures that the index is not out of bounds.
487+
unsafe { self.0.get_unchecked_mut(index as usize) }.get_mut()
488+
}
489+
490+
/// Gets the number of shards.
491+
fn get_shard_size(&self) -> u32 {
492+
self.0.len() as u32
493+
}
494+
}
495+
468496
#[cfg(test)]
469497
mod tests;

0 commit comments

Comments
 (0)