Skip to content

Commit 970d880

Browse files
authored
task: drop the join waker of a task eagerly (#6986)
1 parent 4ca13e6 commit 970d880

File tree

6 files changed

+197
-20
lines changed

6 files changed

+197
-20
lines changed

spellcheck.dic

+1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ unparks
280280
Unparks
281281
unreceived
282282
unsafety
283+
unsets
283284
Unsets
284285
unsynchronized
285286
untrusted

tokio/src/runtime/task/harness.rs

+37-4
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,11 @@ where
284284
}
285285

286286
pub(super) fn drop_join_handle_slow(self) {
287-
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
287+
// Try to unset `JOIN_INTEREST` and `JOIN_WAKER`. This must be done as a first step in
288288
// case the task concurrently completed.
289-
if self.state().unset_join_interested().is_err() {
289+
let transition = self.state().transition_to_join_handle_dropped();
290+
291+
if transition.drop_output {
290292
// It is our responsibility to drop the output. This is critical as
291293
// the task output may not be `Send` and as such must remain with
292294
// the scheduler or `JoinHandle`. i.e. if the output remains in the
@@ -301,6 +303,23 @@ where
301303
}));
302304
}
303305

306+
if transition.drop_waker {
307+
// If the JOIN_WAKER flag is unset at this point, the task is either
308+
// already terminal or not complete so the `JoinHandle` is responsible
309+
// for dropping the waker.
310+
// Safety:
311+
// If the JOIN_WAKER bit is not set the join handle has exclusive
312+
// access to the waker as per rule 2 in task/mod.rs.
313+
// This can only be the case at this point in two scenarios:
314+
// 1. The task completed and the runtime unset `JOIN_WAKER` flag
315+
// after accessing the waker during task completion. So the
316+
// `JoinHandle` is the only one to access the join waker here.
317+
// 2. The task is not completed so the `JoinHandle` was able to unset
318+
// `JOIN_WAKER` bit itself to get mutable access to the waker.
319+
// The runtime will not access the waker when this flag is unset.
320+
unsafe { self.trailer().set_waker(None) };
321+
}
322+
304323
// Drop the `JoinHandle` reference, possibly deallocating the task
305324
self.drop_reference();
306325
}
@@ -311,7 +330,6 @@ where
311330
fn complete(self) {
312331
// The future has completed and its output has been written to the task
313332
// stage. We transition from running to complete.
314-
315333
let snapshot = self.state().transition_to_complete();
316334

317335
// We catch panics here in case dropping the future or waking the
@@ -320,13 +338,28 @@ where
320338
if !snapshot.is_join_interested() {
321339
// The `JoinHandle` is not interested in the output of
322340
// this task. It is our responsibility to drop the
323-
// output.
341+
// output. The join waker was already dropped by the
342+
// `JoinHandle` before.
324343
self.core().drop_future_or_output();
325344
} else if snapshot.is_join_waker_set() {
326345
// Notify the waker. Reading the waker field is safe per rule 4
327346
// in task/mod.rs, since the JOIN_WAKER bit is set and the call
328347
// to transition_to_complete() above set the COMPLETE bit.
329348
self.trailer().wake_join();
349+
350+
// Inform the `JoinHandle` that we are done waking the waker by
351+
// unsetting the `JOIN_WAKER` bit. If the `JoinHandle` has
352+
// already been dropped and `JOIN_INTEREST` is unset, then we must
353+
// drop the waker ourselves.
354+
if !self
355+
.state()
356+
.unset_waker_after_complete()
357+
.is_join_interested()
358+
{
359+
// SAFETY: We have COMPLETE=1 and JOIN_INTEREST=0, so
360+
// we have exclusive access to the waker.
361+
unsafe { self.trailer().set_waker(None) };
362+
}
330363
}
331364
}));
332365

tokio/src/runtime/task/mod.rs

+16-2
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,30 @@
9494
//! `JoinHandle` needs to (i) successfully set `JOIN_WAKER` to zero if it is
9595
//! not already zero to gain exclusive access to the waker field per rule
9696
//! 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` to one.
97+
//! If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped
98+
//! to clear the waker field, only steps (i) and (ii) are relevant.
9799
//!
98100
//! 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e.
99-
//! the task hasn't yet completed).
101+
//! the task hasn't yet completed). The runtime can change `JOIN_WAKER` only
102+
//! if COMPLETE is one.
103+
//!
104+
//! 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the runtime has
105+
//! exclusive (mutable) access to the waker field. This might happen if the
106+
//! `JoinHandle` gets dropped right after the task completes and the runtime
107+
//! sets the `COMPLETE` bit. In this case the runtime needs the mutable access
108+
//! to the waker field to drop it.
100109
//!
101110
//! Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a
102111
//! race. If step (i) fails, then the attempt to write a waker is aborted. If
103112
//! step (iii) fails because COMPLETE is set to one by another thread after
104113
//! step (i), then the waker field is cleared. Once COMPLETE is one (i.e.
105114
//! task has completed), the `JoinHandle` will not modify `JOIN_WAKER`. After the
106-
//! runtime sets COMPLETE to one, it invokes the waker if there is one.
115+
//! runtime sets COMPLETE to one, it invokes the waker if there is one so in this
116+
//! case when a task completes the `JOIN_WAKER` bit implicates to the runtime
117+
//! whether it should invoke the waker or not. After the runtime is done with
118+
//! using the waker during task completion, it unsets the `JOIN_WAKER` bit to give
119+
//! the `JoinHandle` exclusive access again so that it is able to drop the waker
120+
//! at a later point.
107121
//!
108122
//! All other fields are immutable and can be accessed immutably without
109123
//! synchronization by anyone.

