@@ -28,6 +28,7 @@ use crate::util::WakeList;
28
28
29
29
use crate :: loom:: sync:: atomic:: AtomicU64 ;
30
30
use std:: fmt;
31
+ use std:: sync:: RwLock ;
31
32
use std:: { num:: NonZeroU64 , ptr:: NonNull } ;
32
33
33
34
struct AtomicOptionNonZeroU64 ( AtomicU64 ) ;
@@ -115,7 +116,7 @@ struct Inner {
115
116
next_wake : AtomicOptionNonZeroU64 ,
116
117
117
118
/// Sharded Timer wheels.
118
- wheels : Box < [ Mutex < wheel :: Wheel > ] > ,
119
+ wheels : RwLock < ShardedWheel > ,
119
120
120
121
/// True if the driver is being shutdown.
121
122
pub ( super ) is_shutdown : AtomicBool ,
@@ -130,6 +131,9 @@ struct Inner {
130
131
did_wake : AtomicBool ,
131
132
}
132
133
134
+ /// Wrapper around the sharded timer wheels.
135
+ struct ShardedWheel ( Box < [ Mutex < wheel:: Wheel > ] > ) ;
136
+
133
137
// ===== impl Driver =====
134
138
135
139
impl Driver {
@@ -149,7 +153,7 @@ impl Driver {
149
153
time_source,
150
154
inner : Inner {
151
155
next_wake : AtomicOptionNonZeroU64 :: new ( None ) ,
152
- wheels : wheels. into_boxed_slice ( ) ,
156
+ wheels : RwLock :: new ( ShardedWheel ( wheels. into_boxed_slice ( ) ) ) ,
153
157
is_shutdown : AtomicBool :: new ( false ) ,
154
158
#[ cfg( feature = "test-util" ) ]
155
159
did_wake : AtomicBool :: new ( false ) ,
@@ -190,18 +194,23 @@ impl Driver {
190
194
assert ! ( !handle. is_shutdown( ) ) ;
191
195
192
196
// Finds out the min expiration time to park.
193
- let expiration_time = ( 0 ..rt_handle. time ( ) . inner . get_shard_size ( ) )
194
- . filter_map ( |id| {
195
- let lock = rt_handle. time ( ) . inner . lock_sharded_wheel ( id) ;
196
- lock. next_expiration_time ( )
197
- } )
198
- . min ( ) ;
199
-
200
- rt_handle
201
- . time ( )
202
- . inner
203
- . next_wake
204
- . store ( next_wake_time ( expiration_time) ) ;
197
+ let expiration_time = {
198
+ let mut wheels_lock = rt_handle. time ( ) . inner . wheels . write ( ) . expect ( "" ) ;
199
+ let expiration_time = ( 0 ..wheels_lock. get_shard_size ( ) )
200
+ . filter_map ( |id| {
201
+ let wheel = wheels_lock. get_sharded_wheel ( id) ;
202
+ wheel. next_expiration_time ( )
203
+ } )
204
+ . min ( ) ;
205
+
206
+ rt_handle
207
+ . time ( )
208
+ . inner
209
+ . next_wake
210
+ . store ( next_wake_time ( expiration_time) ) ;
211
+
212
+ expiration_time
213
+ } ;
205
214
206
215
match expiration_time {
207
216
Some ( when) => {
@@ -307,7 +316,8 @@ impl Handle {
307
316
// Returns the next wakeup time of this shard.
308
317
pub ( self ) fn process_at_sharded_time ( & self , id : u32 , mut now : u64 ) -> Option < u64 > {
309
318
let mut waker_list = WakeList :: new ( ) ;
310
- let mut lock = self . inner . lock_sharded_wheel ( id) ;
319
+ let wheels_lock = self . inner . wheels . read ( ) . expect ( "" ) ;
320
+ let mut lock = wheels_lock. lock_sharded_wheel ( id) ;
311
321
312
322
if now < lock. elapsed ( ) {
313
323
// Time went backwards! This normally shouldn't happen as the Rust language
@@ -332,7 +342,7 @@ impl Handle {
332
342
333
343
waker_list. wake_all ( ) ;
334
344
335
- lock = self . inner . lock_sharded_wheel ( id) ;
345
+ lock = wheels_lock . lock_sharded_wheel ( id) ;
336
346
}
337
347
}
338
348
}
@@ -355,7 +365,8 @@ impl Handle {
355
365
/// `add_entry` must not be called concurrently.
356
366
pub ( self ) unsafe fn clear_entry ( & self , entry : NonNull < TimerShared > ) {
357
367
unsafe {
358
- let mut lock = self . inner . lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
368
+ let wheels_lock = self . inner . wheels . read ( ) . expect ( "" ) ;
369
+ let mut lock = wheels_lock. lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
359
370
360
371
if entry. as_ref ( ) . might_be_registered ( ) {
361
372
lock. remove ( entry) ;
@@ -378,7 +389,8 @@ impl Handle {
378
389
entry : NonNull < TimerShared > ,
379
390
) {
380
391
let waker = unsafe {
381
- let mut lock = self . inner . lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
392
+ let wheels_lock = self . inner . wheels . read ( ) . expect ( "" ) ;
393
+ let mut lock = wheels_lock. lock_sharded_wheel ( entry. as_ref ( ) . shard_id ( ) ) ;
382
394
383
395
// We may have raced with a firing/deregistration, so check before
384
396
// deregistering.
@@ -438,24 +450,14 @@ impl Handle {
438
450
// ===== impl Inner =====
439
451
440
452
impl Inner {
441
- /// Locks the driver's sharded wheel structure.
442
- pub ( super ) fn lock_sharded_wheel (
443
- & self ,
444
- shard_id : u32 ,
445
- ) -> crate :: loom:: sync:: MutexGuard < ' _ , Wheel > {
446
- let index = shard_id % ( self . wheels . len ( ) as u32 ) ;
447
- // Safety: This modulo operation ensures that the index is not out of bounds.
448
- unsafe { self . wheels . get_unchecked ( index as usize ) . lock ( ) }
449
- }
450
-
451
453
// Check whether the driver has been shutdown
452
454
pub ( super ) fn is_shutdown ( & self ) -> bool {
453
455
self . is_shutdown . load ( Ordering :: SeqCst )
454
456
}
455
457
456
458
// Gets the number of shards.
457
459
fn get_shard_size ( & self ) -> u32 {
458
- self . wheels . len ( ) as u32
460
+ self . wheels . read ( ) . expect ( "" ) . get_shard_size ( )
459
461
}
460
462
}
461
463
@@ -465,5 +467,31 @@ impl fmt::Debug for Inner {
465
467
}
466
468
}
467
469
470
+ // ===== impl ShardedWheel =====
471
+
472
+ impl ShardedWheel {
473
+ /// Locks the driver's sharded wheel structure.
474
+ pub ( super ) fn lock_sharded_wheel (
475
+ & self ,
476
+ shard_id : u32 ,
477
+ ) -> crate :: loom:: sync:: MutexGuard < ' _ , Wheel > {
478
+ let index = shard_id % ( self . 0 . len ( ) as u32 ) ;
479
+ // Safety: This modulo operation ensures that the index is not out of bounds.
480
+ unsafe { self . 0 . get_unchecked ( index as usize ) } . lock ( )
481
+ }
482
+
483
+ /// Gets a mutable reference to the sharded wheel with the given id.
484
+ pub ( super ) fn get_sharded_wheel ( & mut self , shard_id : u32 ) -> & mut wheel:: Wheel {
485
+ let index = shard_id % ( self . 0 . len ( ) as u32 ) ;
486
+ // Safety: This modulo operation ensures that the index is not out of bounds.
487
+ unsafe { self . 0 . get_unchecked_mut ( index as usize ) } . get_mut ( )
488
+ }
489
+
490
+ /// Gets the number of shards.
491
+ fn get_shard_size ( & self ) -> u32 {
492
+ self . 0 . len ( ) as u32
493
+ }
494
+ }
495
+
468
496
#[ cfg( test) ]
469
497
mod tests;
0 commit comments