diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 885b4f212d4..e6d81683f6b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1058,6 +1058,7 @@ jobs: CARGO_TARGET_WASM32_WASIP1_RUNNER: "wasmtime run --" CARGO_TARGET_WASM32_WASIP1_THREADS_RUNNER: "wasmtime run -W bulk-memory=y -W threads=y -S threads=y --" RUSTFLAGS: --cfg tokio_unstable -Dwarnings -C target-feature=+atomics,+bulk-memory -C link-args=--max-memory=67108864 + RUSTDOCFLAGS: -C link-args=--max-memory=67108864 - name: WASI test tokio-stream run: cargo test -p tokio-stream --target ${{ matrix.target }} --features time,net,io-util,sync diff --git a/.github/workflows/loom.yml b/.github/workflows/loom.yml index 2eefe9d50c3..c400a28cd68 100644 --- a/.github/workflows/loom.yml +++ b/.github/workflows/loom.yml @@ -52,7 +52,7 @@ jobs: toolchain: ${{ env.rust_stable }} - uses: Swatinem/rust-cache@v2 - name: run tests - run: cargo test --lib --release --features full -- --nocapture runtime::time::tests + run: cargo test --lib --release --features full -- --nocapture runtime::time working-directory: tokio loom-current-thread: diff --git a/spellcheck.dic b/spellcheck.dic index 2baf2df351f..e377506bac6 100644 --- a/spellcheck.dic +++ b/spellcheck.dic @@ -1,4 +1,4 @@ -308 +311 & + < @@ -64,6 +64,7 @@ codec codecs combinator combinators +condvar config Config connectionless @@ -163,6 +164,7 @@ Lauck libc lifecycle lifo +LLVM lookups macOS MacOS @@ -307,3 +309,4 @@ Wakers wakeup wakeups workstealing +ZST diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index e87d2ad0381..7caea7d09d1 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -140,6 +140,7 @@ tokio-test = { version = "0.4.0", path = "../tokio-test" } tokio-stream = { version = "0.1", path = "../tokio-stream" } tokio-util = { version = "0.7", path = "../tokio-util", features = ["rt"] } futures = { version = "0.3.0", features = ["async-await"] } +futures-test = "0.3.31" mockall = "0.13.0" async-stream = "0.3" futures-concurrency = "7.6.3" diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 9aae69ab98f..3a40717e5ce 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -1,7 +1,9 @@ #![cfg_attr(loom, allow(unused_imports))] use crate::runtime::handle::Handle; -use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; +use crate::runtime::{ + blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback, TimerFlavor, +}; #[cfg(tokio_unstable)] use crate::runtime::{metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta}; use crate::util::rand::{RngSeed, RngSeedGenerator}; @@ -133,6 +135,8 @@ pub struct Builder { #[cfg(tokio_unstable)] pub(super) unhandled_panic: UnhandledPanic, + + timer_flavor: TimerFlavor, } cfg_unstable! { @@ -318,6 +322,8 @@ impl Builder { metrics_poll_count_histogram: HistogramBuilder::default(), disable_lifo_slot: false, + + timer_flavor: TimerFlavor::Traditional, } } @@ -363,6 +369,41 @@ impl Builder { self } + /// Enables the alternative timer implementation, which is disabled by default. + /// + /// The alternative timer implementation is an unstable feature that may + /// provide better performance on multi-threaded runtimes with a large number + /// of worker threads. + /// + /// This option only applies to multi-threaded runtimes. Attempting to use + /// this option with any other runtime type will have no effect. + /// + /// [Click here to share your experience with the alternative timer](https://github.com/tokio-rs/tokio/issues/7745) + /// + /// # Examples + /// + /// ``` + /// # #[cfg(not(target_family = "wasm"))] + /// # { + /// use tokio::runtime; + /// + /// let rt = runtime::Builder::new_multi_thread() + /// .enable_alt_timer() + /// .build() + /// .unwrap(); + /// # } + /// ``` + #[cfg(all(tokio_unstable, feature = "time", feature = "rt-multi-thread"))] + #[cfg_attr( + docsrs, + doc(cfg(all(tokio_unstable, feature = "time", feature = "rt-multi-thread"))) + )] + pub fn enable_alt_timer(&mut self) -> &mut Self { + self.enable_time(); + self.timer_flavor = TimerFlavor::Alternative; + self + } + /// Sets the number of worker threads the `Runtime` will use. /// /// This can be any number above 0 though it is advised to keep this value @@ -992,6 +1033,7 @@ impl Builder { enable_time: self.enable_time, start_paused: self.start_paused, nevents: self.nevents, + timer_flavor: self.timer_flavor, } } @@ -1544,7 +1586,9 @@ impl Builder { use crate::runtime::scheduler; use crate::runtime::Config; - let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?; + let mut cfg = self.get_cfg(); + cfg.timer_flavor = TimerFlavor::Traditional; + let (driver, driver_handle) = driver::Driver::new(cfg)?; // Blocking pool let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads); @@ -1761,6 +1805,7 @@ cfg_rt_multi_thread! { seed_generator: seed_generator_1, metrics_poll_count_histogram: self.metrics_poll_count_histogram_builder(), }, + self.timer_flavor, ); let handle = Handle { inner: scheduler::Handle::MultiThread(handle) }; diff --git a/tokio/src/runtime/driver.rs b/tokio/src/runtime/driver.rs index a1a6df8e007..92b2350db9d 100644 --- a/tokio/src/runtime/driver.rs +++ b/tokio/src/runtime/driver.rs @@ -40,6 +40,7 @@ pub(crate) struct Cfg { pub(crate) enable_pause_time: bool, pub(crate) start_paused: bool, pub(crate) nevents: usize, + pub(crate) timer_flavor: crate::runtime::TimerFlavor, } impl Driver { @@ -48,7 +49,8 @@ impl Driver { let clock = create_clock(cfg.enable_pause_time, cfg.start_paused); - let (time_driver, time_handle) = create_time_driver(cfg.enable_time, io_stack, &clock); + let (time_driver, time_handle) = + create_time_driver(cfg.enable_time, cfg.timer_flavor, io_stack, &clock); Ok(( Self { inner: time_driver }, @@ -113,6 +115,14 @@ impl Handle { .expect("A Tokio 1.x context was found, but timers are disabled. Call `enable_time` on the runtime builder to enable timers.") } + #[cfg(tokio_unstable)] + pub(crate) fn with_time(&self, f: F) -> R + where + F: FnOnce(Option<&crate::runtime::time::Handle>) -> R, + { + f(self.time.as_ref()) + } + pub(crate) fn clock(&self) -> &Clock { &self.clock } @@ -281,6 +291,7 @@ cfg_time! { Enabled { driver: crate::runtime::time::Driver, }, + EnabledAlt(IoStack), Disabled(IoStack), } @@ -293,13 +304,21 @@ cfg_time! { fn create_time_driver( enable: bool, + timer_flavor: crate::runtime::TimerFlavor, io_stack: IoStack, clock: &Clock, ) -> (TimeDriver, TimeHandle) { if enable { - let (driver, handle) = crate::runtime::time::Driver::new(io_stack, clock); - - (TimeDriver::Enabled { driver }, Some(handle)) + match timer_flavor { + crate::runtime::TimerFlavor::Traditional => { + let (driver, handle) = crate::runtime::time::Driver::new(io_stack, clock); + (TimeDriver::Enabled { driver }, Some(handle)) + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + crate::runtime::TimerFlavor::Alternative => { + (TimeDriver::EnabledAlt(io_stack), Some(crate::runtime::time::Driver::new_alt(clock))) + } + } } else { (TimeDriver::Disabled(io_stack), None) } @@ -309,6 +328,7 @@ cfg_time! { pub(crate) fn park(&mut self, handle: &Handle) { match self { TimeDriver::Enabled { driver, .. } => driver.park(handle), + TimeDriver::EnabledAlt(v) => v.park(handle), TimeDriver::Disabled(v) => v.park(handle), } } @@ -316,6 +336,7 @@ cfg_time! { pub(crate) fn park_timeout(&mut self, handle: &Handle, duration: Duration) { match self { TimeDriver::Enabled { driver } => driver.park_timeout(handle, duration), + TimeDriver::EnabledAlt(v) => v.park_timeout(handle, duration), TimeDriver::Disabled(v) => v.park_timeout(handle, duration), } } @@ -323,6 +344,7 @@ cfg_time! { pub(crate) fn shutdown(&mut self, handle: &Handle) { match self { TimeDriver::Enabled { driver } => driver.shutdown(handle), + TimeDriver::EnabledAlt(v) => v.shutdown(handle), TimeDriver::Disabled(v) => v.shutdown(handle), } } @@ -341,6 +363,7 @@ cfg_not_time! { fn create_time_driver( _enable: bool, + _timer_flavor: crate::runtime::TimerFlavor, io_stack: IoStack, _clock: &Clock, ) -> (TimeDriver, TimeHandle) { diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index ae58ce6da86..92a159b38cb 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -389,8 +389,135 @@ cfg_process_driver! { mod process; } +#[cfg_attr(not(feature = "time"), allow(dead_code))] +#[derive(Debug, Copy, Clone, PartialEq)] +pub(crate) enum TimerFlavor { + Traditional, + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Alternative, +} + cfg_time! { pub(crate) mod time; + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + pub(crate) mod time_alt; + + use std::task::{Context, Poll}; + use std::pin::Pin; + + #[derive(Debug)] + pub(crate) enum Timer { + Traditional(time::TimerEntry), + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Alternative(time_alt::Timer), + } + + impl Timer { + #[track_caller] + pub(crate) fn new( + handle: crate::runtime::scheduler::Handle, + deadline: crate::time::Instant, + ) -> Self { + match handle.timer_flavor() { + crate::runtime::TimerFlavor::Traditional => { + Timer::Traditional(time::TimerEntry::new(handle, deadline)) + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + crate::runtime::TimerFlavor::Alternative => { + Timer::Alternative(time_alt::Timer::new(handle, deadline)) + } + } + } + + pub(crate) fn deadline(&self) -> crate::time::Instant { + match self { + Timer::Traditional(entry) => entry.deadline(), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(entry) => entry.deadline(), + } + } + + pub(crate) fn is_elapsed(&self) -> bool { + match self { + Timer::Traditional(entry) => entry.is_elapsed(), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(entry) => entry.is_elapsed(), + } + } + + pub(crate) fn flavor(self: Pin<&Self>) -> TimerFlavor { + match self.get_ref() { + Timer::Traditional(_) => TimerFlavor::Traditional, + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(_) => TimerFlavor::Alternative, + } + } + + pub(crate) fn reset( + self: Pin<&mut Self>, + new_time: crate::time::Instant, + reregister: bool + ) { + // Safety: we never move the inner entries. + let this = unsafe { self.get_unchecked_mut() }; + match this { + Timer::Traditional(entry) => { + // Safety: we never move the inner entries. + unsafe { Pin::new_unchecked(entry).reset(new_time, reregister); } + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(_) => panic!("not implemented yet"), + } + } + + pub(crate) fn poll_elapsed( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // Safety: we never move the inner entries. + let this = unsafe { self.get_unchecked_mut() }; + match this { + Timer::Traditional(entry) => { + // Safety: we never move the inner entries. + unsafe { Pin::new_unchecked(entry).poll_elapsed(cx) } + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(entry) => { + // Safety: we never move the inner entries. + unsafe { Pin::new_unchecked(entry).poll_elapsed(cx).map(Ok) } + } + } + } + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + pub(crate) fn scheduler_handle(&self) -> &crate::runtime::scheduler::Handle { + match self { + Timer::Traditional(_) => unreachable!("we should not call this on Traditional Timer"), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(entry) => entry.scheduler_handle(), + } + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn driver(self: Pin<&Self>) -> &crate::runtime::time::Handle { + match self.get_ref() { + Timer::Traditional(entry) => entry.driver(), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(entry) => entry.driver(), + } + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn clock(self: Pin<&Self>) -> &crate::time::Clock { + match self.get_ref() { + Timer::Traditional(entry) => entry.clock(), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Timer::Alternative(entry) => entry.clock(), + } + } + } } cfg_signal_internal_and_unix! { diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index 3b13cf6d9a4..b505035aa45 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -387,12 +387,7 @@ impl Context { core.metrics.about_to_park(); core.submit_metrics(handle); - let (c, ()) = self.enter(core, || { - driver.park(&handle.driver); - self.defer.wake(); - }); - - core = c; + core = self.park_internal(core, handle, &mut driver, None); core.metrics.unparked(); core.submit_metrics(handle); @@ -413,12 +408,27 @@ impl Context { core.submit_metrics(handle); - let (mut core, ()) = self.enter(core, || { - driver.park_timeout(&handle.driver, Duration::from_millis(0)); + core = self.park_internal(core, handle, &mut driver, Some(Duration::from_millis(0))); + + core.driver = Some(driver); + core + } + + fn park_internal( + &self, + core: Box, + handle: &Handle, + driver: &mut Driver, + duration: Option, + ) -> Box { + let (core, ()) = self.enter(core, || { + match duration { + Some(dur) => driver.park_timeout(&handle.driver, dur), + None => driver.park(&handle.driver), + } self.defer.wake(); }); - core.driver = Some(driver); core } diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index d52216f06f6..3f142120d33 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -24,6 +24,8 @@ cfg_rt_multi_thread! { pub(crate) use multi_thread::MultiThread; } +pub(super) mod util; + use crate::runtime::driver; #[derive(Debug, Clone)] @@ -107,6 +109,48 @@ cfg_rt! { } } + #[cfg(feature = "time")] + pub(crate) fn timer_flavor(&self) -> crate::runtime::TimerFlavor { + match self { + Handle::CurrentThread(_) => crate::runtime::TimerFlavor::Traditional, + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(h) => h.timer_flavor, + } + } + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread", feature = "time"))] + /// Returns true if both handles belong to the same runtime instance. + pub(crate) fn is_same_runtime(&self, other: &Handle) -> bool { + match (self, other) { + (Handle::CurrentThread(a), Handle::CurrentThread(b)) => Arc::ptr_eq(a, b), + #[cfg(feature = "rt-multi-thread")] + (Handle::MultiThread(a), Handle::MultiThread(b)) => Arc::ptr_eq(a, b), + #[cfg(feature = "rt-multi-thread")] + _ => false, // different runtime types + } + } + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread", feature = "time"))] + /// Returns true if the runtime is shutting down. + pub(crate) fn is_shutdown(&self) -> bool { + match self { + Handle::CurrentThread(_) => panic!("the alternative timer implementation is not supported on CurrentThread runtime"), + Handle::MultiThread(h) => h.is_shutdown(), + } + } + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread", feature = "time"))] + /// Push a timer entry that was created outside of this runtime + /// into the runtime-global queue. The pushed timer will be + /// processed by a random worker thread. + pub(crate) fn push_remote_timer(&self, entry_hdl: crate::runtime::time_alt::EntryHandle) { + match self { + Handle::CurrentThread(_) => panic!("the alternative timer implementation is not supported on CurrentThread runtime"), + Handle::MultiThread(h) => h.push_remote_timer(entry_hdl), + } + } + /// Returns true if this is a local runtime and the runtime is owned by the current thread. pub(crate) fn can_spawn_local_on_local_runtime(&self) -> bool { match self { @@ -249,6 +293,17 @@ cfg_rt! { match_flavor!(self, Context(context) => context.defer(waker)); } + #[cfg(all(tokio_unstable, feature = "time", feature = "rt-multi-thread"))] + pub(crate) fn with_time_temp_local_context(&self, f: F) -> R + where + F: FnOnce(Option>) -> R, + { + match self { + Context::CurrentThread(_) => panic!("the alternative timer implementation is not supported on CurrentThread runtime"), + Context::MultiThread(context) => context.with_time_temp_local_context(f), + } + } + cfg_rt_multi_thread! { #[track_caller] pub(crate) fn expect_multi_thread(&self) -> &multi_thread::Context { @@ -273,5 +328,11 @@ cfg_not_rt! { pub(crate) fn current() -> Handle { panic!("{}", crate::util::error::CONTEXT_MISSING_ERROR) } + + #[cfg_attr(not(feature = "time"), allow(dead_code))] + #[track_caller] + pub(crate) fn timer_flavor(&self) -> crate::runtime::TimerFlavor { + panic!("{}", crate::util::error::CONTEXT_MISSING_ERROR) + } } } diff --git a/tokio/src/runtime/scheduler/multi_thread/handle.rs b/tokio/src/runtime/scheduler/multi_thread/handle.rs index 9acfcb270d6..14d65294c08 100644 --- a/tokio/src/runtime/scheduler/multi_thread/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread/handle.rs @@ -5,7 +5,7 @@ use crate::runtime::task::{Notified, Task, TaskHarnessScheduleHooks}; use crate::runtime::{ blocking, driver, task::{self, JoinHandle, SpawnLocation}, - TaskHooks, TaskMeta, + TaskHooks, TaskMeta, TimerFlavor, }; use crate::util::RngSeedGenerator; @@ -17,6 +17,9 @@ cfg_taskdump! { mod taskdump; } +#[cfg(all(tokio_unstable, feature = "time"))] +use crate::loom::sync::atomic::{AtomicBool, Ordering::SeqCst}; + /// Handle to the multi thread scheduler pub(crate) struct Handle { /// Task spawner @@ -33,6 +36,14 @@ pub(crate) struct Handle { /// User-supplied hooks to invoke for things pub(crate) task_hooks: TaskHooks, + + #[cfg_attr(not(feature = "time"), allow(dead_code))] + /// Timer flavor used by the runtime + pub(crate) timer_flavor: TimerFlavor, + + #[cfg(all(tokio_unstable, feature = "time"))] + /// Indicates that the runtime is shutting down. + pub(crate) is_shutdown: AtomicBool, } impl Handle { @@ -50,8 +61,16 @@ impl Handle { Self::bind_new_task(me, future, id, spawned_at) } + #[cfg(all(tokio_unstable, feature = "time"))] + pub(crate) fn is_shutdown(&self) -> bool { + self.is_shutdown + .load(crate::loom::sync::atomic::Ordering::SeqCst) + } + pub(crate) fn shutdown(&self) { self.close(); + #[cfg(all(tokio_unstable, feature = "time"))] + self.is_shutdown.store(true, SeqCst); } #[track_caller] diff --git a/tokio/src/runtime/scheduler/multi_thread/mod.rs b/tokio/src/runtime/scheduler/multi_thread/mod.rs index d85a0ae0a2a..1c5e1a88884 100644 --- a/tokio/src/runtime/scheduler/multi_thread/mod.rs +++ b/tokio/src/runtime/scheduler/multi_thread/mod.rs @@ -41,7 +41,7 @@ use crate::loom::sync::Arc; use crate::runtime::{ blocking, driver::{self, Driver}, - scheduler, Config, + scheduler, Config, TimerFlavor, }; use crate::util::RngSeedGenerator; @@ -61,6 +61,7 @@ impl MultiThread { blocking_spawner: blocking::Spawner, seed_generator: RngSeedGenerator, config: Config, + timer_flavor: TimerFlavor, ) -> (MultiThread, Arc, Launch) { let parker = Parker::new(driver); let (handle, launch) = worker::create( @@ -70,6 +71,7 @@ impl MultiThread { blocking_spawner, seed_generator, config, + timer_flavor, ); (MultiThread, handle, launch) diff --git a/tokio/src/runtime/scheduler/multi_thread/park.rs b/tokio/src/runtime/scheduler/multi_thread/park.rs index b00c648e6d3..79a3507f87c 100644 --- a/tokio/src/runtime/scheduler/multi_thread/park.rs +++ b/tokio/src/runtime/scheduler/multi_thread/park.rs @@ -8,7 +8,7 @@ use crate::runtime::driver::{self, Driver}; use crate::util::TryLock; use std::sync::atomic::Ordering::SeqCst; -use std::time::Duration; +use std::time::{Duration, Instant}; #[cfg(loom)] use crate::runtime::park::CURRENT_THREAD_PARK_COUNT; @@ -70,12 +70,16 @@ impl Parker { self.inner.park(handle); } + /// Parks the current thread for up to `duration`. + /// + /// This function tries to acquire the driver lock. If it succeeds, it + /// parks using the driver. Otherwise, it fails back to using a condvar, + /// unless the duration is zero, in which case it returns immediately. pub(crate) fn park_timeout(&mut self, handle: &driver::Handle, duration: Duration) { - // Only parking with zero is supported... - assert_eq!(duration, Duration::from_millis(0)); - if let Some(mut driver) = self.inner.shared.driver.try_lock() { - driver.park_timeout(handle, duration); + self.inner.park_driver(&mut driver, handle, Some(duration)); + } else if !duration.is_zero() { + self.inner.park_condvar(Some(duration)); } else { // https://github.com/tokio-rs/tokio/issues/6536 // Hacky, but it's just for loom tests. The counter gets incremented during @@ -124,13 +128,20 @@ impl Inner { } if let Some(mut driver) = self.shared.driver.try_lock() { - self.park_driver(&mut driver, handle); + self.park_driver(&mut driver, handle, None); } else { - self.park_condvar(); + self.park_condvar(None); } } - fn park_condvar(&self) { + /// Parks the current thread using a condvar for up to `duration`. + /// + /// If `duration` is `None`, parks indefinitely until notified. + /// + /// # Panics + /// + /// Panics if `duration` is `Some` and the duration is zero. + fn park_condvar(&self, duration: Option) { // Otherwise we need to coordinate going to sleep let mut m = self.mutex.lock(); @@ -154,10 +165,40 @@ impl Inner { Err(actual) => panic!("inconsistent park state; actual = {actual}"), } - loop { - m = self.condvar.wait(m).unwrap(); + let timeout_at = duration.map(|d| { + Instant::now() + .checked_add(d) + // best effort to avoid overflow and still provide a usable timeout + .unwrap_or(Instant::now() + Duration::from_secs(1)) + }); - if self + loop { + let is_timeout; + (m, is_timeout) = match timeout_at { + Some(timeout_at) => { + let dur = timeout_at.saturating_duration_since(Instant::now()); + if !dur.is_zero() { + // Ideally, we would use `condvar.wait_timeout_until` here, but it is not available + // in `loom`. So we manually compute the timeout. + let (m, res) = self.condvar.wait_timeout(m, dur).unwrap(); + (m, res.timed_out()) + } else { + (m, true) + } + } + None => (self.condvar.wait(m).unwrap(), false), + }; + + if is_timeout { + match self.state.swap(EMPTY, SeqCst) { + PARKED_CONDVAR => return, // timed out, and no notification received + NOTIFIED => return, // notification and timeout happened concurrently + actual @ (PARKED_DRIVER | EMPTY) => { + panic!("inconsistent park_timeout state, actual = {actual}") + } + invalid => panic!("invalid park_timeout state, actual = {invalid}"), + } + } else if self .state .compare_exchange(NOTIFIED, EMPTY, SeqCst, SeqCst) .is_ok() @@ -170,7 +211,19 @@ impl Inner { } } - fn park_driver(&self, driver: &mut Driver, handle: &driver::Handle) { + fn park_driver( + &self, + driver: &mut Driver, + handle: &driver::Handle, + duration: Option, + ) { + if duration.as_ref().is_some_and(Duration::is_zero) { + // zero duration doesn't actually park the thread, it just + // polls the I/O events, timers, etc. + driver.park_timeout(handle, Duration::ZERO); + return; + } + match self .state .compare_exchange(EMPTY, PARKED_DRIVER, SeqCst, SeqCst) @@ -191,7 +244,12 @@ impl Inner { Err(actual) => panic!("inconsistent park state; actual = {actual}"), } - driver.park(handle); + if let Some(duration) = duration { + debug_assert_ne!(duration, Duration::ZERO); + driver.park_timeout(handle, duration); + } else { + driver.park(handle); + } match self.state.swap(EMPTY, SeqCst) { NOTIFIED => {} // got a notification, hurray! diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index 7ec3f126467..ae9f2556dfb 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -63,7 +63,9 @@ use crate::runtime::scheduler::multi_thread::{ }; use crate::runtime::scheduler::{inject, Defer, Lock}; use crate::runtime::task::OwnedTasks; -use crate::runtime::{blocking, driver, scheduler, task, Config, SchedulerMetrics, WorkerMetrics}; +use crate::runtime::{ + blocking, driver, scheduler, task, Config, SchedulerMetrics, TimerFlavor, WorkerMetrics, +}; use crate::runtime::{context, TaskHooks}; use crate::task::coop; use crate::util::atomic_cell::AtomicCell; @@ -84,6 +86,15 @@ cfg_not_taskdump! { mod taskdump_mock; } +#[cfg(all(tokio_unstable, feature = "time"))] +use crate::loom::sync::atomic::AtomicBool; + +#[cfg(all(tokio_unstable, feature = "time"))] +use crate::runtime::time_alt; + +#[cfg(all(tokio_unstable, feature = "time"))] +use crate::runtime::scheduler::util; + /// A scheduler worker pub(super) struct Worker { /// Reference to scheduler's handle @@ -115,6 +126,9 @@ struct Core { /// The worker-local run queue. run_queue: queue::Local>, + #[cfg(all(tokio_unstable, feature = "time"))] + time_context: time_alt::LocalContext, + /// True if the worker is currently searching for more work. Searching /// involves attempting to steal from other workers. is_searching: bool, @@ -193,6 +207,12 @@ pub(crate) struct Synced { /// Synchronized state for `Inject`. pub(crate) inject: inject::Synced, + + #[cfg(all(tokio_unstable, feature = "time"))] + /// Timers pending to be registered. + /// This is used to register a timer but the [`Core`] + /// is not available in the current thread. + inject_timers: Vec, } /// Used to communicate with a worker from other threads. @@ -241,6 +261,7 @@ pub(super) fn create( blocking_spawner: blocking::Spawner, seed_generator: RngSeedGenerator, config: Config, + timer_flavor: TimerFlavor, ) -> (Arc, Launch) { let mut cores = Vec::with_capacity(size); let mut remotes = Vec::with_capacity(size); @@ -260,6 +281,8 @@ pub(super) fn create( lifo_slot: None, lifo_enabled: !config.disable_lifo_slot, run_queue, + #[cfg(all(tokio_unstable, feature = "time"))] + time_context: time_alt::LocalContext::new(), is_searching: false, is_shutdown: false, is_traced: false, @@ -287,6 +310,8 @@ pub(super) fn create( synced: Mutex::new(Synced { idle: idle_synced, inject: inject_synced, + #[cfg(all(tokio_unstable, feature = "time"))] + inject_timers: Vec::new(), }), shutdown_cores: Mutex::new(vec![]), trace_status: TraceStatus::new(remotes_len), @@ -298,6 +323,9 @@ pub(super) fn create( driver: driver_handle, blocking_spawner, seed_generator, + timer_flavor, + #[cfg(all(tokio_unstable, feature = "time"))] + is_shutdown: AtomicBool::new(false), }); let mut launch = Launch(vec![]); @@ -552,7 +580,7 @@ impl Context { } else { // Wait for work core = if !self.defer.is_empty() { - self.park_timeout(core, Some(Duration::from_millis(0))) + self.park_yield(core) } else { self.park(core) }; @@ -560,6 +588,21 @@ impl Context { } } + #[cfg(all(tokio_unstable, feature = "time"))] + { + match self.worker.handle.timer_flavor { + TimerFlavor::Traditional => {} + TimerFlavor::Alternative => { + util::time_alt::shutdown_local_timers( + &mut core.time_context.wheel, + &mut core.time_context.canc_rx, + self.worker.handle.take_remote_timers(), + &self.worker.handle.driver, + ); + } + } + } + core.pre_shutdown(&self.worker); // Signal shutdown self.worker.handle.shutdown_core(core); @@ -701,7 +744,7 @@ impl Context { // Call `park` with a 0 timeout. This enables the I/O driver, timer, ... // to run without actually putting the thread to sleep. - core = self.park_timeout(core, Some(Duration::from_millis(0))); + core = self.park_yield(core); // Run regularly scheduled maintenance core.maintenance(&self.worker); @@ -734,7 +777,7 @@ impl Context { core.stats .submit(&self.worker.handle.shared.worker_metrics[self.worker.index]); - core = self.park_timeout(core, None); + core = self.park_internal(core, None); core.stats.unparked(); @@ -753,15 +796,35 @@ impl Context { core } - fn park_timeout(&self, mut core: Box, duration: Option) -> Box { + fn park_yield(&self, core: Box) -> Box { + self.park_internal(core, Some(Duration::from_millis(0))) + } + + fn park_internal(&self, mut core: Box, duration: Option) -> Box { self.assert_lifo_enabled_is_correct(&core); // Take the parker out of core let mut park = core.park.take().expect("park missing"); - // Store `core` in context *self.core.borrow_mut() = Some(core); + #[cfg(feature = "time")] + let (duration, auto_advance_duration) = match self.worker.handle.timer_flavor { + TimerFlavor::Traditional => (duration, None::), + #[cfg(tokio_unstable)] + TimerFlavor::Alternative => { + // Must happens after taking out the parker, as the `Handle::schedule_local` + // will delay the notify if the parker taken out. + // + // See comments in `Handle::schedule_local` for more details. + let MaintainLocalTimer { + park_duration: duration, + auto_advance_duration, + } = self.maintain_local_timers_before_parking(duration); + (duration, auto_advance_duration) + } + }; + // Park thread if let Some(timeout) = duration { park.park_timeout(&self.worker.handle.driver, timeout); @@ -771,16 +834,30 @@ impl Context { self.defer.wake(); + #[cfg(feature = "time")] + match self.worker.handle.timer_flavor { + TimerFlavor::Traditional => { + // suppress unused variable warning + let _ = auto_advance_duration; + } + #[cfg(tokio_unstable)] + TimerFlavor::Alternative => { + // Must happens before placing back the parker, as the `Handle::schedule_local` + // will delay the notify if the parker is still in `core`. + // + // See comments in `Handle::schedule_local` for more details. + self.maintain_local_timers_after_parking(auto_advance_duration); + } + } + // Remove `core` from context core = self.core.borrow_mut().take().expect("core missing"); // Place `park` back in `core` core.park = Some(park); - if core.should_notify_others() { self.worker.handle.notify_parked_local(); } - core } @@ -793,6 +870,138 @@ impl Context { self.defer.defer(waker); } } + + #[cfg(all(tokio_unstable, feature = "time"))] + /// Maintain local timers before parking the resource driver. + /// + /// * Remove cancelled timers from the local timer wheel. + /// * Register remote timers to the local timer wheel. + /// * Adjust the park duration based on + /// * the next timer expiration time. + /// * whether auto-advancing is required (feature = "test-util"). + /// + /// # Returns + /// + /// `(Box, park_duration, auto_advance_duration)` + fn maintain_local_timers_before_parking( + &self, + park_duration: Option, + ) -> MaintainLocalTimer { + let handle = &self.worker.handle; + let mut wake_queue = time_alt::WakeQueue::new(); + + let (should_yield, next_timer) = with_current(|maybe_cx| { + let cx = maybe_cx.expect("function should be called when core is present"); + assert_eq!( + Arc::as_ptr(&cx.worker.handle), + Arc::as_ptr(&self.worker.handle), + "function should be called on the exact same worker" + ); + + let mut maybe_core = cx.core.borrow_mut(); + let core = maybe_core.as_mut().expect("core missing"); + let time_cx = &mut core.time_context; + + util::time_alt::process_registration_queue( + &mut time_cx.registration_queue, + &mut time_cx.wheel, + &time_cx.canc_tx, + &mut wake_queue, + ); + util::time_alt::insert_inject_timers( + &mut time_cx.wheel, + &time_cx.canc_tx, + handle.take_remote_timers(), + &mut wake_queue, + ); + util::time_alt::remove_cancelled_timers(&mut time_cx.wheel, &mut time_cx.canc_rx); + let should_yield = !wake_queue.is_empty(); + + let next_timer = util::time_alt::next_expiration_time(&time_cx.wheel, &handle.driver); + + (should_yield, next_timer) + }); + + wake_queue.wake_all(); + + if should_yield { + MaintainLocalTimer { + park_duration: Some(Duration::from_millis(0)), + auto_advance_duration: None, + } + } else { + // get the minimum duration + let dur = util::time_alt::min_duration(park_duration, next_timer); + if util::time_alt::pre_auto_advance(&handle.driver, dur) { + MaintainLocalTimer { + park_duration: Some(Duration::ZERO), + auto_advance_duration: dur, + } + } else { + MaintainLocalTimer { + park_duration: dur, + auto_advance_duration: None, + } + } + } + } + + #[cfg(all(tokio_unstable, feature = "time"))] + /// Maintain local timers after unparking the resource driver. + /// + /// * Auto-advance time, if required (feature = "test-util"). + /// * Process expired timers. + fn maintain_local_timers_after_parking(&self, auto_advance_duration: Option) { + let handle = &self.worker.handle; + let mut wake_queue = time_alt::WakeQueue::new(); + + with_current(|maybe_cx| { + let cx = maybe_cx.expect("function should be called when core is present"); + assert_eq!( + Arc::as_ptr(&cx.worker.handle), + Arc::as_ptr(&self.worker.handle), + "function should be called on the exact same worker" + ); + + let mut maybe_core = cx.core.borrow_mut(); + let core = maybe_core.as_mut().expect("core missing"); + let time_cx = &mut core.time_context; + + util::time_alt::post_auto_advance(&handle.driver, auto_advance_duration); + util::time_alt::process_expired_timers( + &mut time_cx.wheel, + &handle.driver, + &mut wake_queue, + ); + }); + + wake_queue.wake_all(); + } + + #[cfg(all(tokio_unstable, feature = "time"))] + fn with_core(&self, f: F) -> R + where + F: FnOnce(Option<&mut Core>) -> R, + { + match self.core.borrow_mut().as_mut() { + Some(core) => f(Some(core)), + None => f(None), + } + } + + #[cfg(all(tokio_unstable, feature = "time"))] + pub(crate) fn with_time_temp_local_context(&self, f: F) -> R + where + F: FnOnce(Option>) -> R, + { + self.with_core(|maybe_core| match maybe_core { + Some(core) if core.is_shutdown => f(Some(time_alt::TempLocalContext::new_shutdown())), + Some(core) => f(Some(time_alt::TempLocalContext::new_running( + &mut core.time_context, + ))), + None => f(None), + }) + } } impl Core { @@ -1131,6 +1340,27 @@ impl Handle { } } + #[cfg(all(tokio_unstable, feature = "time"))] + pub(crate) fn push_remote_timer(&self, hdl: time_alt::EntryHandle) { + assert_eq!(self.timer_flavor, TimerFlavor::Alternative); + { + let mut synced = self.shared.synced.lock(); + synced.inject_timers.push(hdl); + } + self.notify_parked_remote(); + } + + #[cfg(all(tokio_unstable, feature = "time"))] + pub(crate) fn take_remote_timers(&self) -> Vec { + assert_eq!(self.timer_flavor, TimerFlavor::Alternative); + // It's ok to lost the race, as another worker is + // draining the inject_timers. + match self.shared.synced.try_lock() { + Some(mut synced) => std::mem::take(&mut synced.inject_timers), + None => Vec::new(), + } + } + pub(super) fn close(&self) { if self .shared @@ -1249,6 +1479,13 @@ impl<'a> Lock for &'a Handle { } } +#[cfg(all(tokio_unstable, feature = "time"))] +/// Returned by [`Context::maintain_local_timers_before_parking`]. +struct MaintainLocalTimer { + park_duration: Option, + auto_advance_duration: Option, +} + #[track_caller] fn with_current(f: impl FnOnce(Option<&Context>) -> R) -> R { use scheduler::Context::MultiThread; diff --git a/tokio/src/runtime/scheduler/util/mod.rs b/tokio/src/runtime/scheduler/util/mod.rs new file mode 100644 index 00000000000..bea582887fe --- /dev/null +++ b/tokio/src/runtime/scheduler/util/mod.rs @@ -0,0 +1,2 @@ +#[cfg(all(tokio_unstable, feature = "time", feature = "rt-multi-thread"))] +pub(in crate::runtime) mod time_alt; diff --git a/tokio/src/runtime/scheduler/util/time_alt.rs b/tokio/src/runtime/scheduler/util/time_alt.rs new file mode 100644 index 00000000000..e6ea35843ac --- /dev/null +++ b/tokio/src/runtime/scheduler/util/time_alt.rs @@ -0,0 +1,181 @@ +use crate::runtime::scheduler::driver; +use crate::runtime::time_alt::cancellation_queue::{Receiver, Sender}; +use crate::runtime::time_alt::{EntryHandle, RegistrationQueue, WakeQueue, Wheel}; +use std::time::Duration; + +pub(crate) fn min_duration(a: Option, b: Option) -> Option { + match (a, b) { + (Some(dur_a), Some(dur_b)) => Some(std::cmp::min(dur_a, dur_b)), + (Some(dur_a), None) => Some(dur_a), + (None, Some(dur_b)) => Some(dur_b), + (None, None) => None, + } +} + +pub(crate) fn process_registration_queue( + registration_queue: &mut RegistrationQueue, + wheel: &mut Wheel, + tx: &Sender, + wake_queue: &mut WakeQueue, +) { + while let Some(hdl) = registration_queue.pop_front() { + if hdl.deadline() <= wheel.elapsed() { + unsafe { + wake_queue.push_front(hdl); + } + } else { + // Safety: the entry is not registered yet + unsafe { + wheel.insert(hdl, tx.clone()); + } + } + } +} + +pub(crate) fn insert_inject_timers( + wheel: &mut Wheel, + tx: &Sender, + inject: Vec, + wake_queue: &mut WakeQueue, +) { + for hdl in inject { + if hdl.deadline() <= wheel.elapsed() { + unsafe { + wake_queue.push_front(hdl); + } + } else { + // Safety: the entry is not registered yet + unsafe { + wheel.insert(hdl, tx.clone()); + } + } + } +} + +pub(crate) fn remove_cancelled_timers(wheel: &mut Wheel, rx: &mut Receiver) { + for hdl in rx.recv_all() { + debug_assert!(hdl.is_cancelled()); + + if hdl.deadline() > wheel.elapsed() { + // Safety: the entry is registered in THIS wheel + unsafe { + wheel.remove(hdl); + } + } + } +} + +pub(crate) fn next_expiration_time(wheel: &Wheel, drv_hdl: &driver::Handle) -> Option { + drv_hdl.with_time(|maybe_time_hdl| { + let Some(time_hdl) = maybe_time_hdl else { + // time driver is not enabled, nothing to do. + return None; + }; + + let clock = drv_hdl.clock(); + let time_source = time_hdl.time_source(); + + wheel.next_expiration_time().map(|tick| { + let now = time_source.now(clock); + time_source.tick_to_duration(tick.saturating_sub(now)) + }) + }) +} + +#[cfg(feature = "test-util")] +pub(crate) fn pre_auto_advance(drv_hdl: &driver::Handle, duration: Option) -> bool { + drv_hdl.with_time(|maybe_time_hdl| { + if maybe_time_hdl.is_none() { + // time driver is not enabled, nothing to do. + return false; + } + + if duration.is_some() { + let clock = drv_hdl.clock(); + if clock.can_auto_advance() { + return true; + } + + false + } else { + false + } + }) +} + +pub(crate) fn process_expired_timers( + wheel: &mut Wheel, + drv_hdl: &driver::Handle, + wake_queue: &mut WakeQueue, +) { + drv_hdl.with_time(|maybe_time_hdl| { + let Some(time_hdl) = maybe_time_hdl else { + // time driver is not enabled, nothing to do. + return; + }; + + let clock = drv_hdl.clock(); + let time_source = time_hdl.time_source(); + + let now = time_source.now(clock); + time_hdl.process_at_time_alt(wheel, now, wake_queue); + }); +} + +pub(crate) fn shutdown_local_timers( + wheel: &mut Wheel, + rx: &mut Receiver, + inject: Vec, + drv_hdl: &driver::Handle, +) { + drv_hdl.with_time(|maybe_time_hdl| { + let Some(time_hdl) = maybe_time_hdl else { + // time driver is not enabled, nothing to do. + return; + }; + + remove_cancelled_timers(wheel, rx); + time_hdl.shutdown_alt(wheel); + + let mut wake_queue = WakeQueue::new(); + // simply wake all unregistered timers + for hdl in inject { + if !hdl.is_cancelled() { + unsafe { + wake_queue.push_front(hdl); + } + } + } + + wake_queue.wake_all(); + }); +} + +#[cfg(feature = "test-util")] +pub(crate) fn post_auto_advance(drv_hdl: &driver::Handle, duration: Option) { + drv_hdl.with_time(|maybe_time_hdl| { + let Some(time_hdl) = maybe_time_hdl else { + // time driver is not enabled, nothing to do. + return; + }; + + if let Some(park_duration) = duration { + let clock = drv_hdl.clock(); + if clock.can_auto_advance() && !time_hdl.did_wake() { + if let Err(msg) = clock.advance(park_duration) { + panic!("{msg}"); + } + } + } + }) +} + +#[cfg(not(feature = "test-util"))] +pub(crate) fn pre_auto_advance(_drv_hdl: &driver::Handle, _duration: Option) -> bool { + false +} + +#[cfg(not(feature = "test-util"))] +pub(crate) fn post_auto_advance(_drv_hdl: &driver::Handle, _duration: Option) { + // No-op in non-test util builds +} diff --git a/tokio/src/runtime/time/handle.rs b/tokio/src/runtime/time/handle.rs index fce791d998c..33319031cc1 100644 --- a/tokio/src/runtime/time/handle.rs +++ b/tokio/src/runtime/time/handle.rs @@ -21,9 +21,15 @@ impl Handle { /// Track that the driver is being unparked pub(crate) fn unpark(&self) { #[cfg(feature = "test-util")] - self.inner - .did_wake - .store(true, std::sync::atomic::Ordering::SeqCst); + match self.inner { + super::Inner::Traditional { ref did_wake, .. } => { + did_wake.store(true, std::sync::atomic::Ordering::SeqCst); + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + super::Inner::Alternative { ref did_wake, .. } => { + did_wake.store(true, std::sync::atomic::Ordering::SeqCst); + } + } } } diff --git a/tokio/src/runtime/time/mod.rs b/tokio/src/runtime/time/mod.rs index 3250dce97f6..cecd5d0f25e 100644 --- a/tokio/src/runtime/time/mod.rs +++ b/tokio/src/runtime/time/mod.rs @@ -18,6 +18,9 @@ pub(crate) use source::TimeSource; mod wheel; +#[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] +use super::time_alt; + use crate::loom::sync::atomic::{AtomicBool, Ordering}; use crate::loom::sync::Mutex; use crate::runtime::driver::{self, IoHandle, IoStack}; @@ -89,22 +92,38 @@ pub(crate) struct Driver { park: IoStack, } -/// Timer state shared between `Driver`, `Handle`, and `Registration`. -struct Inner { - // The state is split like this so `Handle` can access `is_shutdown` without locking the mutex - state: Mutex, - - /// True if the driver is being shutdown. - is_shutdown: AtomicBool, - - // When `true`, a call to `park_timeout` should immediately return and time - // should not advance. One reason for this to be `true` is if the task - // passed to `Runtime::block_on` called `task::yield_now()`. - // - // While it may look racy, it only has any effect when the clock is paused - // and pausing the clock is restricted to a single-threaded runtime. - #[cfg(feature = "test-util")] - did_wake: AtomicBool, +enum Inner { + Traditional { + // The state is split like this so `Handle` can access `is_shutdown` without locking the mutex + state: Mutex, + + /// True if the driver is being shutdown. + is_shutdown: AtomicBool, + + // When `true`, a call to `park_timeout` should immediately return and time + // should not advance. One reason for this to be `true` is if the task + // passed to `Runtime::block_on` called `task::yield_now()`. + // + // While it may look racy, it only has any effect when the clock is paused + // and pausing the clock is restricted to a single-threaded runtime. + #[cfg(feature = "test-util")] + did_wake: AtomicBool, + }, + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Alternative { + /// True if the driver is being shutdown. + is_shutdown: AtomicBool, + + // When `true`, a call to `park_timeout` should immediately return and time + // should not advance. One reason for this to be `true` is if the task + // passed to `Runtime::block_on` called `task::yield_now()`. + // + // While it may look racy, it only has any effect when the clock is paused + // and pausing the clock is restricted to a single-threaded runtime. + #[cfg(feature = "test-util")] + did_wake: AtomicBool, + }, } /// Time state shared which must be protected by a `Mutex` @@ -128,7 +147,7 @@ impl Driver { let handle = Handle { time_source, - inner: Inner { + inner: Inner::Traditional { state: Mutex::new(InnerState { next_wake: None, wheel: wheel::Wheel::new(), @@ -145,6 +164,20 @@ impl Driver { (driver, handle) } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + pub(crate) fn new_alt(clock: &Clock) -> Handle { + let time_source = TimeSource::new(clock); + + Handle { + time_source, + inner: Inner::Alternative { + is_shutdown: AtomicBool::new(false), + #[cfg(feature = "test-util")] + did_wake: AtomicBool::new(false), + }, + } + } + pub(crate) fn park(&mut self, handle: &driver::Handle) { self.park_internal(handle, None); } @@ -160,7 +193,15 @@ impl Driver { return; } - handle.inner.is_shutdown.store(true, Ordering::SeqCst); + match &handle.inner { + Inner::Traditional { is_shutdown, .. } => { + is_shutdown.store(true, Ordering::SeqCst); + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Inner::Alternative { is_shutdown, .. } => { + is_shutdown.store(true, Ordering::SeqCst); + } + } // Advance time forward to the end of time. @@ -295,6 +336,37 @@ impl Handle { waker_list.wake_all(); } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + pub(crate) fn process_at_time_alt( + &self, + wheel: &mut time_alt::Wheel, + mut now: u64, + wake_queue: &mut time_alt::WakeQueue, + ) { + if now < wheel.elapsed() { + // Time went backwards! This normally shouldn't happen as the Rust language + // guarantees that an Instant is monotonic, but can happen when running + // Linux in a VM on a Windows host due to std incorrectly trusting the + // hardware clock to be monotonic. + // + // See for more information. + now = wheel.elapsed(); + } + + wheel.take_expired(now, wake_queue); + } + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + pub(crate) fn shutdown_alt(&self, wheel: &mut time_alt::Wheel) { + // self.is_shutdown.store(true, Ordering::SeqCst); + // Advance time forward to the end of time. + // This will ensure that all timers are fired. + let max_tick = u64::MAX; + let mut wake_queue = time_alt::WakeQueue::new(); + self.process_at_time_alt(wheel, max_tick, &mut wake_queue); + wake_queue.wake_all(); + } + /// Removes a registered timer from the driver. /// /// The timer will be moved to the cancelled state. Wakers will _not_ be @@ -379,8 +451,12 @@ impl Handle { } cfg_test_util! { - fn did_wake(&self) -> bool { - self.inner.did_wake.swap(false, Ordering::SeqCst) + pub(super) fn did_wake(&self) -> bool { + match &self.inner { + Inner::Traditional { did_wake, .. } => did_wake.swap(false, Ordering::SeqCst), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Inner::Alternative { did_wake, .. } => did_wake.swap(false, Ordering::SeqCst), + } } } } @@ -390,12 +466,20 @@ impl Handle { impl Inner { /// Locks the driver's inner structure pub(super) fn lock(&self) -> crate::loom::sync::MutexGuard<'_, InnerState> { - self.state.lock() + match self { + Inner::Traditional { state, .. } => state.lock(), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Inner::Alternative { .. } => unreachable!("unreachable in alternative timer"), + } } // Check whether the driver has been shutdown pub(super) fn is_shutdown(&self) -> bool { - self.is_shutdown.load(Ordering::SeqCst) + match self { + Inner::Traditional { is_shutdown, .. } => is_shutdown.load(Ordering::SeqCst), + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Inner::Alternative { is_shutdown, .. } => is_shutdown.load(Ordering::SeqCst), + } } } diff --git a/tokio/src/runtime/time_alt/cancellation_queue.rs b/tokio/src/runtime/time_alt/cancellation_queue.rs new file mode 100644 index 00000000000..cfbd1ad2fa1 --- /dev/null +++ b/tokio/src/runtime/time_alt/cancellation_queue.rs @@ -0,0 +1,102 @@ +use super::{CancellationQueueEntry, Entry, EntryHandle}; +use crate::loom::sync::{Arc, Mutex}; +use crate::util::linked_list; + +type EntryList = linked_list::LinkedList; + +#[derive(Debug)] +struct Inner { + list: EntryList, +} + +impl Drop for Inner { + fn drop(&mut self) { + // consume all entries + let _ = self.iter().count(); + } +} + +impl Inner { + fn new() -> Self { + Self { + list: EntryList::new(), + } + } + + /// # Safety + /// + /// Behavior is undefined if any of the following conditions are violated: + /// + /// - `hdl` must not in any [`super::cancellation_queue`], and also mus not in any [`super::WakeQueue`]. + unsafe fn push_front(&mut self, hdl: EntryHandle) { + self.list.push_front(hdl); + } + + fn iter(&mut self) -> impl Iterator { + struct Iter { + list: EntryList, + } + + impl Drop for Iter { + fn drop(&mut self) { + while let Some(hdl) = self.list.pop_front() { + drop(hdl); + } + } + } + + impl Iterator for Iter { + type Item = EntryHandle; + + fn next(&mut self) -> Option { + self.list.pop_front() + } + } + + Iter { + list: std::mem::take(&mut self.list), + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Sender { + inner: Arc>, +} + +impl Sender { + /// # Safety + /// + /// Behavior is undefined if any of the following conditions are violated: + /// + /// - `hdl` must not in any cancellation queue. + pub(crate) unsafe fn send(&self, hdl: EntryHandle) { + unsafe { + self.inner.lock().push_front(hdl); + } + } +} + +#[derive(Debug)] +pub(crate) struct Receiver { + inner: Arc>, +} + +impl Receiver { + pub(crate) fn recv_all(&mut self) -> impl Iterator { + self.inner.lock().iter() + } +} + +pub(crate) fn new() -> (Sender, Receiver) { + let inner = Arc::new(Mutex::new(Inner::new())); + ( + Sender { + inner: inner.clone(), + }, + Receiver { inner }, + ) +} + +#[cfg(test)] +mod tests; diff --git a/tokio/src/runtime/time_alt/cancellation_queue/tests.rs b/tokio/src/runtime/time_alt/cancellation_queue/tests.rs new file mode 100644 index 00000000000..b20e316ac9d --- /dev/null +++ b/tokio/src/runtime/time_alt/cancellation_queue/tests.rs @@ -0,0 +1,97 @@ +use super::*; + +use futures::task::noop_waker; + +#[cfg(loom)] +const NUM_ITEMS: usize = 16; + +#[cfg(not(loom))] +const NUM_ITEMS: usize = 64; + +fn new_handle() -> EntryHandle { + EntryHandle::new(0, noop_waker()) +} + +fn model(f: F) { + #[cfg(loom)] + loom::model(f); + + #[cfg(not(loom))] + f(); +} + +#[test] +fn single_thread() { + model(|| { + for i in 0..NUM_ITEMS { + let (tx, mut rx) = new(); + + for _ in 0..i { + unsafe { tx.send(new_handle()) }; + } + + assert_eq!(rx.recv_all().count(), i); + } + }); +} + +#[test] +#[cfg(not(target_os = "wasi"))] // No thread on wasi. +fn multi_thread() { + use crate::loom::sync::atomic::{AtomicUsize, Ordering::SeqCst}; + use crate::loom::sync::Arc; + use crate::loom::thread; + + #[cfg(loom)] + const NUM_THREADS: usize = 3; + #[cfg(not(loom))] + const NUM_THREADS: usize = 8; + + model(|| { + let (tx, mut rx) = new(); + let mut jhs = Vec::new(); + let sent = Arc::new(AtomicUsize::new(0)); + + for _ in 0..NUM_THREADS { + let tx = tx.clone(); + let sent = sent.clone(); + jhs.push(thread::spawn(move || { + for _ in 0..NUM_ITEMS { + unsafe { tx.send(new_handle()) }; + sent.fetch_add(1, SeqCst); + } + })); + } + + let mut count = 0; + loop { + count += rx.recv_all().count(); + if sent.fetch_add(0, SeqCst) == NUM_ITEMS * NUM_THREADS { + jhs.into_iter().for_each(|jh| { + jh.join().unwrap(); + }); + count += rx.recv_all().count(); + break; + } + thread::yield_now(); + } + + assert_eq!(count, NUM_ITEMS * NUM_THREADS); + }) +} + +#[test] +fn drop_iter_should_not_leak_memory() { + model(|| { + let (tx, mut rx) = new(); + + let hdls = (0..NUM_ITEMS).map(|_| new_handle()).collect::>(); + for hdl in hdls.iter() { + unsafe { tx.send(hdl.clone()) }; + } + + drop(rx.recv_all()); + + assert!(hdls.into_iter().all(|hdl| hdl.inner_strong_count() == 1)); + }); +} diff --git a/tokio/src/runtime/time_alt/context.rs b/tokio/src/runtime/time_alt/context.rs new file mode 100644 index 00000000000..76035634a26 --- /dev/null +++ b/tokio/src/runtime/time_alt/context.rs @@ -0,0 +1,47 @@ +use super::{cancellation_queue, RegistrationQueue, Wheel}; + +/// Local context for the time driver, used when the runtime wants to +/// fire/cancel timers. +pub(crate) struct LocalContext { + pub(crate) wheel: Wheel, + pub(crate) registration_queue: RegistrationQueue, + pub(crate) canc_tx: cancellation_queue::Sender, + pub(crate) canc_rx: cancellation_queue::Receiver, +} + +impl LocalContext { + pub(crate) fn new() -> Self { + let (canc_tx, canc_rx) = cancellation_queue::new(); + Self { + wheel: Wheel::new(), + registration_queue: RegistrationQueue::new(), + canc_tx, + canc_rx, + } + } +} + +pub(crate) enum TempLocalContext<'a> { + /// The runtime is running, we can access it. + Running { + registration_queue: &'a mut RegistrationQueue, + elapsed: u64, + }, + #[cfg(feature = "rt-multi-thread")] + /// The runtime is shutting down, no timers can be registered. + Shutdown, +} + +impl<'a> TempLocalContext<'a> { + pub(crate) fn new_running(cx: &'a mut LocalContext) -> Self { + TempLocalContext::Running { + registration_queue: &mut cx.registration_queue, + elapsed: cx.wheel.elapsed(), + } + } + + #[cfg(feature = "rt-multi-thread")] + pub(crate) fn new_shutdown() -> Self { + TempLocalContext::Shutdown + } +} diff --git a/tokio/src/runtime/time_alt/entry.rs b/tokio/src/runtime/time_alt/entry.rs new file mode 100644 index 00000000000..b7b5627e0b1 --- /dev/null +++ b/tokio/src/runtime/time_alt/entry.rs @@ -0,0 +1,276 @@ +use super::cancellation_queue::Sender; +use crate::loom::sync::{Arc, Mutex}; +use crate::util::linked_list; + +use std::marker::PhantomPinned; +use std::ptr::NonNull; +use std::task::Waker; + +pub(super) type EntryList = linked_list::LinkedList; + +#[derive(Debug)] +struct State { + cancelled: bool, + woken_up: bool, + waker: Option, + cancel_tx: Option, +} + +#[derive(Debug)] +pub(crate) struct Entry { + /// The intrusive pointer used by [`super::cancellation_queue`]. + cancel_pointers: linked_list::Pointers, + + /// The intrusive pointer used by any of the following queues: + /// + /// - [`Wheel`] + /// - [`RegistrationQueue`] + /// - [`WakeQueue`] + /// + /// We can guarantee that this pointer is only used by one of the above + /// at any given time. See below for the journey of this pointer. + /// + /// Initially, this pointer is used by the [`RegistrationQueue`]. + /// + /// And then, before parking the resource driver, + /// the scheduler removes the entry from the [`RegistrationQueue`] + /// [`RegistrationQueue`] and insert it into the [`Wheel`]. + /// + /// Finally, after parking the resource driver, the scheduler removes + /// the entry from the [`Wheel`] and insert it into the [`WakeQueue`]. + /// + /// [`RegistrationQueue`]: super::RegistrationQueue + /// [`Wheel`]: super::Wheel + /// [`WakeQueue`]: super::WakeQueue + extra_pointers: linked_list::Pointers, + + /// The tick when this entry is scheduled to expire. + deadline: u64, + + state: Mutex, + + /// Make the type `!Unpin` to prevent LLVM from emitting + /// the `noalias` attribute for mutable references. + /// + /// See . + _pin: PhantomPinned, +} + +// Safety: `Entry` is always in an `Arc`. +unsafe impl linked_list::Link for Entry { + type Handle = Handle; + type Target = Entry; + + fn as_raw(hdl: &Self::Handle) -> NonNull { + unsafe { NonNull::new_unchecked(Arc::as_ptr(&hdl.entry).cast_mut()) } + } + + unsafe fn from_raw(ptr: NonNull) -> Self::Handle { + Handle { + entry: unsafe { Arc::from_raw(ptr.as_ptr()) }, + } + } + + unsafe fn pointers( + target: NonNull, + ) -> NonNull> { + let this = target.as_ptr(); + let field = unsafe { std::ptr::addr_of_mut!((*this).extra_pointers) }; + unsafe { NonNull::new_unchecked(field) } + } +} + +/// An ZST to allow [`super::registration_queue`] to utilize the [`Entry::extra_pointers`] +/// by impl [`linked_list::Link`] as we cannot impl it on [`Entry`] +/// directly due to the conflicting implementations. +/// +/// This type should never be constructed. +pub(super) struct RegistrationQueueEntry; + +// Safety: `Entry` is always in an `Arc`. +unsafe impl linked_list::Link for RegistrationQueueEntry { + type Handle = Handle; + type Target = Entry; + + fn as_raw(hdl: &Self::Handle) -> NonNull { + unsafe { NonNull::new_unchecked(Arc::as_ptr(&hdl.entry).cast_mut()) } + } + + unsafe fn from_raw(ptr: NonNull) -> Self::Handle { + Handle { + entry: unsafe { Arc::from_raw(ptr.as_ptr()) }, + } + } + + unsafe fn pointers( + target: NonNull, + ) -> NonNull> { + let this = target.as_ptr(); + let field = unsafe { std::ptr::addr_of_mut!((*this).extra_pointers) }; + unsafe { NonNull::new_unchecked(field) } + } +} + +/// An ZST to allow [`super::cancellation_queue`] to utilize the [`Entry::cancel_pointers`] +/// by impl [`linked_list::Link`] as we cannot impl it on [`Entry`] +/// directly due to the conflicting implementations. +/// +/// This type should never be constructed. +pub(super) struct CancellationQueueEntry; + +// Safety: `Entry` is always in an `Arc`. +unsafe impl linked_list::Link for CancellationQueueEntry { + type Handle = Handle; + type Target = Entry; + + fn as_raw(hdl: &Self::Handle) -> NonNull { + unsafe { NonNull::new_unchecked(Arc::as_ptr(&hdl.entry).cast_mut()) } + } + + unsafe fn from_raw(ptr: NonNull) -> Self::Handle { + Handle { + entry: unsafe { Arc::from_raw(ptr.as_ptr()) }, + } + } + + unsafe fn pointers( + target: NonNull, + ) -> NonNull> { + let this = target.as_ptr(); + let field = unsafe { std::ptr::addr_of_mut!((*this).cancel_pointers) }; + unsafe { NonNull::new_unchecked(field) } + } +} + +/// An ZST to allow [`super::WakeQueue`] to utilize the [`Entry::extra_pointers`] +/// by impl [`linked_list::Link`] as we cannot impl it on [`Entry`] +/// directly due to the conflicting implementations. +/// +/// This type should never be constructed. +pub(super) struct WakeQueueEntry; + +// Safety: `Entry` is always in an `Arc`. +unsafe impl linked_list::Link for WakeQueueEntry { + type Handle = Handle; + type Target = Entry; + + fn as_raw(hdl: &Self::Handle) -> NonNull { + unsafe { NonNull::new_unchecked(Arc::as_ptr(&hdl.entry).cast_mut()) } + } + + unsafe fn from_raw(ptr: NonNull) -> Self::Handle { + Handle { + entry: unsafe { Arc::from_raw(ptr.as_ptr()) }, + } + } + + unsafe fn pointers( + target: NonNull, + ) -> NonNull> { + let this = target.as_ptr(); + let field = unsafe { std::ptr::addr_of_mut!((*this).extra_pointers) }; + unsafe { NonNull::new_unchecked(field) } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct Handle { + pub(crate) entry: Arc, +} + +impl From<&Handle> for NonNull { + fn from(hdl: &Handle) -> Self { + // Safety: entry is in an `Arc`, so the pointer is valid. + unsafe { NonNull::new_unchecked(Arc::as_ptr(&hdl.entry) as *mut Entry) } + } +} + +impl Handle { + pub(crate) fn new(deadline: u64, waker: Waker) -> Self { + let state = State { + cancelled: false, + woken_up: false, + waker: Some(waker), + cancel_tx: None, + }; + + let entry = Arc::new(Entry { + cancel_pointers: linked_list::Pointers::new(), + extra_pointers: linked_list::Pointers::new(), + deadline, + state: Mutex::new(state), + _pin: PhantomPinned, + }); + + Handle { entry } + } + + /// Wake the entry if it is already in the pending queue of the timer wheel. + pub(crate) fn wake(&self) { + let mut lock = self.entry.state.lock(); + + if !lock.cancelled { + lock.woken_up = true; + if let Some(waker) = lock.waker.take() { + // unlock before calling waker + drop(lock); + waker.wake(); + } + } + } + + pub(crate) fn register_cancel_tx(&self, cancel_tx: Sender) { + let mut lock = self.entry.state.lock(); + if !lock.cancelled && !lock.woken_up { + let old_tx = lock.cancel_tx.replace(cancel_tx); + // don't unlock — poisoning the `Mutex` stops others from using the bad state. + assert!(old_tx.is_none(), "cancel_tx is already registered"); + } + } + + pub(crate) fn register_waker(&self, waker: Waker) { + let mut lock = self.entry.state.lock(); + if !lock.cancelled && !lock.woken_up { + let maybe_old_waker = lock.waker.replace(waker); + // unlock before calling waker + drop(lock); + drop(maybe_old_waker); + } + } + + pub(crate) fn cancel(&self) { + let mut lock = self.entry.state.lock(); + if !lock.cancelled { + lock.cancelled = true; + if let Some(cancel_tx) = lock.cancel_tx.take() { + drop(lock); + + // Safety: we can guarantee that `self` is not in any cancellation queue + // because the `self.cancelled` was just set to `true`. + unsafe { + cancel_tx.send(self.clone()); + } + } + } + } + + pub(crate) fn deadline(&self) -> u64 { + self.entry.deadline + } + + pub(crate) fn is_woken_up(&self) -> bool { + let lock = self.entry.state.lock(); + lock.woken_up + } + + pub(crate) fn is_cancelled(&self) -> bool { + let lock = self.entry.state.lock(); + lock.cancelled + } + + #[cfg(test)] + /// Only used for unit tests. + pub(crate) fn inner_strong_count(&self) -> usize { + Arc::strong_count(&self.entry) + } +} diff --git a/tokio/src/runtime/time_alt/mod.rs b/tokio/src/runtime/time_alt/mod.rs new file mode 100644 index 00000000000..5d528461ced --- /dev/null +++ b/tokio/src/runtime/time_alt/mod.rs @@ -0,0 +1,24 @@ +pub(crate) mod context; +pub(super) use context::{LocalContext, TempLocalContext}; + +pub(crate) mod cancellation_queue; + +mod entry; +pub(crate) use entry::Handle as EntryHandle; +use entry::{CancellationQueueEntry, RegistrationQueueEntry, WakeQueueEntry}; +use entry::{Entry, EntryList}; + +mod registration_queue; +pub(crate) use registration_queue::RegistrationQueue; + +mod timer; +pub(crate) use timer::Timer; + +mod wheel; +pub(super) use wheel::Wheel; + +mod wake_queue; +pub(crate) use wake_queue::WakeQueue; + +#[cfg(test)] +mod tests; diff --git a/tokio/src/runtime/time_alt/registration_queue.rs b/tokio/src/runtime/time_alt/registration_queue.rs new file mode 100644 index 00000000000..d135e5b213b --- /dev/null +++ b/tokio/src/runtime/time_alt/registration_queue.rs @@ -0,0 +1,43 @@ +use super::{Entry, EntryHandle, RegistrationQueueEntry}; +use crate::util::linked_list; + +type EntryList = linked_list::LinkedList; + +/// A queue of entries that need to be registered in the timer wheel. +#[derive(Debug)] +pub(crate) struct RegistrationQueue { + list: EntryList, +} + +impl Drop for RegistrationQueue { + fn drop(&mut self) { + // drain all entries without waking them up + while let Some(hdl) = self.list.pop_front() { + drop(hdl); + } + } +} + +impl RegistrationQueue { + pub(crate) fn new() -> Self { + Self { + list: EntryList::new(), + } + } + + /// # Safety + /// + /// Behavior is undefined if any of the following conditions are violated: + /// + /// - [`Entry::extra_pointers`] of `hdl` must not being used. + pub(crate) unsafe fn push_front(&mut self, hdl: EntryHandle) { + self.list.push_front(hdl); + } + + pub(crate) fn pop_front(&mut self) -> Option { + self.list.pop_front() + } +} + +#[cfg(test)] +mod tests; diff --git a/tokio/src/runtime/time_alt/registration_queue/tests.rs b/tokio/src/runtime/time_alt/registration_queue/tests.rs new file mode 100644 index 00000000000..b6b3699fa3d --- /dev/null +++ b/tokio/src/runtime/time_alt/registration_queue/tests.rs @@ -0,0 +1,53 @@ +use super::*; + +use futures::task::noop_waker; + +#[cfg(loom)] +const NUM_ITEMS: usize = 16; + +#[cfg(not(loom))] +const NUM_ITEMS: usize = 64; + +fn new_handle() -> EntryHandle { + EntryHandle::new(0, noop_waker()) +} + +fn model(f: F) { + #[cfg(loom)] + loom::model(f); + + #[cfg(not(loom))] + f(); +} + +#[test] +fn sanity() { + model(|| { + let mut queue = RegistrationQueue::new(); + for _ in 0..NUM_ITEMS { + unsafe { + queue.push_front(new_handle()); + } + } + for _ in 0..NUM_ITEMS { + assert!(queue.pop_front().is_some()); + } + assert!(queue.pop_front().is_none()); + }); +} + +#[test] +fn drop_should_not_leak_memory() { + model(|| { + let mut queue = RegistrationQueue::new(); + + let hdls = (0..NUM_ITEMS).map(|_| new_handle()).collect::>(); + for hdl in hdls.iter() { + unsafe { queue.push_front(hdl.clone()) }; + } + + drop(queue); + + assert!(hdls.into_iter().all(|hdl| hdl.inner_strong_count() == 1)); + }); +} diff --git a/tokio/src/runtime/time_alt/tests.rs b/tokio/src/runtime/time_alt/tests.rs new file mode 100644 index 00000000000..29015f3bde9 --- /dev/null +++ b/tokio/src/runtime/time_alt/tests.rs @@ -0,0 +1,168 @@ +use super::*; +use crate::loom::thread; + +use futures_test::task::{new_count_waker, AwokenCount}; + +#[cfg(loom)] +const NUM_ITEMS: usize = 16; + +#[cfg(not(loom))] +const NUM_ITEMS: usize = 64; + +fn new_handle() -> (EntryHandle, AwokenCount) { + let (waker, count) = new_count_waker(); + (EntryHandle::new(0, waker), count) +} + +fn model(f: F) { + #[cfg(loom)] + loom::model(f); + + #[cfg(not(loom))] + f(); +} + +#[test] +fn wake_up_in_the_same_thread() { + model(|| { + let mut counts = Vec::new(); + + let mut reg_queue = RegistrationQueue::new(); + + for _ in 0..NUM_ITEMS { + let (hdl, count) = new_handle(); + counts.push(count); + unsafe { + reg_queue.push_front(hdl); + } + } + + let mut wake_queue = WakeQueue::new(); + for _ in 0..NUM_ITEMS { + if let Some(hdl) = reg_queue.pop_front() { + unsafe { + wake_queue.push_front(hdl); + } + } + } + assert!(reg_queue.pop_front().is_none()); + wake_queue.wake_all(); + + assert!(counts.into_iter().all(|c| c.get() == 1)); + }); +} + +#[test] +fn cancel_in_the_same_thread() { + model(|| { + let mut counts = Vec::new(); + let (cancel_tx, mut cancel_rx) = cancellation_queue::new(); + + let mut reg_queue = RegistrationQueue::new(); + + for _ in 0..NUM_ITEMS { + let (hdl, count) = new_handle(); + hdl.register_cancel_tx(cancel_tx.clone()); + counts.push(count); + unsafe { + reg_queue.push_front(hdl.clone()); + } + hdl.cancel(); + } + + // drain the registration queue + while let Some(hdl) = reg_queue.pop_front() { + drop(hdl); + } + + let mut wake_queue = WakeQueue::new(); + for hdl in cancel_rx.recv_all() { + unsafe { + wake_queue.push_front(hdl); + } + } + wake_queue.wake_all(); + + assert!(counts.into_iter().all(|c| c.get() == 0)); + }); +} + +#[test] +fn wake_up_in_the_different_thread() { + model(|| { + let mut counts = Vec::new(); + + let mut hdls = Vec::new(); + let mut reg_queue = RegistrationQueue::new(); + + for _ in 0..NUM_ITEMS { + let (hdl, count) = new_handle(); + counts.push(count); + hdls.push(hdl.clone()); + unsafe { + reg_queue.push_front(hdl); + } + } + + // wake up all handles in a different thread + thread::spawn(move || { + let mut wake_queue = WakeQueue::new(); + for _ in 0..NUM_ITEMS { + if let Some(hdl) = reg_queue.pop_front() { + unsafe { + wake_queue.push_front(hdl); + } + } + } + assert!(reg_queue.pop_front().is_none()); + wake_queue.wake_all(); + assert!(counts.into_iter().all(|c| c.get() == 1)); + }) + .join() + .unwrap(); + }); +} + +#[test] +fn cancel_in_the_different_thread() { + model(|| { + let mut counts = Vec::new(); + let (cancel_tx, mut cancel_rx) = cancellation_queue::new(); + let mut hdls = Vec::new(); + let mut reg_queue = RegistrationQueue::new(); + + for _ in 0..NUM_ITEMS { + let (hdl, count) = new_handle(); + hdl.register_cancel_tx(cancel_tx.clone()); + counts.push(count); + hdls.push(hdl.clone()); + unsafe { + reg_queue.push_front(hdl); + } + } + + // this thread cancel all handles concurrently + let jh = thread::spawn(move || { + // cancel all handles + for hdl in hdls { + hdl.cancel(); + } + }); + + // cancellation queue concurrently + while let Some(hdl) = reg_queue.pop_front() { + drop(hdl); + } + + let mut wake_queue = WakeQueue::new(); + for hdl in cancel_rx.recv_all() { + unsafe { + wake_queue.push_front(hdl); + } + } + wake_queue.wake_all(); + assert!(counts.into_iter().all(|c| c.get() == 0)); + + jh.join().unwrap(); + }) +} diff --git a/tokio/src/runtime/time_alt/timer.rs b/tokio/src/runtime/time_alt/timer.rs new file mode 100644 index 00000000000..178ab81f24e --- /dev/null +++ b/tokio/src/runtime/time_alt/timer.rs @@ -0,0 +1,169 @@ +use super::{EntryHandle, TempLocalContext}; +use crate::runtime::scheduler::Handle as SchedulerHandle; +use crate::time::Instant; + +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[cfg(any(feature = "rt", feature = "rt-multi-thread"))] +use crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR; + +pub(crate) struct Timer { + sched_handle: SchedulerHandle, + + /// The entry in the timing wheel. + /// + /// - `Some` if the timer is registered / pending / woken up / cancelling. + /// - `None` if the timer is unregistered. + entry: Option, + + /// The deadline for the timer. + deadline: Instant, +} + +impl std::fmt::Debug for Timer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Timer") + .field("deadline", &self.deadline) + .finish() + } +} + +impl Drop for Timer { + fn drop(&mut self) { + if let Some(entry) = self.entry.take() { + entry.cancel(); + } + } +} + +impl Timer { + #[track_caller] + pub(crate) fn new(sched_hdl: SchedulerHandle, deadline: Instant) -> Self { + // Panic if the time driver is not enabled + let _ = sched_hdl.driver().time(); + Timer { + sched_handle: sched_hdl, + entry: None, + deadline, + } + } + + pub(crate) fn deadline(&self) -> Instant { + self.deadline + } + + pub(crate) fn is_elapsed(&self) -> bool { + self.entry.as_ref().is_some_and(|entry| entry.is_woken_up()) + } + + fn register(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let this = self.get_mut(); + + with_current_temp_local_context(&this.sched_handle, |maybe_time_cx| { + let deadline = deadline_to_tick(&this.sched_handle, this.deadline); + + match maybe_time_cx { + Some(TempLocalContext::Running { + registration_queue: _, + elapsed, + }) if deadline <= elapsed => Poll::Ready(()), + + Some(TempLocalContext::Running { + registration_queue, + elapsed: _, + }) => { + let hdl = EntryHandle::new(deadline, cx.waker().clone()); + this.entry = Some(hdl.clone()); + unsafe { + registration_queue.push_front(hdl); + } + Poll::Pending + } + #[cfg(feature = "rt-multi-thread")] + Some(TempLocalContext::Shutdown) => panic!("{RUNTIME_SHUTTING_DOWN_ERROR}"), + + _ => { + let hdl = EntryHandle::new(deadline, cx.waker().clone()); + this.entry = Some(hdl.clone()); + push_from_remote(&this.sched_handle, hdl); + Poll::Pending + } + } + }) + } + + pub(crate) fn poll_elapsed(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + match self.entry.as_ref() { + Some(entry) if entry.is_woken_up() => Poll::Ready(()), + Some(entry) => { + entry.register_waker(cx.waker().clone()); + Poll::Pending + } + None => self.register(cx), + } + } + + pub(crate) fn scheduler_handle(&self) -> &SchedulerHandle { + &self.sched_handle + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn driver(&self) -> &crate::runtime::time::Handle { + self.sched_handle.driver().time() + } + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn clock(&self) -> &crate::time::Clock { + self.sched_handle.driver().clock() + } +} + +fn with_current_temp_local_context(hdl: &SchedulerHandle, f: F) -> R +where + F: FnOnce(Option>) -> R, +{ + #[cfg(not(feature = "rt"))] + { + let (_, _) = (hdl, f); + panic!("Tokio runtime is not enabled, cannot access the current wheel"); + } + + #[cfg(feature = "rt")] + { + use crate::runtime::context; + + let is_same_rt = + context::with_current(|cur_hdl| cur_hdl.is_same_runtime(hdl)).unwrap_or_default(); + + if !is_same_rt { + // We don't want to create the timer in one runtime, + // but register it in a different runtime's timer wheel. + f(None) + } else { + context::with_scheduler(|maybe_cx| match maybe_cx { + Some(cx) => cx.with_time_temp_local_context(f), + None => f(None), + }) + } + } +} + +fn push_from_remote(sched_hdl: &SchedulerHandle, entry_hdl: EntryHandle) { + #[cfg(not(feature = "rt"))] + { + let (_, _) = (sched_hdl, entry_hdl); + panic!("Tokio runtime is not enabled, cannot access the current wheel"); + } + + #[cfg(feature = "rt")] + { + assert!(!sched_hdl.is_shutdown(), "{RUNTIME_SHUTTING_DOWN_ERROR}"); + sched_hdl.push_remote_timer(entry_hdl); + } +} + +fn deadline_to_tick(sched_hdl: &SchedulerHandle, deadline: Instant) -> u64 { + let time_hdl = sched_hdl.driver().time(); + time_hdl.time_source().deadline_to_tick(deadline) +} diff --git a/tokio/src/runtime/time_alt/wake_queue.rs b/tokio/src/runtime/time_alt/wake_queue.rs new file mode 100644 index 00000000000..90ab9f6d287 --- /dev/null +++ b/tokio/src/runtime/time_alt/wake_queue.rs @@ -0,0 +1,50 @@ +use super::{Entry, EntryHandle, WakeQueueEntry}; +use crate::util::linked_list; + +type EntryList = linked_list::LinkedList; + +/// A queue of entries that need to be woken up. +#[derive(Debug)] +pub(crate) struct WakeQueue { + list: EntryList, +} + +impl Drop for WakeQueue { + fn drop(&mut self) { + // drain all entries without waking them up + while let Some(hdl) = self.list.pop_front() { + drop(hdl); + } + } +} + +impl WakeQueue { + pub(crate) fn new() -> Self { + Self { + list: EntryList::new(), + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.list.is_empty() + } + + /// # Safety + /// + /// Behavior is undefined if any of the following conditions are violated: + /// + /// - [`Entry::extra_pointers`] of `hdl` must not being used. + pub(crate) unsafe fn push_front(&mut self, hdl: EntryHandle) { + self.list.push_front(hdl); + } + + /// Wakes all entries in the wake queue. + pub(crate) fn wake_all(mut self) { + while let Some(hdl) = self.list.pop_front() { + hdl.wake(); + } + } +} + +#[cfg(test)] +mod tests; diff --git a/tokio/src/runtime/time_alt/wake_queue/tests.rs b/tokio/src/runtime/time_alt/wake_queue/tests.rs new file mode 100644 index 00000000000..f0449ee912b --- /dev/null +++ b/tokio/src/runtime/time_alt/wake_queue/tests.rs @@ -0,0 +1,66 @@ +use super::*; + +use futures_test::task::{new_count_waker, AwokenCount}; + +#[cfg(loom)] +const NUM_ITEMS: usize = 16; + +#[cfg(not(loom))] +const NUM_ITEMS: usize = 64; + +fn new_handle() -> (EntryHandle, AwokenCount) { + let (waker, count) = new_count_waker(); + (EntryHandle::new(0, waker), count) +} + +fn model(f: F) { + #[cfg(loom)] + loom::model(f); + + #[cfg(not(loom))] + f(); +} + +#[test] +fn sanity() { + model(|| { + let mut queue = WakeQueue::new(); + let mut counts = Vec::new(); + + for _ in 0..NUM_ITEMS { + let (hdl, count) = new_handle(); + counts.push(count); + unsafe { + queue.push_front(hdl); + } + } + assert!(!queue.is_empty()); + queue.wake_all(); + assert!(counts.into_iter().all(|c| c.get() == 1)); + }); +} + +#[test] +fn drop_should_not_leak_memory() { + model(|| { + let mut queue = WakeQueue::new(); + + let mut hdls = vec![]; + let mut counts = vec![]; + for _ in 0..NUM_ITEMS { + let (hdl, count) = new_handle(); + hdls.push(hdl); + counts.push(count); + } + + for hdl in hdls.iter() { + unsafe { queue.push_front(hdl.clone()) }; + } + + drop(queue); + + assert!(hdls.into_iter().all(|hdl| hdl.inner_strong_count() == 1)); + // drop should not wake any entries + assert!(counts.into_iter().all(|count| count.get() == 0)); + }); +} diff --git a/tokio/src/runtime/time_alt/wheel/level.rs b/tokio/src/runtime/time_alt/wheel/level.rs new file mode 100644 index 00000000000..99309bfe0fb --- /dev/null +++ b/tokio/src/runtime/time_alt/wheel/level.rs @@ -0,0 +1,194 @@ +use super::{EntryHandle, EntryList}; +use std::ptr::NonNull; +use std::{array, fmt}; + +/// Wheel for a single level in the timer. This wheel contains 64 slots. +pub(crate) struct Level { + level: usize, + + /// Bit field tracking which slots currently contain entries. + /// + /// Using a bit field to track slots that contain entries allows avoiding a + /// scan to find entries. This field is updated when entries are added or + /// removed from a slot. + /// + /// The least-significant bit represents slot zero. + occupied: u64, + + /// Slots. We access these via the EntryInner `current_list` as well, so this needs to be an `UnsafeCell`. + slot: [EntryList; LEVEL_MULT], +} + +/// Indicates when a slot must be processed next. +#[derive(Debug)] +pub(crate) struct Expiration { + /// The level containing the slot. + pub(crate) level: usize, + + /// The slot index. + pub(crate) slot: usize, + + /// The instant at which the slot needs to be processed. + pub(crate) deadline: u64, +} + +/// Level multiplier. +/// +/// Being a power of 2 is very important. +const LEVEL_MULT: usize = 64; + +impl Level { + pub(crate) fn new(level: usize) -> Level { + Level { + level, + occupied: 0, + slot: array::from_fn(|_| EntryList::default()), + } + } + + /// Finds the slot that needs to be processed next and returns the slot and + /// `Instant` at which this slot must be processed. + pub(crate) fn next_expiration(&self, now: u64) -> Option { + // Use the `occupied` bit field to get the index of the next slot that + // needs to be processed. + let slot = self.next_occupied_slot(now)?; + + // From the slot index, calculate the `Instant` at which it needs to be + // processed. This value *must* be in the future with respect to `now`. + + let level_range = level_range(self.level); + let slot_range = slot_range(self.level); + + // Compute the start date of the current level by masking the low bits + // of `now` (`level_range` is a power of 2). + let level_start = now & !(level_range - 1); + let mut deadline = level_start + slot as u64 * slot_range; + + if deadline <= now { + // A timer is in a slot "prior" to the current time. This can occur + // because we do not have an infinite hierarchy of timer levels, and + // eventually a timer scheduled for a very distant time might end up + // being placed in a slot that is beyond the end of all of the + // arrays. + // + // To deal with this, we first limit timers to being scheduled no + // more than MAX_DURATION ticks in the future; that is, they're at + // most one rotation of the top level away. Then, we force timers + // that logically would go into the top+1 level, to instead go into + // the top level's slots. + // + // What this means is that the top level's slots act as a + // pseudo-ring buffer, and we rotate around them indefinitely. If we + // compute a deadline before now, and it's the top level, it + // therefore means we're actually looking at a slot in the future. + debug_assert_eq!(self.level, super::NUM_LEVELS - 1); + + deadline += level_range; + } + + debug_assert!( + deadline >= now, + "deadline={:016X}; now={:016X}; level={}; lr={:016X}, sr={:016X}, slot={}; occupied={:b}", + deadline, + now, + self.level, + level_range, + slot_range, + slot, + self.occupied + ); + + Some(Expiration { + level: self.level, + slot, + deadline, + }) + } + + fn next_occupied_slot(&self, now: u64) -> Option { + if self.occupied == 0 { + return None; + } + + // Get the slot for now using Maths + let now_slot = (now / slot_range(self.level)) as usize; + let occupied = self.occupied.rotate_right(now_slot as u32); + let zeros = occupied.trailing_zeros() as usize; + let slot = (zeros + now_slot) % LEVEL_MULT; + + Some(slot) + } + + pub(crate) unsafe fn add_entry(&mut self, hdl: EntryHandle) { + // Safety: the associated entry must be valid. + let deadline = hdl.deadline(); + let slot = slot_for(deadline, self.level); + + self.slot[slot].push_front(hdl); + + self.occupied |= occupied_bit(slot); + } + + pub(crate) unsafe fn remove_entry(&mut self, hdl: EntryHandle) { + let slot = slot_for(hdl.deadline(), self.level); + + unsafe { self.slot[slot].remove(NonNull::from(&hdl)) }; + if self.slot[slot].is_empty() { + // The bit is currently set + debug_assert!(self.occupied & occupied_bit(slot) != 0); + + // Unset the bit + self.occupied ^= occupied_bit(slot); + } + } + + pub(crate) fn take_slot(&mut self, slot: usize) -> EntryList { + self.occupied &= !occupied_bit(slot); + + std::mem::take(&mut self.slot[slot]) + } +} + +impl fmt::Debug for Level { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Level") + .field("occupied", &self.occupied) + .finish() + } +} + +fn occupied_bit(slot: usize) -> u64 { + 1 << slot +} + +fn slot_range(level: usize) -> u64 { + LEVEL_MULT.pow(level as u32) as u64 +} + +fn level_range(level: usize) -> u64 { + LEVEL_MULT as u64 * slot_range(level) +} + +/// Converts a duration (milliseconds) and a level to a slot position. +fn slot_for(duration: u64, level: usize) -> usize { + ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_slot_for() { + for pos in 0..64 { + assert_eq!(pos as usize, slot_for(pos, 0)); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!(pos, slot_for(a as u64, level)); + } + } + } +} diff --git a/tokio/src/runtime/time_alt/wheel/mod.rs b/tokio/src/runtime/time_alt/wheel/mod.rs new file mode 100644 index 00000000000..f66a91c150c --- /dev/null +++ b/tokio/src/runtime/time_alt/wheel/mod.rs @@ -0,0 +1,293 @@ +mod level; +pub(crate) use self::level::Expiration; +use self::level::Level; + +use super::cancellation_queue::Sender; +use super::{EntryHandle, EntryList, WakeQueue}; + +use std::array; + +/// Timing wheel implementation. +/// +/// This type provides the hashed timing wheel implementation that backs `Timer` +/// and `DelayQueue`. +/// +/// The structure is generic over `T: Stack`. This allows handling timeout data +/// being stored on the heap or in a slab. In order to support the latter case, +/// the slab must be passed into each function allowing the implementation to +/// lookup timer entries. +/// +/// See `Timer` documentation for some implementation notes. +#[derive(Debug)] +pub(crate) struct Wheel { + /// The number of milliseconds elapsed since the wheel started. + elapsed: u64, + + /// Timer wheel. + /// + /// Levels: + /// + /// * 1 ms slots / 64 ms range + /// * 64 ms slots / ~ 4 sec range + /// * ~ 4 sec slots / ~ 4 min range + /// * ~ 4 min slots / ~ 4 hr range + /// * ~ 4 hr slots / ~ 12 day range + /// * ~ 12 day slots / ~ 2 yr range + levels: Box<[Level; NUM_LEVELS]>, +} + +/// Number of levels. Each level has 64 slots. By using 6 levels with 64 slots +/// each, the timer is able to track time up to 2 years into the future with a +/// precision of 1 millisecond. +const NUM_LEVELS: usize = 6; + +/// The maximum duration of a `Sleep`. +pub(super) const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; + +impl Wheel { + /// Creates a new timing wheel. + pub(crate) fn new() -> Wheel { + Wheel { + elapsed: 0, + levels: Box::new(array::from_fn(Level::new)), + } + } + + /// Returns the number of milliseconds that have elapsed since the timing + /// wheel's creation. + pub(crate) fn elapsed(&self) -> u64 { + self.elapsed + } + + /// Inserts an entry into the timing wheel. + /// + /// # Arguments + /// + /// * `hdl`: The entry handle to insert into the wheel. + /// + /// # Safety + /// + /// The caller must ensure: + /// + /// * The entry is not already registered in ANY wheel. + pub(crate) unsafe fn insert(&mut self, hdl: EntryHandle, cancel_tx: Sender) { + let deadline = hdl.deadline(); + + assert!(deadline > self.elapsed); + + hdl.register_cancel_tx(cancel_tx); + + // Get the level at which the entry should be stored + let level = self.level_for(deadline); + unsafe { + self.levels[level].add_entry(hdl); + } + + debug_assert!({ + self.levels[level] + .next_expiration(self.elapsed) + .map(|e| e.deadline >= self.elapsed) + .unwrap_or(true) + }); + } + + /// Removes `item` from the timing wheel. + /// + /// # Safety + /// + /// The caller must ensure: + /// + /// * The entry is already registered in THIS wheel. + pub(crate) unsafe fn remove(&mut self, hdl: EntryHandle) { + let deadline = hdl.deadline(); + debug_assert!( + self.elapsed <= deadline, + "elapsed={}; deadline={}", + self.elapsed, + deadline + ); + + let level = self.level_for(deadline); + unsafe { self.levels[level].remove_entry(hdl.clone()) }; + } + + /// Advances the timer up to the instant represented by `now`. + pub(crate) fn take_expired(&mut self, now: u64, wake_queue: &mut WakeQueue) { + loop { + match self.next_expiration() { + Some(ref expiration) if expiration.deadline <= now => { + self.process_expiration(expiration, wake_queue); + + self.set_elapsed(expiration.deadline); + } + _ => { + // in this case the poll did not indicate an expiration + // _and_ we were not able to find a next expiration in + // the current list of timers. advance to the poll's + // current time and do nothing else. + self.set_elapsed(now); + break; + } + } + } + } + + /// Returns the instant at which the next timeout expires. + fn next_expiration(&self) -> Option { + // Check all levels + for (level_num, level) in self.levels.iter().enumerate() { + if let Some(expiration) = level.next_expiration(self.elapsed) { + // There cannot be any expirations at a higher level that happen + // before this one. + debug_assert!(self.no_expirations_before(level_num + 1, expiration.deadline)); + + return Some(expiration); + } + } + + None + } + + /// Returns the tick at which this timer wheel next needs to perform some + /// processing, or None if there are no timers registered. + pub(crate) fn next_expiration_time(&self) -> Option { + self.next_expiration().map(|ex| ex.deadline) + } + + /// Used for debug assertions + fn no_expirations_before(&self, start_level: usize, before: u64) -> bool { + let mut res = true; + + for level in &self.levels[start_level..] { + if let Some(e2) = level.next_expiration(self.elapsed) { + if e2.deadline < before { + res = false; + } + } + } + + res + } + + /// iteratively find entries that are between the wheel's current + /// time and the expiration time. for each in that population either + /// queue it for notification (in the case of the last level) or tier + /// it down to the next level (in all other cases). + pub(crate) fn process_expiration( + &mut self, + expiration: &Expiration, + wake_queue: &mut WakeQueue, + ) { + // Note that we need to take _all_ of the entries off the list before + // processing any of them. This is important because it's possible that + // those entries might need to be reinserted into the same slot. + // + // This happens only on the highest level, when an entry is inserted + // more than MAX_DURATION into the future. When this happens, we wrap + // around, and process some entries a multiple of MAX_DURATION before + // they actually need to be dropped down a level. We then reinsert them + // back into the same position; we must make sure we don't then process + // those entries again or we'll end up in an infinite loop. + let mut entries = self.take_entries(expiration); + + while let Some(hdl) = entries.pop_back() { + if expiration.level == 0 { + debug_assert_eq!(hdl.deadline(), expiration.deadline); + } + + let deadline = hdl.deadline(); + + if deadline > expiration.deadline { + let level = level_for(expiration.deadline, deadline); + unsafe { + self.levels[level].add_entry(hdl); + } + } else { + unsafe { + wake_queue.push_front(hdl); + } + } + } + } + + fn set_elapsed(&mut self, when: u64) { + assert!( + self.elapsed <= when, + "elapsed={:?}; when={:?}", + self.elapsed, + when + ); + + if when > self.elapsed { + self.elapsed = when; + } + } + + /// Obtains the list of entries that need processing for the given expiration. + fn take_entries(&mut self, expiration: &Expiration) -> EntryList { + self.levels[expiration.level].take_slot(expiration.slot) + } + + fn level_for(&self, when: u64) -> usize { + level_for(self.elapsed, when) + } +} + +fn level_for(elapsed: u64, when: u64) -> usize { + const SLOT_MASK: u64 = (1 << 6) - 1; + + // Mask in the trailing bits ignored by the level calculation in order to cap + // the possible leading zeros + let mut masked = elapsed ^ when | SLOT_MASK; + + if masked >= MAX_DURATION { + // Fudge the timer into the top level + masked = MAX_DURATION - 1; + } + + let leading_zeros = masked.leading_zeros() as usize; + let significant = 63 - leading_zeros; + + significant / NUM_LEVELS +} + +#[cfg(all(test, not(loom)))] +mod test { + use super::*; + + #[test] + fn test_level_for() { + for pos in 0..64 { + assert_eq!(0, level_for(0, pos), "level_for({pos}) -- binary = {pos:b}"); + } + + for level in 1..5 { + for pos in level..64 { + let a = pos * 64_usize.pow(level as u32); + assert_eq!( + level, + level_for(0, a as u64), + "level_for({a}) -- binary = {a:b}" + ); + + if pos > level { + let a = a - 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({a}) -- binary = {a:b}" + ); + } + + if pos < 64 { + let a = a + 1; + assert_eq!( + level, + level_for(0, a as u64), + "level_for({a}) -- binary = {a:b}" + ); + } + } + } + } +} diff --git a/tokio/src/time/sleep.rs b/tokio/src/time/sleep.rs index 87261057bfe..f0bbf5c2fd1 100644 --- a/tokio/src/time/sleep.rs +++ b/tokio/src/time/sleep.rs @@ -1,4 +1,4 @@ -use crate::runtime::time::TimerEntry; +use crate::runtime::Timer; use crate::time::{error::Error, Duration, Instant}; use crate::util::trace; @@ -227,7 +227,7 @@ pin_project! { // The link between the `Sleep` instance and the timer that drives it. #[pin] - entry: TimerEntry, + entry: Timer, } } @@ -253,7 +253,7 @@ impl Sleep { ) -> Sleep { use crate::runtime::scheduler; let handle = scheduler::Handle::current(); - let entry = TimerEntry::new(handle, deadline); + let entry = Timer::new(handle, deadline); #[cfg(all(tokio_unstable, feature = "tracing"))] let inner = { let handle = scheduler::Handle::current(); @@ -362,12 +362,30 @@ impl Sleep { /// without having it wake up the last task that polled it. pub(crate) fn reset_without_reregister(self: Pin<&mut Self>, deadline: Instant) { let mut me = self.project(); - me.entry.as_mut().reset(deadline, false); + match me.entry.as_ref().flavor() { + crate::runtime::TimerFlavor::Traditional => { + me.entry.as_mut().reset(deadline, false); + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + crate::runtime::TimerFlavor::Alternative => { + let handle = me.entry.as_ref().scheduler_handle().clone(); + me.entry.set(Timer::new(handle, deadline)); + } + } } fn reset_inner(self: Pin<&mut Self>, deadline: Instant) { let mut me = self.project(); - me.entry.as_mut().reset(deadline, true); + match me.entry.as_ref().flavor() { + crate::runtime::TimerFlavor::Traditional => { + me.entry.as_mut().reset(deadline, true); + } + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + crate::runtime::TimerFlavor::Alternative => { + let handle = me.entry.as_ref().scheduler_handle().clone(); + me.entry.set(Timer::new(handle, deadline)); + } + } #[cfg(all(tokio_unstable, feature = "tracing"))] { @@ -380,8 +398,8 @@ impl Sleep { tracing::trace_span!("runtime.resource.async_op.poll"); let duration = { - let clock = me.entry.clock(); - let time_source = me.entry.driver().time_source(); + let clock = me.entry.as_ref().clock(); + let time_source = me.entry.as_ref().driver().time_source(); let now = time_source.now(clock); let deadline_tick = time_source.deadline_to_tick(deadline); deadline_tick.saturating_sub(now) diff --git a/tokio/tests/time_alt.rs b/tokio/tests/time_alt.rs new file mode 100644 index 00000000000..360aac3c802 --- /dev/null +++ b/tokio/tests/time_alt.rs @@ -0,0 +1,108 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(tokio_unstable, feature = "time", feature = "rt-multi-thread"))] + +use tokio::runtime::Runtime; +use tokio::time::*; + +fn rt_combinations() -> Vec { + let mut rts = vec![]; + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + #[cfg(tokio_unstable)] + { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_alt_timer() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_alt_timer() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + } + + rts +} + +#[test] +fn sleep() { + const N: u32 = 512; + + for rt in rt_combinations() { + rt.block_on(async { + let mut jhs = vec![]; + + // sleep outside of the worker threads + let now = Instant::now(); + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(now.elapsed() >= Duration::from_millis(10)); + + for _ in 0..N { + let jh = tokio::spawn(async move { + // sleep inside of the worker threads + let now = Instant::now(); + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(now.elapsed() >= Duration::from_millis(10)); + }); + jhs.push(jh); + } + + for jh in jhs { + jh.await.unwrap(); + } + }); + } +} + +#[test] +fn timeout() { + const N: u32 = 512; + + for rt in rt_combinations() { + rt.block_on(async { + let mut jhs = vec![]; + + // timeout outside of the worker threads + let now = Instant::now(); + tokio::time::timeout(Duration::from_millis(10), std::future::pending::<()>()) + .await + .expect_err("timeout should occur"); + assert!(now.elapsed() >= Duration::from_millis(10)); + + for _ in 0..N { + let jh = tokio::spawn(async move { + let now = Instant::now(); + // timeout inside of the worker threads + tokio::time::timeout(Duration::from_millis(10), std::future::pending::<()>()) + .await + .expect_err("timeout should occur"); + assert!(now.elapsed() >= Duration::from_millis(10)); + }); + jhs.push(jh); + } + + for jh in jhs { + jh.await.unwrap(); + } + }); + } +} diff --git a/tokio/tests/time_panic.rs b/tokio/tests/time_panic.rs index 8a997f04529..aa7439cce56 100644 --- a/tokio/tests/time_panic.rs +++ b/tokio/tests/time_panic.rs @@ -13,19 +13,64 @@ mod support { } use support::panic::test_panic; +fn rt_combinations() -> Vec { + let mut rts = vec![]; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + #[cfg(tokio_unstable)] + { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_alt_timer() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_alt_timer() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + } + + rts +} + #[test] fn pause_panic_caller() -> Result<(), Box> { - let panic_location_file = test_panic(|| { - let rt = current_thread(); - - rt.block_on(async { - time::pause(); - time::pause(); + for rt in rt_combinations() { + let panic_location_file = test_panic(|| { + rt.block_on(async { + time::pause(); + time::pause(); + }); }); - }); - // The panic location should be in this file - assert_eq!(&panic_location_file.unwrap(), file!()); + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + } Ok(()) } diff --git a/tokio/tests/time_rt.rs b/tokio/tests/time_rt.rs index 13f888c1791..283967798a1 100644 --- a/tokio/tests/time_rt.rs +++ b/tokio/tests/time_rt.rs @@ -1,28 +1,96 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use tokio::runtime::Runtime; use tokio::time::*; use std::sync::mpsc; +fn rt_combinations() -> Vec { + let mut rts = vec![]; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + #[cfg(tokio_unstable)] + { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_alt_timer() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_alt_timer() + .enable_all() + .build() + .unwrap(); + rts.push(rt); + } + + rts +} + #[cfg(all(feature = "rt-multi-thread", not(target_os = "wasi")))] // Wasi doesn't support threads #[test] fn timer_with_threaded_runtime() { use tokio::runtime::Runtime; - let rt = Runtime::new().unwrap(); - let (tx, rx) = mpsc::channel(); + { + let rt = Runtime::new().unwrap(); + let (tx, rx) = mpsc::channel(); - rt.spawn(async move { - let when = Instant::now() + Duration::from_millis(10); + rt.spawn(async move { + let when = Instant::now() + Duration::from_millis(10); - sleep_until(when).await; - assert!(Instant::now() >= when); + sleep_until(when).await; + assert!(Instant::now() >= when); - tx.send(()).unwrap(); - }); + tx.send(()).unwrap(); + }); - rx.recv().unwrap(); + rx.recv().unwrap(); + } + + #[cfg(tokio_unstable)] + { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_alt_timer() + .build() + .unwrap(); + let (tx, rx) = mpsc::channel(); + + rt.block_on(async move { + let when = Instant::now() + Duration::from_millis(10); + + sleep_until(when).await; + assert!(Instant::now() >= when); + + tx.send(()).unwrap(); + }); + + rx.recv().unwrap(); + } } #[test] @@ -44,8 +112,8 @@ fn timer_with_current_thread_scheduler() { rx.recv().unwrap(); } -#[tokio::test] -async fn starving() { +#[test] +fn starving() { use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -68,23 +136,31 @@ async fn starving() { } } - let when = Instant::now() + Duration::from_millis(10); - let starve = Starve(Box::pin(sleep_until(when)), 0); + for rt in rt_combinations() { + rt.block_on(async { + let when = Instant::now() + Duration::from_millis(10); + let starve = Starve(Box::pin(sleep_until(when)), 0); - starve.await; - assert!(Instant::now() >= when); + starve.await; + assert!(Instant::now() >= when); + }); + } } -#[tokio::test] -async fn timeout_value() { +#[test] +fn timeout_value() { use tokio::sync::oneshot; - let (_tx, rx) = oneshot::channel::<()>(); + for rt in rt_combinations() { + rt.block_on(async { + let (_tx, rx) = oneshot::channel::<()>(); - let now = Instant::now(); - let dur = Duration::from_millis(10); + let now = Instant::now(); + let dur = Duration::from_millis(10); - let res = timeout(dur, rx).await; - assert!(res.is_err()); - assert!(Instant::now() >= now + dur); + let res = timeout(dur, rx).await; + assert!(res.is_err()); + assert!(Instant::now() >= now + dur); + }); + } }