tokio/src/runtime/task/state.rs

+51-12
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ pub(crate) enum TransitionToNotifiedByRef {
8989
Submit,
9090
}
9191

92+
#[must_use]
93+
pub(super) struct TransitionToJoinHandleDrop {
94+
pub(super) drop_waker: bool,
95+
pub(super) drop_output: bool,
96+
}
97+
9298
/// All transitions are performed via RMW operations. This establishes an
9399
/// unambiguous modification order.
94100
impl State {
@@ -371,22 +377,45 @@ impl State {
371377
.map_err(|_| ())
372378
}
373379

374-
/// Tries to unset the `JOIN_INTEREST` flag.
375-
///
376-
/// Returns `Ok` if the operation happens before the task transitions to a
377-
/// completed state, `Err` otherwise.
378-
pub(super) fn unset_join_interested(&self) -> UpdateResult {
379-
self.fetch_update(|curr| {
380-
assert!(curr.is_join_interested());
380+
/// Unsets the `JOIN_INTEREST` flag. If `COMPLETE` is not set, the `JOIN_WAKER`
381+
/// flag is also unset.
382+
/// The returned `TransitionToJoinHandleDrop` indicates whether the `JoinHandle` should drop
383+
/// the output of the future or the join waker after the transition.
384+
pub(super) fn transition_to_join_handle_dropped(&self) -> TransitionToJoinHandleDrop {
385+
self.fetch_update_action(|mut snapshot| {
386+
assert!(snapshot.is_join_interested());
381387

382-
if curr.is_complete() {
383-
return None;
388+
let mut transition = TransitionToJoinHandleDrop {
389+
drop_waker: false,
390+
drop_output: false,
391+
};
392+
393+
snapshot.unset_join_interested();
394+
395+
if !snapshot.is_complete() {
396+
// If `COMPLETE` is unset we also unset `JOIN_WAKER` to give the
397+
// `JoinHandle` exclusive access to the waker following rule 6 in task/mod.rs.
398+
// The `JoinHandle` will drop the waker if it has exclusive access
399+
// to drop it.
400+
snapshot.unset_join_waker();
401+
} else {
402+
// If `COMPLETE` is set the task is completed so the `JoinHandle` is responsible
403+
// for dropping the output.
404+
transition.drop_output = true;
384405
}
385406

386-
let mut next = curr;
387-
next.unset_join_interested();
407+
if !snapshot.is_join_waker_set() {
408+
// If the `JOIN_WAKER` bit is unset and the `JOIN_HANDLE` has exclusive access to
409+
// the join waker and should drop it following this transition.
410+
// This might happen in two situations:
411+
// 1. The task is not completed and we just unset the `JOIN_WAKer` above in this
412+
// function.
413+
// 2. The task is completed. In that case the `JOIN_WAKER` bit was already unset
414+
// by the runtime during completion.
415+
transition.drop_waker = true;
416+
}
388417

389-
Some(next)
418+
(transition, Some(snapshot))
390419
})
391420
}
392421

@@ -430,6 +459,16 @@ impl State {
430459
})
431460
}
432461

462+
/// Unsets the `JOIN_WAKER` bit unconditionally after task completion.
463+
///
464+
/// This operation requires the task to be completed.
465+
pub(super) fn unset_waker_after_complete(&self) -> Snapshot {
466+
let prev = Snapshot(self.val.fetch_and(!JOIN_WAKER, AcqRel));
467+
assert!(prev.is_complete());
468+
assert!(prev.is_join_waker_set());
469+
Snapshot(prev.0 & !JOIN_WAKER)
470+
}
471+
433472
pub(super) fn ref_inc(&self) {
434473
use std::process;
435474
use std::sync::atomic::Ordering::Relaxed;

tokio/src/runtime/tests/loom_current_thread.rs

+56-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
mod yield_now;
22

3-
use crate::loom::sync::atomic::AtomicUsize;
3+
use crate::loom::sync::atomic::{AtomicUsize, Ordering};
44
use crate::loom::sync::Arc;
55
use crate::loom::thread;
66
use crate::runtime::{Builder, Runtime};
@@ -9,7 +9,7 @@ use crate::task;
99
use std::future::Future;
1010
use std::pin::Pin;
1111
use std::sync::atomic::Ordering::{Acquire, Release};
12-
use std::task::{Context, Poll};
12+
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
1313

1414
fn assert_at_most_num_polls(rt: Arc<Runtime>, at_most_polls: usize) {
1515
let (tx, rx) = oneshot::channel();
@@ -106,6 +106,60 @@ fn assert_no_unnecessary_polls() {
106106
});
107107
}
108108

109+
#[test]
110+
fn drop_jh_during_schedule() {
111+
unsafe fn waker_clone(ptr: *const ()) -> RawWaker {
112+
let atomic = unsafe { &*(ptr as *const AtomicUsize) };
113+
atomic.fetch_add(1, Ordering::Relaxed);
114+
RawWaker::new(ptr, &VTABLE)
115+
}
116+
unsafe fn waker_drop(ptr: *const ()) {
117+
let atomic = unsafe { &*(ptr as *const AtomicUsize) };
118+
atomic.fetch_sub(1, Ordering::Relaxed);
119+
}
120+
unsafe fn waker_nop(_ptr: *const ()) {}
121+
122+
static VTABLE: RawWakerVTable =
123+
RawWakerVTable::new(waker_clone, waker_drop, waker_nop, waker_drop);
124+
125+
loom::model(|| {
126+
let rt = Builder::new_current_thread().build().unwrap();
127+
128+
let mut jh = rt.spawn(async {});
129+
// Using AbortHandle to increment task refcount. This ensures that the waker is not
130+
// destroyed due to the refcount hitting zero.
131+
let task_refcnt = jh.abort_handle();
132+
133+
let waker_refcnt = AtomicUsize::new(1);
134+
{
135+
// Set up the join waker.
136+
use std::future::Future;
137+
use std::pin::Pin;
138+
139+
// SAFETY: Before `waker_refcnt` goes out of scope, this test asserts that the refcnt
140+
// has dropped to zero.
141+
let join_waker = unsafe {
142+
Waker::from_raw(RawWaker::new(
143+
(&waker_refcnt) as *const AtomicUsize as *const (),
144+
&VTABLE,
145+
))
146+
};
147+
148+
assert!(Pin::new(&mut jh)
149+
.poll(&mut Context::from_waker(&join_waker))
150+
.is_pending());
151+
}
152+
assert_eq!(waker_refcnt.load(Ordering::Relaxed), 1);
153+
154+
let bg_thread = loom::thread::spawn(move || drop(jh));
155+
rt.block_on(crate::task::yield_now());
156+
bg_thread.join().unwrap();
157+
158+
assert_eq!(waker_refcnt.load(Ordering::Relaxed), 0);
159+
drop(task_refcnt);
160+
});
161+
}
162+
109163
struct BlockedFuture {
110164
rx: Receiver<()>,
111165
num_polls: Arc<AtomicUsize>,

tokio/tests/rt_handle.rs

+36
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
#![warn(rust_2018_idioms)]
33
#![cfg(feature = "full")]
44

5+
use std::sync::Arc;
56
use tokio::runtime::Runtime;
7+
use tokio::sync::{mpsc, Barrier};
68

79
#[test]
810
#[cfg_attr(panic = "abort", ignore)]
@@ -65,6 +67,40 @@ fn interleave_then_enter() {
6567
let _enter = rt3.enter();
6668
}
6769

70+
// If the cycle causes a leak, then miri will catch it.
71+
#[test]
72+
fn drop_tasks_with_reference_cycle() {
73+
rt().block_on(async {
74+
let (tx, mut rx) = mpsc::channel(1);
75+
76+
let barrier = Arc::new(Barrier::new(3));
77+
let barrier_a = barrier.clone();
78+
let barrier_b = barrier.clone();
79+
80+
let a = tokio::spawn(async move {
81+
let b = rx.recv().await.unwrap();
82+
83+
// Poll the JoinHandle once. This registers the waker.
84+
// The other task cannot have finished at this point due to the barrier below.
85+
futures::future::select(b, std::future::ready(())).await;
86+
87+
barrier_a.wait().await;
88+
});
89+
90+
let b = tokio::spawn(async move {
91+
// Poll the JoinHandle once. This registers the waker.
92+
// The other task cannot have finished at this point due to the barrier below.
93+
futures::future::select(a, std::future::ready(())).await;
94+
95+
barrier_b.wait().await;
96+
});
97+
98+
tx.send(b).await.unwrap();
99+
100+
barrier.wait().await;
101+
});
102+
}
103+
68104
#[cfg(tokio_unstable)]
69105
mod unstable {
70106
use super::*;

0 commit comments

Comments
 (0)