Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions tokio-test/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
//! ```

use std::future::Future;
use std::mem;
use std::ops;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, Wake, Waker};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};

use tokio_stream::Stream;

Expand Down Expand Up @@ -170,7 +171,7 @@ impl MockTask {
F: FnOnce(&mut Context<'_>) -> R,
{
self.waker.clear();
let waker = self.clone().into_waker();
let waker = self.waker();
let mut cx = Context::from_waker(&waker);

f(&mut cx)
Expand All @@ -189,8 +190,11 @@ impl MockTask {
Arc::strong_count(&self.waker)
}

fn into_waker(self) -> Waker {
self.waker.into()
fn waker(&self) -> Waker {
unsafe {
let raw = to_raw(self.waker.clone());
Waker::from_raw(raw)
}
}
}

Expand Down Expand Up @@ -222,14 +226,8 @@ impl ThreadWaker {
_ => unreachable!(),
}
}
}

impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}

fn wake_by_ref(self: &Arc<Self>) {
fn wake(&self) {
// First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
let mut state = self.state.lock().unwrap();
let prev = *state;
Expand All @@ -249,3 +247,39 @@ impl Wake for ThreadWaker {
self.condvar.notify_one();
}
}

static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);

unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
}

unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
Arc::from_raw(raw as *const ThreadWaker)
}

unsafe fn clone(raw: *const ()) -> RawWaker {
let waker = from_raw(raw);

// Increment the ref count
mem::forget(waker.clone());

to_raw(waker)
}

unsafe fn wake(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
}

unsafe fn wake_by_ref(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();

// We don't actually own a reference to the unparker
mem::forget(waker);
}

unsafe fn drop_waker(raw: *const ()) {
let _ = from_raw(raw);
}
45 changes: 37 additions & 8 deletions tokio/src/runtime/park.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Condvar, Mutex};
use crate::util::{waker, Wake};

use std::sync::atomic::Ordering::SeqCst;
use std::time::Duration;
Expand Down Expand Up @@ -227,7 +226,7 @@ use crate::loom::thread::AccessError;
use std::future::Future;
use std::marker::PhantomData;
use std::rc::Rc;
use std::task::Waker;
use std::task::{RawWaker, RawWakerVTable, Waker};

/// Blocks the current thread using a condition variable.
#[derive(Debug)]
Expand Down Expand Up @@ -293,20 +292,50 @@ impl CachedParkThread {

impl UnparkThread {
pub(crate) fn into_waker(self) -> Waker {
waker(self.inner)
unsafe {
let raw = unparker_to_raw_waker(self.inner);
Waker::from_raw(raw)
}
}
}

impl Wake for Inner {
fn wake(arc_self: Arc<Self>) {
arc_self.unpark();
impl Inner {
#[allow(clippy::wrong_self_convention)]
fn into_raw(this: Arc<Inner>) -> *const () {
Arc::into_raw(this) as *const ()
}

fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.unpark();
unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
Arc::from_raw(ptr as *const Inner)
}
}

unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
RawWaker::new(
Inner::into_raw(unparker),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
)
}

unsafe fn clone(raw: *const ()) -> RawWaker {
Arc::increment_strong_count(raw as *const Inner);
unparker_to_raw_waker(Inner::from_raw(raw))
}

unsafe fn drop_waker(raw: *const ()) {
drop(Inner::from_raw(raw));
}

unsafe fn wake(raw: *const ()) {
let unparker = Inner::from_raw(raw);
unparker.unpark();
}

unsafe fn wake_by_ref(raw: *const ()) {
let raw = raw as *const Inner;
(*raw).unpark();
}

#[cfg(loom)]
pub(crate) fn current_thread_park_count() -> usize {
CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst))
Expand Down
7 changes: 3 additions & 4 deletions tokio/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ pub(crate) use blocking_check::check_socket_for_blocking;

pub(crate) mod metric_atomics;

mod wake;
pub(crate) use wake::{waker, Wake};

#[cfg(any(
// io driver uses `WakeList` directly
feature = "net",
Expand Down Expand Up @@ -70,7 +67,9 @@ cfg_rt! {

pub(crate) use self::rand::RngSeedGenerator;

pub(crate) use wake::{waker_ref, WakerRef};
mod wake;
pub(crate) use wake::WakerRef;
pub(crate) use wake::{waker_ref, Wake};

mod sync_wrapper;
pub(crate) use sync_wrapper::SyncWrapper;
Expand Down
53 changes: 20 additions & 33 deletions tokio/src/util/wake.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::loom::sync::Arc;

use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops::Deref;
use std::task::{RawWaker, RawWakerVTable, Waker};

/// Simplified waking interface based on Arcs.
Expand All @@ -12,45 +14,30 @@ pub(crate) trait Wake: Send + Sync + Sized + 'static {
fn wake_by_ref(arc_self: &Arc<Self>);
}

cfg_rt! {
use std::marker::PhantomData;
use std::ops::Deref;

/// A `Waker` that is only valid for a given lifetime.
#[derive(Debug)]
pub(crate) struct WakerRef<'a> {
waker: ManuallyDrop<Waker>,
_p: PhantomData<&'a ()>,
}
/// A `Waker` that is only valid for a given lifetime.
#[derive(Debug)]
pub(crate) struct WakerRef<'a> {
waker: ManuallyDrop<Waker>,
_p: PhantomData<&'a ()>,
}

impl Deref for WakerRef<'_> {
type Target = Waker;
impl Deref for WakerRef<'_> {
type Target = Waker;

fn deref(&self) -> &Waker {
&self.waker
}
fn deref(&self) -> &Waker {
&self.waker
}
}

/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
let ptr = Arc::as_ptr(wake).cast::<()>();

let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };
/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
let ptr = Arc::as_ptr(wake).cast::<()>();

WakerRef {
waker: ManuallyDrop::new(waker),
_p: PhantomData,
}
}
}
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };

/// Creates a waker from a `Arc<impl Wake>`.
pub(crate) fn waker<W: Wake>(wake: Arc<W>) -> Waker {
unsafe {
Waker::from_raw(RawWaker::new(
Arc::into_raw(wake).cast(),
waker_vtable::<W>(),
))
WakerRef {
waker: ManuallyDrop::new(waker),
_p: PhantomData,
}
}

Expand Down