Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task: wake local tasks to the local queue when woken by the same thread #5095

Merged
merged 22 commits into from
Oct 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 172 additions & 22 deletions tokio/src/task/local.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Runs `!Send` futures on the current thread.
use crate::loom::sync::{Arc, Mutex};
use crate::loom::thread::{self, ThreadId};
use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
use crate::sync::AtomicWaker;
use crate::util::{RcCell, VecDequeCell};
Expand Down Expand Up @@ -228,9 +229,6 @@ struct Context {
/// Collection of all active tasks spawned onto this executor.
owned: LocalOwnedTasks<Arc<Shared>>,

/// Local run queue sender and receiver.
queue: VecDequeCell<task::Notified<Arc<Shared>>>,

/// State shared between threads.
shared: Arc<Shared>,

Expand All @@ -241,6 +239,19 @@ struct Context {

/// LocalSet state shared between threads.
struct Shared {
/// Local run queue sender and receiver.
///
/// # Safety
///
/// This field must *only* be accessed from the thread that owns the
/// `LocalSet` (i.e., `Thread::current().id() == owner`).
local_queue: VecDequeCell<task::Notified<Arc<Shared>>>,

/// The `ThreadId` of the thread that owns the `LocalSet`.
///
/// Since `LocalSet` is `!Send`, this will never change.
owner: ThreadId,

/// Remote run queue sender.
queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>,

Expand All @@ -262,10 +273,21 @@ pin_project! {
}

#[cfg(any(loom, tokio_no_const_thread_local))]
thread_local!(static CURRENT: RcCell<Context> = RcCell::new());
thread_local!(static CURRENT: LocalData = LocalData {
thread_id: Cell::new(None),
ctx: RcCell::new(),
});

#[cfg(not(any(loom, tokio_no_const_thread_local)))]
thread_local!(static CURRENT: RcCell<Context> = const { RcCell::new() });
thread_local!(static CURRENT: LocalData = const { LocalData {
thread_id: Cell::new(None),
ctx: RcCell::new(),
} });

struct LocalData {
thread_id: Cell<Option<ThreadId>>,
ctx: RcCell<Context>,
}

cfg_rt! {
/// Spawns a `!Send` future on the local task set.
Expand Down Expand Up @@ -314,7 +336,7 @@ cfg_rt! {
where F: Future + 'static,
F::Output: 'static
{
match CURRENT.with(|maybe_cx| maybe_cx.get()) {
match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) {
None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
Some(cx) => cx.spawn(future, name)
}
Expand All @@ -335,7 +357,7 @@ pub struct LocalEnterGuard(Option<Rc<Context>>);

impl Drop for LocalEnterGuard {
fn drop(&mut self) {
CURRENT.with(|ctx| {
CURRENT.with(|LocalData { ctx, .. }| {
ctx.set(self.0.take());
})
}
Expand All @@ -354,8 +376,9 @@ impl LocalSet {
tick: Cell::new(0),
context: Rc::new(Context {
owned: LocalOwnedTasks::new(),
queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
shared: Arc::new(Shared {
local_queue: VecDequeCell::with_capacity(INITIAL_CAPACITY),
owner: thread::current().id(),
queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))),
waker: AtomicWaker::new(),
#[cfg(tokio_unstable)]
Expand All @@ -374,7 +397,7 @@ impl LocalSet {
///
/// [`spawn_local`]: fn@crate::task::spawn_local
pub fn enter(&self) -> LocalEnterGuard {
CURRENT.with(|ctx| {
CURRENT.with(|LocalData { ctx, .. }| {
let old = ctx.replace(Some(self.context.clone()));
LocalEnterGuard(old)
})
Expand Down Expand Up @@ -597,9 +620,9 @@ impl LocalSet {
.lock()
.as_mut()
.and_then(|queue| queue.pop_front())
.or_else(|| self.context.queue.pop_front())
.or_else(|| self.pop_local())
} else {
self.context.queue.pop_front().or_else(|| {
self.pop_local().or_else(|| {
self.context
.shared
.queue
Expand All @@ -612,8 +635,17 @@ impl LocalSet {
task.map(|task| self.context.owned.assert_owner(task))
}

fn pop_local(&self) -> Option<task::Notified<Arc<Shared>>> {
unsafe {
// Safety: because the `LocalSet` itself is `!Send`, we know we are
// on the same thread if we have access to the `LocalSet`, and can
// therefore access the local run queue.
self.context.shared.local_queue().pop_front()
}
}

fn with<T>(&self, f: impl FnOnce() -> T) -> T {
CURRENT.with(|ctx| {
CURRENT.with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
Expand All @@ -639,7 +671,7 @@ impl LocalSet {
fn with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T {
let mut f = Some(f);

let res = CURRENT.try_with(|ctx| {
let res = CURRENT.try_with(|LocalData { ctx, .. }| {
struct Reset<'a> {
ctx_ref: &'a RcCell<Context>,
val: Option<Rc<Context>>,
Expand Down Expand Up @@ -782,7 +814,21 @@ impl Drop for LocalSet {

// We already called shutdown on all tasks above, so there is no
// need to call shutdown.
for task in self.context.queue.take() {

// Safety: note that this *intentionally* bypasses the unsafe
// `Shared::local_queue()` method. This is in order to avoid the
// debug assertion that we are on the thread that owns the
// `LocalSet`, because on some systems (e.g. at least some macOS
// versions), attempting to get the current thread ID can panic due
// to the thread's local data that stores the thread ID being
// dropped *before* the `LocalSet`.
//
// Despite avoiding the assertion here, it is safe for us to access
// the local queue in `Drop`, because the `LocalSet` itself is
// `!Send`, so we can reasonably guarantee that it will not be
// `Drop`ped from another thread.
let local_queue = self.context.shared.local_queue.take();
for task in local_queue {
drop(task);
}

Expand Down Expand Up @@ -854,15 +900,48 @@ impl<T: Future> Future for RunUntil<'_, T> {
}

impl Shared {
/// # Safety
///
/// This is safe to call if and ONLY if we are on the thread that owns this
/// `LocalSet`.
unsafe fn local_queue(&self) -> &VecDequeCell<task::Notified<Arc<Self>>> {
debug_assert!(
// if we couldn't get the thread ID because we're dropping the local
// data, skip the assertion --- the `Drop` impl is not going to be
// called from another thread, because `LocalSet` is `!Send`
thread_id().map(|id| id == self.owner).unwrap_or(true),
"`LocalSet`'s local run queue must not be accessed by another thread!"
);
&self.local_queue
}

/// Schedule the provided task on the scheduler.
fn schedule(&self, task: task::Notified<Arc<Self>>) {
CURRENT.with(|maybe_cx| {
match maybe_cx.get() {
Some(cx) if cx.shared.ptr_eq(self) => {
cx.queue.push_back(task);
CURRENT.with(|localdata| {
match localdata.ctx.get() {
Some(cx) if cx.shared.ptr_eq(self) => unsafe {
// Safety: if the current `LocalSet` context points to this
// `LocalSet`, then we are on the thread that owns it.
cx.shared.local_queue().push_back(task);
},

// We are on the thread that owns the `LocalSet`, so we can
// wake to the local queue.
_ if localdata.get_or_insert_id() == self.owner => {
unsafe {
// Safety: we just checked that the thread ID matches
// the localset's owner, so this is safe.
self.local_queue().push_back(task);
}
// We still have to wake the `LocalSet`, because it isn't
// currently being polled.
self.waker.wake();
}

// We are *not* on the thread that owns the `LocalSet`, so we
// have to wake to the remote queue.
_ => {
// First check whether the queue is still there (if not, the
// First, check whether the queue is still there (if not, the
// LocalSet is dropped). Then push to it if so, and if not,
// do nothing.
let mut lock = self.queue.lock();
Expand All @@ -882,9 +961,13 @@ impl Shared {
}
}

// This is safe because (and only because) we *pinky pwomise* to never touch the
// local run queue except from the thread that owns the `LocalSet`.
unsafe impl Sync for Shared {}

impl task::Schedule for Arc<Shared> {
fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
CURRENT.with(|maybe_cx| match maybe_cx.get() {
CURRENT.with(|LocalData { ctx, .. }| match ctx.get() {
None => panic!("scheduler context missing"),
Some(cx) => {
assert!(cx.shared.ptr_eq(self));
Expand All @@ -909,7 +992,7 @@ impl task::Schedule for Arc<Shared> {
// This hook is only called from within the runtime, so
// `CURRENT` should match with `&self`, i.e. there is no
// opportunity for a nested scheduler to be called.
CURRENT.with(|maybe_cx| match maybe_cx.get() {
CURRENT.with(|LocalData { ctx, .. }| match ctx.get() {
Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
cx.unhandled_panic.set(true);
cx.owned.close_and_shutdown_all();
Expand All @@ -922,9 +1005,31 @@ impl task::Schedule for Arc<Shared> {
}
}

#[cfg(test)]
impl LocalData {
fn get_or_insert_id(&self) -> ThreadId {
self.thread_id.get().unwrap_or_else(|| {
let id = thread::current().id();
self.thread_id.set(Some(id));
id
})
}
}

fn thread_id() -> Option<ThreadId> {
CURRENT
.try_with(|localdata| localdata.get_or_insert_id())
.ok()
}

#[cfg(all(test, not(loom)))]
mod tests {
use super::*;

// Does a `LocalSet` running on a current-thread runtime...basically work?
//
// This duplicates a test in `tests/task_local_set.rs`, but because this is
// a lib test, it wil run under Miri, so this is necessary to catch stacked
// borrows violations in the `LocalSet` implementation.
#[test]
fn local_current_thread_scheduler() {
let f = async {
Expand All @@ -939,4 +1044,49 @@ mod tests {
.expect("rt")
.block_on(f)
}

// Tests that when a task on a `LocalSet` is woken by an io driver on the
// same thread, the task is woken to the localset's local queue rather than
// its remote queue.
//
// This test has to be defined in the `local.rs` file as a lib test, rather
// than in `tests/`, because it makes assertions about the local set's
// internal state.
#[test]
fn wakes_to_local_queue() {
use super::*;
use crate::sync::Notify;
let rt = crate::runtime::Builder::new_current_thread()
.build()
.expect("rt");
rt.block_on(async {
let local = LocalSet::new();
let notify = Arc::new(Notify::new());
let task = local.spawn_local({
let notify = notify.clone();
async move {
notify.notified().await;
}
});
let mut run_until = Box::pin(local.run_until(async move {
task.await.unwrap();
}));

// poll the run until future once
crate::future::poll_fn(|cx| {
let _ = run_until.as_mut().poll(cx);
Poll::Ready(())
})
.await;

notify.notify_one();
let task = unsafe { local.context.shared.local_queue().pop_front() };
// TODO(eliza): it would be nice to be able to assert that this is
// the local task.
assert!(
task.is_some(),
"task should have been notified to the LocalSet's local queue"
);
})
}
}