@@ -28,6 +28,7 @@ use crate::util::WakeList;
2828
2929use crate :: loom:: sync:: atomic:: AtomicU64 ;
3030use std:: fmt;
31+ use std:: sync:: RwLock ;
3132use std:: { num:: NonZeroU64 , ptr:: NonNull } ;
3233
3334struct AtomicOptionNonZeroU64 ( AtomicU64 ) ;
@@ -115,7 +116,10 @@ struct Inner {
115116 next_wake : AtomicOptionNonZeroU64 ,
116117
117118 /// Sharded Timer wheels.
118- wheels : Box < [ Mutex < wheel:: Wheel > ] > ,
119+ wheels : RwLock < ShardedWheel > ,
120+
121+ /// Number of entries in the sharded timer wheels.
122+ wheels_len : u32 ,
119123
120124 /// True if the driver is being shutdown.
121125 pub ( super ) is_shutdown : AtomicBool ,
@@ -130,6 +134,9 @@ struct Inner {
130134 did_wake : AtomicBool ,
131135}
132136
137+ /// Wrapper around the sharded timer wheels.
138+ struct ShardedWheel ( Box < [ Mutex < wheel:: Wheel > ] > ) ;
139+
133140// ===== impl Driver =====
134141
135142impl Driver {
@@ -149,7 +156,8 @@ impl Driver {
149156 time_source,
150157 inner : Inner {
151158 next_wake : AtomicOptionNonZeroU64 :: new ( None ) ,
152- wheels : wheels. into_boxed_slice ( ) ,
159+ wheels : RwLock :: new ( ShardedWheel ( wheels. into_boxed_slice ( ) ) ) ,
160+ wheels_len : shards,
153161 is_shutdown : AtomicBool :: new ( false ) ,
154162 #[ cfg( feature = "test-util" ) ]
155163 did_wake : AtomicBool :: new ( false ) ,
@@ -190,23 +198,27 @@ impl Driver {
190198 assert ! ( !handle. is_shutdown( ) ) ;
191199
192200 // Finds out the min expiration time to park.
193- let locks = ( 0 ..rt_handle. time ( ) . inner . get_shard_size ( ) )
194- . map ( |id| rt_handle. time ( ) . inner . lock_sharded_wheel ( id) )
195- . collect :: < Vec < _ > > ( ) ;
196-
197- let expiration_time = locks
198- . iter ( )
199- . filter_map ( |lock| lock. next_expiration_time ( ) )
200- . min ( ) ;
201-
202- rt_handle
203- . time ( )
204- . inner
205- . next_wake
206- . store ( next_wake_time ( expiration_time) ) ;
207-
208- // Safety: After updating the `next_wake`, we drop all the locks.
209- drop ( locks) ;
201+ let expiration_time = {
202+ let mut wheels_lock = rt_handle
203+ . time ( )
204+ . inner
205+ . wheels
206+ . write ( )
207+ . expect ( "Timer wheel shards poisoned" ) ;
208+ let expiration_time = wheels_lock
209+ . 0
210+ . iter_mut ( )
211+ . filter_map ( |wheel| wheel. get_mut ( ) . next_expiration_time ( ) )
212+ . min ( ) ;
213+
214+ rt_handle
215+ . time ( )
216+ . inner
217+ . next_wake
218+ . store ( next_wake_time ( expiration_time) ) ;
219+
220+ expiration_time
221+ } ;
210222
211223 match expiration_time {
212224 Some ( when) => {
@@ -312,7 +324,12 @@ impl Handle {
312324 // Returns the next wakeup time of this shard.
313325 pub ( self ) fn process_at_sharded_time ( & self , id : u32 , mut now : u64 ) -> Option < u64 > {
314326 let mut waker_list = WakeList :: new ( ) ;
315- let mut lock = self . inner . lock_sharded_wheel ( id) ;
327+ let mut wheels_lock = self
328+ . inner
329+ . wheels
330+ . read ( )
331+ . expect ( "Timer wheel shards poisoned" ) ;
332+ let mut lock = wheels_lock. lock_sharded_wheel ( id) ;
316333
317334 if now < lock. elapsed ( ) {
318335 // Time went backwards! This normally shouldn't happen as the Rust language
@@ -334,15 +351,22 @@ impl Handle {
334351 if !waker_list. can_push ( ) {
335352 // Wake a batch of wakers. To avoid deadlock, we must do this with the lock temporarily dropped.
336353 drop ( lock) ;
354+ drop ( wheels_lock) ;
337355
338356 waker_list. wake_all ( ) ;
339357
340- lock = self . inner . lock_sharded_wheel ( id) ;
358+ wheels_lock = self
359+ . inner
360+ . wheels
361+ . read ( )
362+ . expect ( "Timer wheel shards poisoned" ) ;
363+ lock = wheels_lock. lock_sharded_wheel ( id) ;
341364 }
342365 }
343366 }
344367 let next_wake_up = lock. poll_at ( ) ;
345368 drop ( lock) ;
369+ drop ( wheels_lock) ;
346370
347371 waker_list. wake_all ( ) ;
348372 next_wake_up
@@ -360,7 +384,12 @@ impl Handle {
360384 /// `add_entry` must not be called concurrently.
361385 pub ( self ) unsafe fn clear_entry ( & self , entry : NonNull < TimerShared > ) {
362386 unsafe {
363- let mut lock = self . inner . lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
387+ let wheels_lock = self
388+ . inner
389+ . wheels
390+ . read ( )
391+ . expect ( "Timer wheel shards poisoned" ) ;
392+ let mut lock = wheels_lock. lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
364393
365394 if entry. as_ref ( ) . might_be_registered ( ) {
366395 lock. remove ( entry) ;
@@ -383,7 +412,13 @@ impl Handle {
383412 entry : NonNull < TimerShared > ,
384413 ) {
385414 let waker = unsafe {
386- let mut lock = self . inner . lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
415+ let wheels_lock = self
416+ . inner
417+ . wheels
418+ . read ( )
419+ . expect ( "Timer wheel shards poisoned" ) ;
420+
421+ let mut lock = wheels_lock. lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
387422
388423 // We may have raced with a firing/deregistration, so check before
389424 // deregistering.
@@ -443,24 +478,14 @@ impl Handle {
443478// ===== impl Inner =====
444479
445480impl Inner {
446- /// Locks the driver's sharded wheel structure.
447- pub ( super ) fn lock_sharded_wheel (
448- & self ,
449- shard_id : u32 ,
450- ) -> crate :: loom:: sync:: MutexGuard < ' _ , Wheel > {
451- let index = shard_id % ( self . wheels . len ( ) as u32 ) ;
452- // Safety: This modulo operation ensures that the index is not out of bounds.
453- unsafe { self . wheels . get_unchecked ( index as usize ) . lock ( ) }
454- }
455-
456481 // Check whether the driver has been shutdown
457482 pub ( super ) fn is_shutdown ( & self ) -> bool {
458483 self . is_shutdown . load ( Ordering :: SeqCst )
459484 }
460485
461486 // Gets the number of shards.
462487 fn get_shard_size ( & self ) -> u32 {
463- self . wheels . len ( ) as u32
488+ self . wheels_len
464489 }
465490}
466491
@@ -470,5 +495,19 @@ impl fmt::Debug for Inner {
470495 }
471496}
472497
498+ // ===== impl ShardedWheel =====
499+
500+ impl ShardedWheel {
501+ /// Locks the driver's sharded wheel structure.
502+ pub ( super ) fn lock_sharded_wheel (
503+ & self ,
504+ shard_id : u32 ,
505+ ) -> crate :: loom:: sync:: MutexGuard < ' _ , Wheel > {
506+ let index = shard_id % ( self . 0 . len ( ) as u32 ) ;
507+ // Safety: This modulo operation ensures that the index is not out of bounds.
508+ unsafe { self . 0 . get_unchecked ( index as usize ) } . lock ( )
509+ }
510+ }
511+
473512#[ cfg( test) ]
474513mod tests;
0 commit comments