Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rt: implement initial set of task hooks
Browse files Browse the repository at this point in the history
This change implements two hooks for per-task actions, one which is invoked on task spawn, and one which is invoked during task termination.

These hooks initially are only supplied with the task ID (on unstable only), but more information can be added in the future, as the struct used to supply parameters is opaque.

Fixes #3181.
Noah-Kennedy committed Aug 1, 2024
1 parent 1077b0b commit de28721
Showing 18 changed files with 358 additions and 17 deletions.
12 changes: 11 additions & 1 deletion tokio/src/runtime/blocking/schedule.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[cfg(feature = "test-util")]
use crate::runtime::scheduler;
use crate::runtime::task::{self, Task};
use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks};
use crate::runtime::Handle;

/// `task::Schedule` implementation that does nothing (except some bookkeeping
@@ -12,6 +12,7 @@ use crate::runtime::Handle;
pub(crate) struct BlockingSchedule {
#[cfg(feature = "test-util")]
handle: Handle,
hooks: TaskHarnessScheduleHooks,
}

impl BlockingSchedule {
@@ -32,6 +33,9 @@ impl BlockingSchedule {
BlockingSchedule {
#[cfg(feature = "test-util")]
handle: handle.clone(),
hooks: TaskHarnessScheduleHooks {
task_terminate_callback: handle.inner.hooks().task_terminate_callback.clone(),
},
}
}
}
@@ -57,4 +61,10 @@ impl task::Schedule for BlockingSchedule {
fn schedule(&self, _task: task::Notified<Self>) {
unreachable!();
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.hooks.task_terminate_callback.clone(),
}
}
}
104 changes: 103 additions & 1 deletion tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::runtime::handle::Handle;
use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime};
use crate::runtime::{
blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback, TaskMeta,
};
use crate::util::rand::{RngSeed, RngSeedGenerator};

use std::fmt;
@@ -78,6 +80,12 @@ pub struct Builder {
/// To run after each thread is unparked.
pub(super) after_unpark: Option<Callback>,

/// To run before each task is spawned.
pub(super) before_spawn: Option<TaskCallback>,

/// To run after each task is terminated.
pub(super) after_termination: Option<TaskCallback>,

/// Customizable keep alive timeout for `BlockingPool`
pub(super) keep_alive: Option<Duration>,

@@ -290,6 +298,9 @@ impl Builder {
before_park: None,
after_unpark: None,

before_spawn: None,
after_termination: None,

keep_alive: None,

// Defaults for these values depend on the scheduler kind, so we get them
@@ -677,6 +688,91 @@ impl Builder {
self
}

/// Executes function `f` just before a task is spawned.
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
/// invoked immediately.
///
/// This can be used for bookkeeping or monitoring purposes.
///
/// Note: There can only be one spawn callback for a runtime; calling this function more
/// than once replaces the last callback defined, rather than adding to it.
///
/// This *does not* support LocalSet at this time.
///
/// # Examples
///
/// ```
/// # use tokio::runtime;
/// # pub fn main() {
/// let runtime = runtime::Builder::new_current_thread()
/// .on_task_spawn(|_| {
/// println!("spawning task");
/// })
/// .build()
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn(std::future::ready(()));
///
/// for _ in 0..64 {
/// tokio::task::yield_now().await;
/// }
/// })
/// # }
/// ```
#[cfg(not(loom))]
pub fn on_task_spawn<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.before_spawn = Some(std::sync::Arc::new(f));
self
}

/// Executes function `f` just after a task is terminated.
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called.
///
/// This can be used for bookkeeping or monitoring purposes.
///
/// Note: There can only be one task termination callback for a runtime; calling this
/// function more than once replaces the last callback defined, rather than adding to it.
///
/// This *does not* support LocalSet at this time.
///
/// # Examples
///
/// ```
/// # use tokio::runtime;
/// # pub fn main() {
/// let runtime = runtime::Builder::new_current_thread()
/// .on_task_terminate(|_| {
/// println!("killing task");
/// })
/// .build()
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn(std::future::ready(()));
///
/// for _ in 0..64 {
/// tokio::task::yield_now().await;
/// }
/// })
/// # }
/// ```
#[cfg(not(loom))]
pub fn on_task_terminate<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.after_termination = Some(std::sync::Arc::new(f));
self
}

