@@ -9,7 +9,7 @@ use std::pin::Pin;
99use std:: ptr;
1010use std:: sync:: atomic:: AtomicUsize ;
1111use std:: sync:: atomic:: Ordering :: { Acquire , SeqCst } ;
12- use std:: sync:: { Arc , Mutex , Weak } ;
12+ use std:: sync:: { Arc , Mutex , MutexGuard , Weak } ;
1313
1414/// Future for the [`shared`](super::FutureExt::shared) method.
1515#[ must_use = "futures do nothing unless you `.await` or poll them" ]
@@ -81,6 +81,7 @@ const IDLE: usize = 0;
8181const POLLING : usize = 1 ;
8282const COMPLETE : usize = 2 ;
8383const POISONED : usize = 3 ;
84+ const WOKEN_DURING_POLLING : usize = 4 ;
8485
8586const NULL_WAKER_KEY : usize = usize:: MAX ;
8687
@@ -197,35 +198,43 @@ where
197198 }
198199}
199200
201+ /// Registers the current task to receive a wakeup when we are awoken.
202+ fn record_waker ( wakers_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > , waker_key : & mut usize , cx : & mut Context < ' _ > ) {
203+ let wakers = match wakers_guard. as_mut ( ) {
204+ Some ( wakers) => wakers,
205+ None => return ,
206+ } ;
207+
208+ let new_waker = cx. waker ( ) ;
209+
210+ if * waker_key == NULL_WAKER_KEY {
211+ * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
212+ } else {
213+ match wakers[ * waker_key] {
214+ Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
215+ // Could use clone_from here, but Waker doesn't specialize it.
216+ ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
217+ }
218+ }
219+ debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
220+ }
221+
222+ /// Wakes all tasks that are registered to be woken.
223+ fn wake_all ( waker_guard : & mut MutexGuard < ' _ , Option < Slab < Option < Waker > > > > ) {
224+ if let Some ( wakers) = waker_guard. as_mut ( ) {
225+ for ( _key, opt_waker) in wakers {
226+ if let Some ( waker) = opt_waker. take ( ) {
227+ waker. wake ( ) ;
228+ }
229+ }
230+ }
231+ }
232+
200233impl < Fut > Inner < Fut >
201234where
202235 Fut : Future ,
203236 Fut :: Output : Clone ,
204237{
205- /// Registers the current task to receive a wakeup when we are awoken.
206- fn record_waker ( & self , waker_key : & mut usize , cx : & mut Context < ' _ > ) {
207- let mut wakers_guard = self . notifier . wakers . lock ( ) . unwrap ( ) ;
208-
209- let wakers_mut = wakers_guard. as_mut ( ) ;
210-
211- let wakers = match wakers_mut {
212- Some ( wakers) => wakers,
213- None => return ,
214- } ;
215-
216- let new_waker = cx. waker ( ) ;
217-
218- if * waker_key == NULL_WAKER_KEY {
219- * waker_key = wakers. insert ( Some ( new_waker. clone ( ) ) ) ;
220- } else {
221- match wakers[ * waker_key] {
222- Some ( ref old_waker) if new_waker. will_wake ( old_waker) => { }
223- // Could use clone_from here, but Waker doesn't specialize it.
224- ref mut slot => * slot = Some ( new_waker. clone ( ) ) ,
225- }
226- }
227- debug_assert ! ( * waker_key != NULL_WAKER_KEY ) ;
228- }
229238
230239 /// Safety: callers must first ensure that `inner.state`
231240 /// is `COMPLETE`
@@ -268,18 +277,18 @@ where
268277 return unsafe { Poll :: Ready ( inner. take_or_clone_output ( ) ) } ;
269278 }
270279
271- inner. record_waker ( & mut this. waker_key , cx) ;
280+ // Guard the state transition with mutex too
281+ let mut wakers_guard = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
282+ record_waker ( & mut wakers_guard, & mut this. waker_key , cx) ;
272283
273- match inner
274- . notifier
275- . state
276- . compare_exchange ( IDLE , POLLING , SeqCst , SeqCst )
277- . unwrap_or_else ( |x| x)
278- {
284+ let prev = inner. notifier . state . compare_exchange ( IDLE , POLLING , SeqCst , SeqCst ) . unwrap_or_else ( |x| x) ;
285+ drop ( wakers_guard) ;
286+
287+ match prev {
279288 IDLE => {
280289 // Lock acquired, fall through
281290 }
282- POLLING => {
291+ POLLING | WOKEN_DURING_POLLING => {
283292 // Another task is currently polling, at this point we just want
284293 // to ensure that the waker for this task is registered
285294 this. inner = Some ( inner) ;
@@ -324,15 +333,22 @@ where
324333
325334 match poll_result {
326335 Poll :: Pending => {
327- if inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst ) . is_ok ( )
336+ match inner. notifier . state . compare_exchange ( POLLING , IDLE , SeqCst , SeqCst )
328337 {
329- // Success
330- drop ( reset) ;
331- this. inner = Some ( inner) ;
332- return Poll :: Pending ;
333- } else {
334- unreachable ! ( )
338+ Ok ( POLLING ) => { } // success
339+ Err ( WOKEN_DURING_POLLING ) => {
340+ // waker has been called inside future.poll, need to wake any new wakers registered
341+ let mut wakers = inner. notifier . wakers . lock ( ) . unwrap ( ) ;
342+ wake_all ( & mut wakers) ;
343+ let prev = inner. notifier . state . swap ( IDLE , SeqCst ) ;
344+ assert_eq ! ( prev, WOKEN_DURING_POLLING ) ;
345+ drop ( wakers) ;
346+ }
347+ _ => unreachable ! ( ) ,
335348 }
349+ drop ( reset) ;
350+ this. inner = Some ( inner) ;
351+ return Poll :: Pending ;
336352 }
337353 Poll :: Ready ( output) => output,
338354 }
@@ -387,14 +403,9 @@ where
387403
388404impl ArcWake for Notifier {
389405 fn wake_by_ref ( arc_self : & Arc < Self > ) {
390- let wakers = & mut * arc_self. wakers . lock ( ) . unwrap ( ) ;
391- if let Some ( wakers) = wakers. as_mut ( ) {
392- for ( _key, opt_waker) in wakers {
393- if let Some ( waker) = opt_waker. take ( ) {
394- waker. wake ( ) ;
395- }
396- }
397- }
406+ let mut wakers = arc_self. wakers . lock ( ) . unwrap ( ) ;
407+ let _ = arc_self. state . compare_exchange ( POLLING , WOKEN_DURING_POLLING , SeqCst , SeqCst ) ;
408+ wake_all ( & mut wakers) ;
398409 }
399410}
400411
0 commit comments