diff --git a/tokio-test/src/task.rs b/tokio-test/src/task.rs index c781d85ea91..2e646d44bf8 100644 --- a/tokio-test/src/task.rs +++ b/tokio-test/src/task.rs @@ -26,10 +26,11 @@ //! ``` use std::future::Future; +use std::mem; use std::ops; use std::pin::Pin; use std::sync::{Arc, Condvar, Mutex}; -use std::task::{Context, Poll, Wake, Waker}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; use tokio_stream::Stream; @@ -170,7 +171,7 @@ impl MockTask { F: FnOnce(&mut Context<'_>) -> R, { self.waker.clear(); - let waker = self.clone().into_waker(); + let waker = self.waker(); let mut cx = Context::from_waker(&waker); f(&mut cx) @@ -189,8 +190,11 @@ impl MockTask { Arc::strong_count(&self.waker) } - fn into_waker(self) -> Waker { - self.waker.into() + fn waker(&self) -> Waker { + unsafe { + let raw = to_raw(self.waker.clone()); + Waker::from_raw(raw) + } } } @@ -222,14 +226,8 @@ impl ThreadWaker { _ => unreachable!(), } } -} -impl Wake for ThreadWaker { - fn wake(self: Arc) { - self.wake_by_ref(); - } - - fn wake_by_ref(self: &Arc) { + fn wake(&self) { // First, try transitioning from IDLE -> NOTIFY, this does not require a lock. let mut state = self.state.lock().unwrap(); let prev = *state; @@ -249,3 +247,39 @@ impl Wake for ThreadWaker { self.condvar.notify_one(); } } + +static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker); + +unsafe fn to_raw(waker: Arc) -> RawWaker { + RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE) +} + +unsafe fn from_raw(raw: *const ()) -> Arc { + Arc::from_raw(raw as *const ThreadWaker) +} + +unsafe fn clone(raw: *const ()) -> RawWaker { + let waker = from_raw(raw); + + // Increment the ref count + mem::forget(waker.clone()); + + to_raw(waker) +} + +unsafe fn wake(raw: *const ()) { + let waker = from_raw(raw); + waker.wake(); +} + +unsafe fn wake_by_ref(raw: *const ()) { + let waker = from_raw(raw); + waker.wake(); + + // We don't actually own a reference to the unparker + mem::forget(waker); +} + +unsafe fn drop_waker(raw: *const ()) { + let _ = from_raw(raw); +} diff --git a/tokio/src/runtime/park.rs b/tokio/src/runtime/park.rs index 27bcd334c45..08d3e719bc4 100644 --- a/tokio/src/runtime/park.rs +++ b/tokio/src/runtime/park.rs @@ -2,7 +2,6 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Arc, Condvar, Mutex}; -use crate::util::{waker, Wake}; use std::sync::atomic::Ordering::SeqCst; use std::time::Duration; @@ -227,7 +226,7 @@ use crate::loom::thread::AccessError; use std::future::Future; use std::marker::PhantomData; use std::rc::Rc; -use std::task::Waker; +use std::task::{RawWaker, RawWakerVTable, Waker}; /// Blocks the current thread using a condition variable. #[derive(Debug)] @@ -293,20 +292,50 @@ impl CachedParkThread { impl UnparkThread { pub(crate) fn into_waker(self) -> Waker { - waker(self.inner) + unsafe { + let raw = unparker_to_raw_waker(self.inner); + Waker::from_raw(raw) + } } } -impl Wake for Inner { - fn wake(arc_self: Arc) { - arc_self.unpark(); +impl Inner { + #[allow(clippy::wrong_self_convention)] + fn into_raw(this: Arc) -> *const () { + Arc::into_raw(this) as *const () } - fn wake_by_ref(arc_self: &Arc) { - arc_self.unpark(); + unsafe fn from_raw(ptr: *const ()) -> Arc { + Arc::from_raw(ptr as *const Inner) } } +unsafe fn unparker_to_raw_waker(unparker: Arc) -> RawWaker { + RawWaker::new( + Inner::into_raw(unparker), + &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker), + ) +} + +unsafe fn clone(raw: *const ()) -> RawWaker { + Arc::increment_strong_count(raw as *const Inner); + unparker_to_raw_waker(Inner::from_raw(raw)) +} + +unsafe fn drop_waker(raw: *const ()) { + drop(Inner::from_raw(raw)); +} + +unsafe fn wake(raw: *const ()) { + let unparker = Inner::from_raw(raw); + unparker.unpark(); +} + +unsafe fn wake_by_ref(raw: *const ()) { + let raw = raw as *const Inner; + (*raw).unpark(); +} + #[cfg(loom)] pub(crate) fn current_thread_park_count() -> usize { CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst)) diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index eeddd0af2e8..c671fd6a1da 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -16,9 +16,6 @@ pub(crate) use blocking_check::check_socket_for_blocking; pub(crate) mod metric_atomics; -mod wake; -pub(crate) use wake::{waker, Wake}; - #[cfg(any( // io driver uses `WakeList` directly feature = "net", @@ -70,7 +67,9 @@ cfg_rt! { pub(crate) use self::rand::RngSeedGenerator; - pub(crate) use wake::{waker_ref, WakerRef}; + mod wake; + pub(crate) use wake::WakerRef; + pub(crate) use wake::{waker_ref, Wake}; mod sync_wrapper; pub(crate) use sync_wrapper::SyncWrapper; diff --git a/tokio/src/util/wake.rs b/tokio/src/util/wake.rs index d583937b8ba..896ec73e7b1 100644 --- a/tokio/src/util/wake.rs +++ b/tokio/src/util/wake.rs @@ -1,6 +1,8 @@ use crate::loom::sync::Arc; +use std::marker::PhantomData; use std::mem::ManuallyDrop; +use std::ops::Deref; use std::task::{RawWaker, RawWakerVTable, Waker}; /// Simplified waking interface based on Arcs. @@ -12,45 +14,30 @@ pub(crate) trait Wake: Send + Sync + Sized + 'static { fn wake_by_ref(arc_self: &Arc); } -cfg_rt! { - use std::marker::PhantomData; - use std::ops::Deref; - - /// A `Waker` that is only valid for a given lifetime. - #[derive(Debug)] - pub(crate) struct WakerRef<'a> { - waker: ManuallyDrop, - _p: PhantomData<&'a ()>, - } +/// A `Waker` that is only valid for a given lifetime. +#[derive(Debug)] +pub(crate) struct WakerRef<'a> { + waker: ManuallyDrop, + _p: PhantomData<&'a ()>, +} - impl Deref for WakerRef<'_> { - type Target = Waker; +impl Deref for WakerRef<'_> { + type Target = Waker; - fn deref(&self) -> &Waker { - &self.waker - } + fn deref(&self) -> &Waker { + &self.waker } +} - /// Creates a reference to a `Waker` from a reference to `Arc`. - pub(crate) fn waker_ref(wake: &Arc) -> WakerRef<'_> { - let ptr = Arc::as_ptr(wake).cast::<()>(); - - let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::())) }; +/// Creates a reference to a `Waker` from a reference to `Arc`. +pub(crate) fn waker_ref(wake: &Arc) -> WakerRef<'_> { + let ptr = Arc::as_ptr(wake).cast::<()>(); - WakerRef { - waker: ManuallyDrop::new(waker), - _p: PhantomData, - } - } -} + let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::())) }; -/// Creates a waker from a `Arc`. -pub(crate) fn waker(wake: Arc) -> Waker { - unsafe { - Waker::from_raw(RawWaker::new( - Arc::into_raw(wake).cast(), - waker_vtable::(), - )) + WakerRef { + waker: ManuallyDrop::new(waker), + _p: PhantomData, } }