/// Creates the configured `Runtime`.
///
/// The returned `Runtime` instance is ready to spawn tasks.
@@ -1118,6 +1214,8 @@ impl Builder {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
@@ -1269,6 +1367,8 @@ cfg_rt_multi_thread! {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
@@ -1316,6 +1416,8 @@ cfg_rt_multi_thread! {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
8 changes: 7 additions & 1 deletion tokio/src/runtime/config.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
any(not(all(tokio_unstable, feature = "full")), target_family = "wasm"),
allow(dead_code)
)]
use crate::runtime::Callback;
use crate::runtime::{Callback, TaskCallback};
use crate::util::RngSeedGenerator;

pub(crate) struct Config {
@@ -21,6 +21,12 @@ pub(crate) struct Config {
/// Callback for a worker unparking itself
pub(crate) after_unpark: Option<Callback>,

/// To run before each task is spawned.
pub(crate) before_spawn: Option<TaskCallback>,

/// To run after each task is terminated.
pub(crate) after_termination: Option<TaskCallback>,

/// The multi-threaded scheduler includes a per-worker LIFO slot used to
/// store the last scheduled task. This can improve certain usage patterns,
/// especially message passing between tasks. However, this LIFO slot is not
4 changes: 4 additions & 0 deletions tokio/src/runtime/mod.rs
Original file line number Diff line number Diff line change
@@ -379,6 +379,10 @@ cfg_rt! {
pub use dump::Dump;
}

mod task_hooks;
pub (crate) use task_hooks::{TaskHooks, TaskCallback};
pub use task_hooks::TaskMeta;

mod handle;
pub use handle::{EnterGuard, Handle, TryCurrentError};

26 changes: 24 additions & 2 deletions tokio/src/runtime/scheduler/current_thread/mod.rs
Original file line number Diff line number Diff line change
@@ -3,8 +3,12 @@ use crate::loom::sync::atomic::AtomicBool;
use crate::loom::sync::Arc;
use crate::runtime::driver::{self, Driver};
use crate::runtime::scheduler::{self, Defer, Inject};
use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task};
use crate::runtime::{blocking, context, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics};
use crate::runtime::task::{
self, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks,
};
use crate::runtime::{
blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics,
};
use crate::sync::notify::Notify;
use crate::util::atomic_cell::AtomicCell;
use crate::util::{waker_ref, RngSeedGenerator, Wake, WakerRef};
@@ -41,6 +45,9 @@ pub(crate) struct Handle {

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) scheduler_hooks: TaskHooks,
}

/// Data required for executing the scheduler. The struct is passed around to
@@ -131,6 +138,10 @@ impl CurrentThread {
.unwrap_or(DEFAULT_GLOBAL_QUEUE_INTERVAL);

let handle = Arc::new(Handle {
scheduler_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
shared: Shared {
inject: Inject::new(),
owned: OwnedTasks::new(1),
@@ -436,6 +447,11 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

me.scheduler_hooks.spawn(&TaskMeta {
id,
_phantom: Default::default(),
});

if let Some(notified) = notified {
me.schedule(notified);
}
@@ -600,6 +616,12 @@ impl Schedule for Arc<Handle> {
});
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.scheduler_hooks.task_terminate_callback.clone(),
}
}

cfg_unstable! {
fn unhandled_panic(&self) {
use crate::runtime::UnhandledPanic;
11 changes: 10 additions & 1 deletion tokio/src/runtime/scheduler/mod.rs
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ cfg_rt_multi_thread! {
}
}

use crate::runtime::driver;
use crate::runtime::{driver, TaskHooks};

#[derive(Debug, Clone)]
pub(crate) enum Handle {
@@ -151,6 +151,15 @@ cfg_rt! {
}
}

pub(crate) fn hooks(&self) -> &TaskHooks {
match self {
Handle::CurrentThread(h) => &h.scheduler_hooks,
#[cfg(tokio_unstable)]
Handle::MultiThread(h) => &h.scheduler_hooks,
Handle::MultiThreadAlt(h) => &h.scheduler_hooks,
}
}

cfg_rt_multi_thread! {
cfg_unstable! {
pub(crate) fn expect_multi_thread_alt(&self) -> &Arc<multi_thread_alt::Handle> {
9 changes: 9 additions & 0 deletions tokio/src/runtime/scheduler/multi_thread/handle.rs
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ use crate::runtime::scheduler::multi_thread::worker;
use crate::runtime::{
blocking, driver,
task::{self, JoinHandle},
TaskHooks, TaskMeta,
};
use crate::util::RngSeedGenerator;

@@ -28,6 +29,9 @@ pub(crate) struct Handle {

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) scheduler_hooks: TaskHooks,
}

impl Handle {
@@ -51,6 +55,11 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

me.scheduler_hooks.spawn(&TaskMeta {
id,
_phantom: Default::default(),
});

me.schedule_option_task_without_yield(notified);

handle
14 changes: 12 additions & 2 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Original file line number Diff line number Diff line change
@@ -58,15 +58,15 @@
use crate::loom::sync::{Arc, Mutex};
use crate::runtime;
use crate::runtime::context;
use crate::runtime::scheduler::multi_thread::{
idle, queue, Counters, Handle, Idle, Overflow, Parker, Stats, TraceStatus, Unparker,
};
use crate::runtime::scheduler::{inject, Defer, Lock};
use crate::runtime::task::OwnedTasks;
use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks};
use crate::runtime::{
blocking, coop, driver, scheduler, task, Config, SchedulerMetrics, WorkerMetrics,
};
use crate::runtime::{context, TaskHooks};
use crate::util::atomic_cell::AtomicCell;
use crate::util::rand::{FastRand, RngSeedGenerator};

@@ -284,6 +284,10 @@ pub(super) fn create(

let remotes_len = remotes.len();
let handle = Arc::new(Handle {
scheduler_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
shared: Shared {
remotes: remotes.into_boxed_slice(),
inject,
@@ -1037,6 +1041,12 @@ impl task::Schedule for Arc<Handle> {
self.schedule_task(task, false);
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.scheduler_hooks.task_terminate_callback.clone(),
}
}

fn yield_now(&self, task: Notified) {
self.schedule_task(task, true);
}
9 changes: 9 additions & 0 deletions tokio/src/runtime/scheduler/multi_thread_alt/handle.rs
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ use crate::runtime::scheduler::multi_thread_alt::worker;
use crate::runtime::{
blocking, driver,
task::{self, JoinHandle},
TaskHooks, TaskMeta,
};
use crate::util::RngSeedGenerator;

@@ -26,6 +27,9 @@ pub(crate) struct Handle {

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) scheduler_hooks: TaskHooks,
}

impl Handle {
@@ -50,6 +54,11 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

me.scheduler_hooks.spawn(&TaskMeta {
id,
_phantom: Default::default(),
});

if let Some(notified) = notified {
me.shared.schedule_task(notified, false);
}
14 changes: 12 additions & 2 deletions tokio/src/runtime/scheduler/multi_thread_alt/worker.rs
Original file line number Diff line number Diff line change
@@ -58,14 +58,14 @@
use crate::loom::sync::{Arc, Condvar, Mutex, MutexGuard};
use crate::runtime;
use crate::runtime::context;
use crate::runtime::driver::Driver;
use crate::runtime::scheduler::multi_thread_alt::{
idle, queue, stats, Counters, Handle, Idle, Overflow, Stats, TraceStatus,
};
use crate::runtime::scheduler::{self, inject, Lock};
use crate::runtime::task::OwnedTasks;
use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks};
use crate::runtime::{blocking, coop, driver, task, Config, SchedulerMetrics, WorkerMetrics};
use crate::runtime::{context, TaskHooks};
use crate::util::atomic_cell::AtomicCell;
use crate::util::rand::{FastRand, RngSeedGenerator};

@@ -303,6 +303,10 @@ pub(super) fn create(
let (inject, inject_synced) = inject::Shared::new();

let handle = Arc::new(Handle {
scheduler_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
shared: Shared {
remotes: remotes.into_boxed_slice(),
inject,
@@ -1556,6 +1560,12 @@ impl task::Schedule for Arc<Handle> {
self.shared.schedule_task(task, false);
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.scheduler_hooks.task_terminate_callback.clone(),
}
}

fn yield_now(&self, task: Notified) {
self.shared.schedule_task(task, true);
}
9 changes: 6 additions & 3 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ use crate::loom::cell::UnsafeCell;
use crate::runtime::context;
use crate::runtime::task::raw::{self, Vtable};
use crate::runtime::task::state::State;
use crate::runtime::task::{Id, Schedule};
use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks};
use crate::util::linked_list;

use std::num::NonZeroU64;
@@ -185,6 +185,8 @@ pub(super) struct Trailer {
pub(super) owned: linked_list::Pointers<Header>,
/// Consumer task waiting on completion of this task.
pub(super) waker: UnsafeCell<Option<Waker>>,
/// Optional hooks needed in the harness.
pub(super) hooks: TaskHarnessScheduleHooks,
}

generate_addr_of_methods! {
@@ -226,6 +228,7 @@ impl<T: Future, S: Schedule> Cell<T, S> {
let tracing_id = future.id();
let vtable = raw::vtable::<T, S>();
let result = Box::new(Cell {
trailer: Trailer::new(scheduler.hooks()),
header: new_header(
state,
vtable,
@@ -239,7 +242,6 @@ impl<T: Future, S: Schedule> Cell<T, S> {
},
task_id,
},
trailer: Trailer::new(),
});

#[cfg(debug_assertions)]
@@ -459,10 +461,11 @@ impl Header {
}

impl Trailer {
fn new() -> Self {
fn new(hooks: TaskHarnessScheduleHooks) -> Self {
Trailer {
waker: UnsafeCell::new(None),
owned: linked_list::Pointers::new(),
hooks,
}
}

10 changes: 9 additions & 1 deletion tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ use crate::runtime::task::state::{Snapshot, State};
use crate::runtime::task::waker::waker_ref;
use crate::runtime::task::{Id, JoinError, Notified, RawTask, Schedule, Task};

use crate::runtime::TaskMeta;
use std::any::Any;
use std::mem;
use std::mem::ManuallyDrop;
@@ -313,9 +314,16 @@ where

let snapshot = self.state().transition_to_complete();

// We catch panics here in case dropping the future or waking the
// We catch panics here in case dropping the future, invoking a hook or waking the
// JoinHandle panics.
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
if let Some(f) = self.trailer().hooks.task_terminate_callback.as_ref() {
f(&TaskMeta {
id: self.core().task_id,
_phantom: Default::default(),
})
}

if !snapshot.is_join_interested() {
// The `JoinHandle` is not interested in the output of
// this task. It is our responsibility to drop the
9 changes: 9 additions & 0 deletions tokio/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
@@ -210,6 +210,7 @@ use crate::future::Future;
use crate::util::linked_list;
use crate::util::sharded_list;

use crate::runtime::TaskCallback;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::{fmt, mem};
@@ -255,6 +256,12 @@ unsafe impl<S> Sync for UnownedTask<S> {}
/// Task result sent back.
pub(crate) type Result<T> = std::result::Result<T, JoinError>;

/// Hooks for scheduling tasks which are needed in the task harness.
#[derive(Clone)]
pub(crate) struct TaskHarnessScheduleHooks {
pub(crate) task_terminate_callback: Option<TaskCallback>,
}

pub(crate) trait Schedule: Sync + Sized + 'static {
/// The task has completed work and is ready to be released. The scheduler
/// should release it immediately and return it. The task module will batch
@@ -266,6 +273,8 @@ pub(crate) trait Schedule: Sync + Sized + 'static {
/// Schedule the task
fn schedule(&self, task: Notified<Self>);

fn hooks(&self) -> TaskHarnessScheduleHooks;

/// Schedule the task to run in the near future, yielding the thread to
/// other tasks.
fn yield_now(&self, task: Notified<Self>) {
35 changes: 35 additions & 0 deletions tokio/src/runtime/task_hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::marker::PhantomData;

#[derive(Clone)]
pub(crate) struct TaskHooks {
pub(crate) task_spawn_callback: Option<TaskCallback>,
pub(crate) task_terminate_callback: Option<TaskCallback>,
}

impl TaskHooks {
pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) {
if let Some(f) = self.task_spawn_callback.as_ref() {
f(meta)
}
}
}

/// Task metadata supplied to user-provided hooks for task events.
#[allow(missing_debug_implementations)]
pub struct TaskMeta<'a> {
/// The opaque ID of the task.
#[cfg(tokio_unstable)]
pub(crate) id: super::task::Id,
pub(crate) _phantom: PhantomData<&'a ()>,
}

impl<'a> TaskMeta<'a> {
/// Return the opaque ID of the task.
#[cfg(tokio_unstable)]
pub fn id(&self) -> super::task::Id {
self.id
}
}

/// Runs on specific task-related events
pub(crate) type TaskCallback = std::sync::Arc<dyn Fn(&TaskMeta<'_>) + Send + Sync>;
8 changes: 7 additions & 1 deletion tokio/src/runtime/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ use self::noop_scheduler::NoopSchedule;
use self::unowned_wrapper::unowned;

mod noop_scheduler {
use crate::runtime::task::{self, Task};
use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks};

/// `task::Schedule` implementation that does nothing, for testing.
pub(crate) struct NoopSchedule;
@@ -19,6 +19,12 @@ mod noop_scheduler {
fn schedule(&self, _task: task::Notified<Self>) {
unreachable!();
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: None,
}
}
}
}

8 changes: 7 additions & 1 deletion tokio/src/runtime/tests/queue.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::runtime::scheduler::multi_thread::{queue, Stats};
use crate::runtime::task::{self, Schedule, Task};
use crate::runtime::task::{self, Schedule, Task, TaskHarnessScheduleHooks};

use std::cell::RefCell;
use std::thread;
@@ -284,4 +284,10 @@ impl Schedule for Runtime {
fn schedule(&self, _task: task::Notified<Self>) {
unreachable!();
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: None,
}
}
}
9 changes: 8 additions & 1 deletion tokio/src/task/local.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ use crate::loom::cell::UnsafeCell;
use crate::loom::sync::{Arc, Mutex};
#[cfg(tokio_unstable)]
use crate::runtime;
use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task, TaskHarnessScheduleHooks};
use crate::runtime::{context, ThreadId, BOX_FUTURE_THRESHOLD};
use crate::sync::AtomicWaker;
use crate::util::RcCell;
@@ -1071,6 +1071,13 @@ impl task::Schedule for Arc<Shared> {
Shared::schedule(self, task);
}

// localset does not currently support task hooks
fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: None,
}
}

cfg_unstable! {
fn unhandled_panic(&self) {
use crate::runtime::UnhandledPanic;
76 changes: 76 additions & 0 deletions tokio/tests/task_hooks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#![allow(unknown_lints, unexpected_cfgs)]
#![warn(rust_2018_idioms)]
#![cfg(all(feature = "full", tokio_unstable, target_has_atomic = "64"))]

use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};

use tokio::runtime::Builder;

const TASKS: usize = 8;
const ITERATIONS: usize = 64;
/// Assert that the spawn task hook always fires when set.
#[test]
fn spawn_task_hook_fires() {
let count = Arc::new(AtomicUsize::new(0));
let count2 = Arc::clone(&count);

let ids = Arc::new(Mutex::new(HashSet::new()));
let ids2 = Arc::clone(&ids);

let runtime = Builder::new_current_thread()
.on_task_spawn(move |data| {
ids2.lock().unwrap().insert(data.id());

count2.fetch_add(1, Ordering::SeqCst);
})
.build()
.unwrap();

for _ in 0..TASKS {
runtime.spawn(std::future::pending::<()>());
}

let count_realized = count.load(Ordering::SeqCst);
assert_eq!(
TASKS, count_realized,
"Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}",
count_realized
);

let count_ids_realized = ids.lock().unwrap().len();

assert_eq!(
TASKS, count_ids_realized,
"Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}",
count_realized
);
}

/// Assert that the terminate task hook always fires when set.
#[test]
fn terminate_task_hook_fires() {
let count = Arc::new(AtomicUsize::new(0));
let count2 = Arc::clone(&count);

let runtime = Builder::new_current_thread()
.on_task_terminate(move |_data| {
count2.fetch_add(1, Ordering::SeqCst);
})
.build()
.unwrap();

for _ in 0..TASKS {
runtime.spawn(std::future::ready(()));
}

runtime.block_on(async {
// tick the runtime a bunch to close out tasks
for _ in 0..ITERATIONS {
tokio::task::yield_now().await;
}
});

assert_eq!(TASKS, count.load(Ordering::SeqCst));
}

0 comments on commit de28721

Please sign in to comment.