From eb6c6ef23bd55e57b11c33c3f981ba079817792b Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Fri, 10 Jun 2022 14:45:26 -0700 Subject: [PATCH 1/6] Add ThreadPool::broadcast A broadcast runs the closure on every thread in the pool, then collects the results. It's scheduled somewhat like a very soft interrupt -- it won't preempt a thread's local work, but will run before it goes to steal from any other threads. This can be used when you want to precisely split your work per-thread, or to set or retrieve some thread-local data in the pool, e.g. #483. --- rayon-core/src/broadcast/mod.rs | 101 ++++++++++++++++++++++++++++++ rayon-core/src/broadcast/test.rs | 91 +++++++++++++++++++++++++++ rayon-core/src/latch.rs | 11 +++- rayon-core/src/lib.rs | 2 + rayon-core/src/log.rs | 3 + rayon-core/src/registry.rs | 74 ++++++++++++++++++++-- rayon-core/src/scope/mod.rs | 34 +++++----- rayon-core/src/thread_pool/mod.rs | 45 +++++++++++++ src/lib.rs | 1 + 9 files changed, 339 insertions(+), 23 deletions(-) create mode 100644 rayon-core/src/broadcast/mod.rs create mode 100755 rayon-core/src/broadcast/test.rs diff --git a/rayon-core/src/broadcast/mod.rs b/rayon-core/src/broadcast/mod.rs new file mode 100644 index 000000000..0b677b13a --- /dev/null +++ b/rayon-core/src/broadcast/mod.rs @@ -0,0 +1,101 @@ +use crate::job::StackJob; +use crate::registry::{Registry, WorkerThread}; +use crate::scope::ScopeLatch; +use std::fmt; +use std::marker::PhantomData; +use std::sync::Arc; + +mod test; + +/// Executes `op` within every thread in the current threadpool. If this is +/// called from a non-Rayon thread, it will execute in the global threadpool. +/// Any attempts to use `join`, `scope`, or parallel iterators will then operate +/// within that threadpool. When the call has completed on each thread, returns +/// a vector containing all of their return values. +/// +/// For more information, see the [`ThreadPool::broadcast()`][m] method. +/// +/// [m]: struct.ThreadPool.html#method.broadcast +pub fn broadcast(op: OP) -> Vec +where + OP: Fn(BroadcastContext<'_>) -> R + Sync, + R: Send, +{ + // We assert that current registry has not terminated. + unsafe { broadcast_in(op, &Registry::current()) } +} + +/// Provides context to a closure called by `broadcast`. +pub struct BroadcastContext<'a> { + worker: &'a WorkerThread, + + /// Make sure to prevent auto-traits like `Send` and `Sync`. + _marker: PhantomData<&'a mut dyn Fn()>, +} + +impl<'a> BroadcastContext<'a> { + fn new(worker: &WorkerThread) -> BroadcastContext<'_> { + BroadcastContext { + worker, + _marker: PhantomData, + } + } + + /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). + #[inline] + pub fn index(&self) -> usize { + self.worker.index() + } + + /// The number of threads receiving the broadcast in the thread pool. + /// + /// # Future compatibility note + /// + /// Future versions of Rayon might vary the number of threads over time, but + /// this method will always return the number of threads which are actually + /// receiving your particular `broadcast` call. + #[inline] + pub fn num_threads(&self) -> usize { + self.worker.registry().num_threads() + } +} + +impl<'a> fmt::Debug for BroadcastContext<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("BroadcastContext") + .field("index", &self.index()) + .field("num_threads", &self.num_threads()) + .field("pool_id", &self.worker.registry().id()) + .finish() + } +} + +/// Execute `op` on every thread in the pool. It will be executed on each +/// thread when they have nothing else to do locally, before they try to +/// steal work from other threads. This function will not return until all +/// threads have completed the `op`. +/// +/// Unsafe because `registry` must not yet have terminated. +pub(super) unsafe fn broadcast_in(op: OP, registry: &Arc) -> Vec +where + OP: Fn(BroadcastContext<'_>) -> R + Sync, + R: Send, +{ + let f = move |injected| { + let worker_thread = WorkerThread::current(); + assert!(injected && !worker_thread.is_null()); + op(BroadcastContext::new(&*worker_thread)) + }; + + let n_threads = registry.num_threads(); + let current_thread = WorkerThread::current().as_ref(); + let latch = ScopeLatch::with_count(n_threads, current_thread); + let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect(); + let job_refs = jobs.iter().map(|job| job.as_job_ref()); + + registry.inject_broadcast(job_refs); + + // Wait for all jobs to complete, then collect the results, maybe propagating a panic. + latch.wait(current_thread); + jobs.into_iter().map(|job| job.into_result()).collect() +} diff --git a/rayon-core/src/broadcast/test.rs b/rayon-core/src/broadcast/test.rs new file mode 100755 index 000000000..fbe86aec5 --- /dev/null +++ b/rayon-core/src/broadcast/test.rs @@ -0,0 +1,91 @@ +#![cfg(test)] + +use crate::ThreadPoolBuilder; +use std::sync::atomic::{AtomicUsize, Ordering}; + +#[test] +fn broadcast_global() { + let v = crate::broadcast(|ctx| ctx.index()); + assert!(v.into_iter().eq(0..crate::current_num_threads())); +} + +#[test] +fn broadcast_pool() { + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let v = pool.broadcast(|ctx| ctx.index()); + assert!(v.into_iter().eq(0..7)); +} + +#[test] +fn broadcast_self() { + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let v = pool.install(|| crate::broadcast(|ctx| ctx.index())); + assert!(v.into_iter().eq(0..7)); +} + +#[test] +fn broadcast_mutual() { + let count = AtomicUsize::new(0); + let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap(); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.install(|| { + pool2.broadcast(|_| { + pool1.broadcast(|_| { + count.fetch_add(1, Ordering::Relaxed); + }) + }) + }); + assert_eq!(count.into_inner(), 3 * 7); +} + +#[test] +fn broadcast_mutual_sleepy() { + use std::{thread, time}; + + let count = AtomicUsize::new(0); + let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap(); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.install(|| { + thread::sleep(time::Duration::from_secs(1)); + pool2.broadcast(|_| { + thread::sleep(time::Duration::from_secs(1)); + pool1.broadcast(|_| { + thread::sleep(time::Duration::from_millis(100)); + count.fetch_add(1, Ordering::Relaxed); + }) + }) + }); + assert_eq!(count.into_inner(), 3 * 7); +} + +#[test] +fn broadcast_panic_one() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.broadcast(|ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() == 3 { + panic!("Hello, world!"); + } + }) + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +} + +#[test] +fn broadcast_panic_many() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.broadcast(|ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() % 2 == 0 { + panic!("Hello, world!"); + } + }) + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +} diff --git a/rayon-core/src/latch.rs b/rayon-core/src/latch.rs index aa8ce2e88..090929374 100644 --- a/rayon-core/src/latch.rs +++ b/rayon-core/src/latch.rs @@ -286,9 +286,14 @@ pub(super) struct CountLatch { impl CountLatch { #[inline] pub(super) fn new() -> CountLatch { + Self::with_count(1) + } + + #[inline] + pub(super) fn with_count(n: usize) -> CountLatch { CountLatch { core_latch: CoreLatch::new(), - counter: AtomicUsize::new(1), + counter: AtomicUsize::new(n), } } @@ -337,10 +342,10 @@ pub(super) struct CountLockLatch { impl CountLockLatch { #[inline] - pub(super) fn new() -> CountLockLatch { + pub(super) fn with_count(n: usize) -> CountLockLatch { CountLockLatch { lock_latch: LockLatch::new(), - counter: AtomicUsize::new(1), + counter: AtomicUsize::new(n), } } diff --git a/rayon-core/src/lib.rs b/rayon-core/src/lib.rs index fd3b4ce29..917a76931 100644 --- a/rayon-core/src/lib.rs +++ b/rayon-core/src/lib.rs @@ -62,6 +62,7 @@ mod log; #[macro_use] mod private; +mod broadcast; mod job; mod join; mod latch; @@ -75,6 +76,7 @@ mod unwind; mod compile_fail; mod test; +pub use self::broadcast::{broadcast, BroadcastContext}; pub use self::join::{join, join_context}; pub use self::registry::ThreadBuilder; pub use self::scope::{in_place_scope, scope, Scope}; diff --git a/rayon-core/src/log.rs b/rayon-core/src/log.rs index 0daa028c2..7b6daf0ab 100644 --- a/rayon-core/src/log.rs +++ b/rayon-core/src/log.rs @@ -93,6 +93,9 @@ pub(super) enum Event { /// A job was removed from the global queue. JobUninjected { worker: usize }, + /// A job was broadcasted to N threads. + JobBroadcast { count: usize }, + /// When announcing a job, this was the value of the counters we observed. /// /// No effect on thread state, just a debugging event. diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index 02b6a861c..32185ad60 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -17,7 +17,7 @@ use std::io; use std::mem; use std::ptr; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Once}; +use std::sync::{Arc, Mutex, Once}; use std::thread; use std::usize; @@ -27,6 +27,7 @@ pub struct ThreadBuilder { name: Option, stack_size: Option, worker: Worker, + stealer: Stealer, registry: Arc, index: usize, } @@ -50,7 +51,7 @@ impl ThreadBuilder { /// Executes the main loop for this thread. This will not return until the /// thread pool is dropped. pub fn run(self) { - unsafe { main_loop(self.worker, self.registry, self.index) } + unsafe { main_loop(self.worker, self.stealer, self.registry, self.index) } } } @@ -133,6 +134,7 @@ pub(super) struct Registry { thread_infos: Vec, sleep: Sleep, injected_jobs: Injector, + broadcasts: Mutex>>, panic_handler: Option>, start_handler: Option>, exit_handler: Option>, @@ -230,12 +232,21 @@ impl Registry { }) .unzip(); + let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads) + .map(|_| { + let worker = Worker::new_fifo(); + let stealer = worker.stealer(); + (worker, stealer) + }) + .unzip(); + let logger = Logger::new(n_threads); let registry = Arc::new(Registry { logger: logger.clone(), thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(), sleep: Sleep::new(logger, n_threads), injected_jobs: Injector::new(), + broadcasts: Mutex::new(broadcasts), terminate_count: AtomicUsize::new(1), panic_handler: builder.take_panic_handler(), start_handler: builder.take_start_handler(), @@ -245,12 +256,13 @@ impl Registry { // If we return early or panic, make sure to terminate existing threads. let t1000 = Terminator(®istry); - for (index, worker) in workers.into_iter().enumerate() { + for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() { let thread = ThreadBuilder { name: builder.get_thread_name(index), stack_size: builder.get_stack_size(), registry: Arc::clone(®istry), worker, + stealer, index, }; if let Err(e) = builder.get_spawn_handler().spawn(thread) { @@ -376,7 +388,7 @@ impl Registry { } /// Push a job into the "external jobs" queue; it will be taken by - /// whatever worker has nothing to do. Use this is you know that + /// whatever worker has nothing to do. Use this if you know that /// you are not on a worker of this registry. pub(super) fn inject(&self, injected_jobs: &[JobRef]) { self.log(|| JobsInjected { @@ -423,6 +435,40 @@ impl Registry { } } + /// Push a job into each thread's own "external jobs" queue; it will be + /// executed only on that thread, when it has nothing else to do locally, + /// before it tries to steal other work. + /// + /// **Panics** if not given exactly as many jobs as there are threads. + pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator) { + assert_eq!(self.num_threads(), injected_jobs.len()); + self.log(|| JobBroadcast { + count: self.num_threads(), + }); + { + let broadcasts = self.broadcasts.lock().unwrap(); + + // It should not be possible for `state.terminate` to be true + // here. It is only set to true when the user creates (and + // drops) a `ThreadPool`; and, in that case, they cannot be + // calling `inject_broadcast()` later, since they dropped their + // `ThreadPool`. + debug_assert_ne!( + self.terminate_count.load(Ordering::Acquire), + 0, + "inject_broadcast() sees state.terminate as true" + ); + + assert_eq!(broadcasts.len(), injected_jobs.len()); + for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) { + worker.push(job_ref); + } + } + for i in 0..self.num_threads() { + self.sleep.notify_worker_latch_is_set(i); + } + } + /// If already in a worker-thread of this registry, just execute `op`. /// Otherwise, inject `op` in this thread-pool. Either way, block until `op` /// completes and return its return value. If `op` panics, that panic will @@ -592,6 +638,9 @@ pub(super) struct WorkerThread { /// the "worker" half of our local deque worker: Worker, + /// the "stealer" half of the worker's broadcast deque + stealer: Stealer, + /// local queue used for `spawn_fifo` indirection fifo: JobFifo, @@ -687,9 +736,16 @@ impl WorkerThread { if popped_job.is_some() { self.log(|| JobPopped { worker: self.index }); + return popped_job; } - popped_job + loop { + match self.stealer.steal() { + Steal::Success(job) => return Some(job), + Steal::Empty => return None, + Steal::Retry => {} + } + } } /// Wait until the latch is set. Try to keep busy by popping and @@ -797,9 +853,15 @@ impl WorkerThread { /// //////////////////////////////////////////////////////////////////////// -unsafe fn main_loop(worker: Worker, registry: Arc, index: usize) { +unsafe fn main_loop( + worker: Worker, + stealer: Stealer, + registry: Arc, + index: usize, +) { let worker_thread = &WorkerThread { worker, + stealer, fifo: JobFifo::new(), index, rng: XorShift64Star::new(), diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index f92d1614b..4cfdf88aa 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -38,7 +38,7 @@ pub struct ScopeFifo<'scope> { fifos: Vec, } -enum ScopeLatch { +pub(super) enum ScopeLatch { /// A latch for scopes created on a rayon thread which will participate in work- /// stealing while it waits for completion. This thread is not necessarily part /// of the same registry as the scope itself! @@ -687,14 +687,18 @@ impl<'scope> ScopeBase<'scope> { impl ScopeLatch { fn new(owner: Option<&WorkerThread>) -> Self { + Self::with_count(1, owner) + } + + pub(super) fn with_count(count: usize, owner: Option<&WorkerThread>) -> Self { match owner { Some(owner) => ScopeLatch::Stealing { - latch: CountLatch::new(), + latch: CountLatch::with_count(count), registry: Arc::clone(owner.registry()), worker_index: owner.index(), }, None => ScopeLatch::Blocking { - latch: CountLockLatch::new(), + latch: CountLockLatch::with_count(count), }, } } @@ -706,30 +710,32 @@ impl ScopeLatch { } } - fn set(&self) { + pub(super) fn wait(&self, owner: Option<&WorkerThread>) { match self { ScopeLatch::Stealing { latch, registry, worker_index, - } => latch.set_and_tickle_one(registry, *worker_index), - ScopeLatch::Blocking { latch } => latch.set(), + } => unsafe { + let owner = owner.expect("owner thread"); + debug_assert_eq!(registry.id(), owner.registry().id()); + debug_assert_eq!(*worker_index, owner.index()); + owner.wait_until(latch); + }, + ScopeLatch::Blocking { latch } => latch.wait(), } } +} - fn wait(&self, owner: Option<&WorkerThread>) { +impl Latch for ScopeLatch { + fn set(&self) { match self { ScopeLatch::Stealing { latch, registry, worker_index, - } => unsafe { - let owner = owner.expect("owner thread"); - debug_assert_eq!(registry.id(), owner.registry().id()); - debug_assert_eq!(*worker_index, owner.index()); - owner.wait_until(latch); - }, - ScopeLatch::Blocking { latch } => latch.wait(), + } => latch.set_and_tickle_one(registry, *worker_index), + ScopeLatch::Blocking { latch } => latch.set(), } } } diff --git a/rayon-core/src/thread_pool/mod.rs b/rayon-core/src/thread_pool/mod.rs index 2f5977963..0e86a6b1a 100644 --- a/rayon-core/src/thread_pool/mod.rs +++ b/rayon-core/src/thread_pool/mod.rs @@ -3,6 +3,7 @@ //! //! [`ThreadPool`]: struct.ThreadPool.html +use crate::broadcast::{broadcast_in, BroadcastContext}; use crate::join; use crate::registry::{Registry, ThreadSpawn, WorkerThread}; use crate::scope::{do_in_place_scope, do_in_place_scope_fifo}; @@ -109,6 +110,50 @@ impl ThreadPool { self.registry.in_worker(|_, _| op()) } + /// Executes `op` within every thread in the threadpool. Any attempts to use + /// `join`, `scope`, or parallel iterators will then operate within that + /// threadpool. + /// + /// # Warning: thread-local data + /// + /// Because `op` is executing within the Rayon thread-pool, + /// thread-local data from the current thread will not be + /// accessible. + /// + /// # Panics + /// + /// If `op` should panic on one or more threads, exactly one panic + /// will be propagated, only after all threads have completed + /// (or panicked) their own `op`. + /// + /// # Examples + /// + /// ``` + /// # use rayon_core as rayon; + /// use std::sync::atomic::{AtomicUsize, Ordering}; + /// + /// fn main() { + /// let pool = rayon::ThreadPoolBuilder::new().num_threads(5).build().unwrap(); + /// + /// // The argument gives context, including the index of each thread. + /// let v: Vec = pool.broadcast(|ctx| ctx.index() * ctx.index()); + /// assert_eq!(v, &[0, 1, 4, 9, 16]); + /// + /// // The closure can reference the local stack + /// let count = AtomicUsize::new(0); + /// pool.broadcast(|_| count.fetch_add(1, Ordering::Relaxed)); + /// assert_eq!(count.into_inner(), 5); + /// } + /// ``` + pub fn broadcast(&self, op: OP) -> Vec + where + OP: Fn(BroadcastContext<'_>) -> R + Sync, + R: Send, + { + // We assert that `self.registry` has not terminated. + unsafe { broadcast_in(op, &self.registry) } + } + /// Returns the (current) number of threads in the thread pool. /// /// # Future compatibility note diff --git a/src/lib.rs b/src/lib.rs index 1c743fd98..49083aff0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,6 +113,7 @@ pub use rayon_core::ThreadBuilder; pub use rayon_core::ThreadPool; pub use rayon_core::ThreadPoolBuildError; pub use rayon_core::ThreadPoolBuilder; +pub use rayon_core::{broadcast, BroadcastContext}; pub use rayon_core::{current_num_threads, current_thread_index, max_num_threads}; pub use rayon_core::{in_place_scope, scope, Scope}; pub use rayon_core::{in_place_scope_fifo, scope_fifo, ScopeFifo}; From 817c4ccb990a5ff14000e4ba436dfd96672c501a Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Fri, 10 Jun 2022 16:19:15 -0700 Subject: [PATCH 2/6] Add ThreadPool::spawn_broadcast --- rayon-core/src/broadcast/mod.rs | 57 ++++++++++++++- rayon-core/src/broadcast/test.rs | 115 +++++++++++++++++++++++++++++- rayon-core/src/job.rs | 40 ++++++++++- rayon-core/src/lib.rs | 2 +- rayon-core/src/scope/mod.rs | 12 ++-- rayon-core/src/spawn/mod.rs | 4 +- rayon-core/src/thread_pool/mod.rs | 16 ++++- src/lib.rs | 2 +- 8 files changed, 229 insertions(+), 19 deletions(-) diff --git a/rayon-core/src/broadcast/mod.rs b/rayon-core/src/broadcast/mod.rs index 0b677b13a..05de7dfa3 100644 --- a/rayon-core/src/broadcast/mod.rs +++ b/rayon-core/src/broadcast/mod.rs @@ -1,6 +1,7 @@ -use crate::job::StackJob; +use crate::job::{ArcJob, StackJob}; use crate::registry::{Registry, WorkerThread}; use crate::scope::ScopeLatch; +use crate::unwind; use std::fmt; use std::marker::PhantomData; use std::sync::Arc; @@ -25,6 +26,22 @@ where unsafe { broadcast_in(op, &Registry::current()) } } +/// Spawns an asynchronous task on every thread in this thread-pool. This task +/// will run in the implicit, global scope, which means that it may outlast the +/// current stack frame -- therefore, it cannot capture any references onto the +/// stack (you will likely need a `move` closure). +/// +/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method. +/// +/// [m]: struct.ThreadPool.html#method.spawn_broadcast +pub fn spawn_broadcast(op: OP) +where + OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, +{ + // We assert that current registry has not terminated. + unsafe { spawn_broadcast_in(op, &Registry::current()) } +} + /// Provides context to a closure called by `broadcast`. pub struct BroadcastContext<'a> { worker: &'a WorkerThread, @@ -99,3 +116,41 @@ where latch.wait(current_thread); jobs.into_iter().map(|job| job.into_result()).collect() } + +/// Execute `op` on every thread in the pool. It will be executed on each +/// thread when they have nothing else to do locally, before they try to +/// steal work from other threads. This function returns immediately after +/// injecting the jobs. +/// +/// Unsafe because `registry` must not yet have terminated. +pub(super) unsafe fn spawn_broadcast_in(op: OP, registry: &Arc) +where + OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, +{ + let job = ArcJob::new({ + let registry = Arc::clone(registry); + move || { + let worker_thread = WorkerThread::current(); + assert!(!worker_thread.is_null()); + let ctx = BroadcastContext::new(&*worker_thread); + match unwind::halt_unwinding(|| op(ctx)) { + Ok(()) => {} + Err(err) => { + registry.handle_panic(err); + } + } + registry.terminate(); // (*) permit registry to terminate now + } + }); + + let n_threads = registry.num_threads(); + let job_refs = (0..n_threads).map(|_| { + // Ensure that registry cannot terminate until this job has executed + // on each thread. This ref is decremented at the (*) above. + registry.increment_terminate_count(); + + ArcJob::as_job_ref(&job) + }); + + registry.inject_broadcast(job_refs); +} diff --git a/rayon-core/src/broadcast/test.rs b/rayon-core/src/broadcast/test.rs index fbe86aec5..2a3be9a63 100755 --- a/rayon-core/src/broadcast/test.rs +++ b/rayon-core/src/broadcast/test.rs @@ -2,6 +2,8 @@ use crate::ThreadPoolBuilder; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::{thread, time}; #[test] fn broadcast_global() { @@ -9,6 +11,16 @@ fn broadcast_global() { assert!(v.into_iter().eq(0..crate::current_num_threads())); } +#[test] +fn spawn_broadcast_global() { + let (tx, rx) = crossbeam_channel::unbounded(); + crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()); + + let mut v: Vec<_> = rx.into_iter().collect(); + v.sort_unstable(); + assert!(v.into_iter().eq(0..crate::current_num_threads())); +} + #[test] fn broadcast_pool() { let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); @@ -16,6 +28,17 @@ fn broadcast_pool() { assert!(v.into_iter().eq(0..7)); } +#[test] +fn spawn_broadcast_pool() { + let (tx, rx) = crossbeam_channel::unbounded(); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool.spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()); + + let mut v: Vec<_> = rx.into_iter().collect(); + v.sort_unstable(); + assert!(v.into_iter().eq(0..7)); +} + #[test] fn broadcast_self() { let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); @@ -23,6 +46,17 @@ fn broadcast_self() { assert!(v.into_iter().eq(0..7)); } +#[test] +fn spawn_broadcast_self() { + let (tx, rx) = crossbeam_channel::unbounded(); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool.spawn(|| crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap())); + + let mut v: Vec<_> = rx.into_iter().collect(); + v.sort_unstable(); + assert!(v.into_iter().eq(0..7)); +} + #[test] fn broadcast_mutual() { let count = AtomicUsize::new(0); @@ -39,9 +73,24 @@ fn broadcast_mutual() { } #[test] -fn broadcast_mutual_sleepy() { - use std::{thread, time}; +fn spawn_broadcast_mutual() { + let (tx, rx) = crossbeam_channel::unbounded(); + let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap()); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.spawn({ + let pool1 = Arc::clone(&pool1); + move || { + pool2.spawn_broadcast(move |_| { + let tx = tx.clone(); + pool1.spawn_broadcast(move |_| tx.send(()).unwrap()) + }) + } + }); + assert_eq!(rx.into_iter().count(), 3 * 7); +} +#[test] +fn broadcast_mutual_sleepy() { let count = AtomicUsize::new(0); let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap(); let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); @@ -58,6 +107,28 @@ fn broadcast_mutual_sleepy() { assert_eq!(count.into_inner(), 3 * 7); } +#[test] +fn spawn_broadcast_mutual_sleepy() { + let (tx, rx) = crossbeam_channel::unbounded(); + let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap()); + let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool1.spawn({ + let pool1 = Arc::clone(&pool1); + move || { + thread::sleep(time::Duration::from_secs(1)); + pool2.spawn_broadcast(move |_| { + let tx = tx.clone(); + thread::sleep(time::Duration::from_secs(1)); + pool1.spawn_broadcast(move |_| { + thread::sleep(time::Duration::from_millis(100)); + tx.send(()).unwrap(); + }) + }) + } + }); + assert_eq!(rx.into_iter().count(), 3 * 7); +} + #[test] fn broadcast_panic_one() { let count = AtomicUsize::new(0); @@ -74,6 +145,26 @@ fn broadcast_panic_one() { assert!(result.is_err(), "broadcast panic should propagate!"); } +#[test] +fn spawn_broadcast_panic_one() { + let (tx, rx) = crossbeam_channel::unbounded(); + let (panic_tx, panic_rx) = crossbeam_channel::unbounded(); + let pool = ThreadPoolBuilder::new() + .num_threads(7) + .panic_handler(move |e| panic_tx.send(e).unwrap()) + .build() + .unwrap(); + pool.spawn_broadcast(move |ctx| { + tx.send(()).unwrap(); + if ctx.index() == 3 { + panic!("Hello, world!"); + } + }); + drop(pool); // including panic_tx + assert_eq!(rx.into_iter().count(), 7); + assert_eq!(panic_rx.into_iter().count(), 1); +} + #[test] fn broadcast_panic_many() { let count = AtomicUsize::new(0); @@ -89,3 +180,23 @@ fn broadcast_panic_many() { assert_eq!(count.into_inner(), 7); assert!(result.is_err(), "broadcast panic should propagate!"); } + +#[test] +fn spawn_broadcast_panic_many() { + let (tx, rx) = crossbeam_channel::unbounded(); + let (panic_tx, panic_rx) = crossbeam_channel::unbounded(); + let pool = ThreadPoolBuilder::new() + .num_threads(7) + .panic_handler(move |e| panic_tx.send(e).unwrap()) + .build() + .unwrap(); + pool.spawn_broadcast(move |ctx| { + tx.send(()).unwrap(); + if ctx.index() % 2 == 0 { + panic!("Hello, world!"); + } + }); + drop(pool); // including panic_tx + assert_eq!(rx.into_iter().count(), 7); + assert_eq!(panic_rx.into_iter().count(), 4); +} diff --git a/rayon-core/src/job.rs b/rayon-core/src/job.rs index c89d9a7b7..b099d1735 100644 --- a/rayon-core/src/job.rs +++ b/rayon-core/src/job.rs @@ -4,6 +4,7 @@ use crossbeam_deque::{Injector, Steal}; use std::any::Any; use std::cell::UnsafeCell; use std::mem; +use std::sync::Arc; pub(super) enum JobResult { None, @@ -133,8 +134,8 @@ impl HeapJob where BODY: FnOnce() + Send, { - pub(super) fn new(job: BODY) -> Self { - HeapJob { job } + pub(super) fn new(job: BODY) -> Box { + Box::new(HeapJob { job }) } /// Creates a `JobRef` from this job -- note that this hides all @@ -155,6 +156,41 @@ where } } +/// Represents a job stored in an `Arc` -- like `HeapJob`, but may +/// be turned into multiple `JobRef`s and called multiple times. +pub(super) struct ArcJob +where + BODY: Fn() + Send + Sync, +{ + job: BODY, +} + +impl ArcJob +where + BODY: Fn() + Send + Sync, +{ + pub(super) fn new(job: BODY) -> Arc { + Arc::new(ArcJob { job }) + } + + /// Creates a `JobRef` from this job -- note that this hides all + /// lifetimes, so it is up to you to ensure that this JobRef + /// doesn't outlive any data that it closes over. + pub(super) unsafe fn as_job_ref(this: &Arc) -> JobRef { + JobRef::new(Arc::into_raw(Arc::clone(this))) + } +} + +impl Job for ArcJob +where + BODY: Fn() + Send + Sync, +{ + unsafe fn execute(this: *const ()) { + let this = Arc::from_raw(this as *mut Self); + (this.job)(); + } +} + impl JobResult { fn call(func: impl FnOnce(bool) -> T) -> Self { match unwind::halt_unwinding(|| func(true)) { diff --git a/rayon-core/src/lib.rs b/rayon-core/src/lib.rs index 917a76931..b31a2d7e0 100644 --- a/rayon-core/src/lib.rs +++ b/rayon-core/src/lib.rs @@ -76,7 +76,7 @@ mod unwind; mod compile_fail; mod test; -pub use self::broadcast::{broadcast, BroadcastContext}; +pub use self::broadcast::{broadcast, spawn_broadcast, BroadcastContext}; pub use self::join::{join, join_context}; pub use self::registry::ThreadBuilder; pub use self::scope::{in_place_scope, scope, Scope}; diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index 4cfdf88aa..f9e76e85a 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -540,10 +540,8 @@ impl<'scope> Scope<'scope> { { self.base.increment(); unsafe { - let job_ref = Box::new(HeapJob::new(move || { - self.base.execute_job(move || body(self)) - })) - .into_job_ref(); + let job_ref = + HeapJob::new(move || self.base.execute_job(move || body(self))).into_job_ref(); // Since `Scope` implements `Sync`, we can't be sure that we're still in a // thread of this pool, so we can't just push to the local worker thread. @@ -581,10 +579,8 @@ impl<'scope> ScopeFifo<'scope> { { self.base.increment(); unsafe { - let job_ref = Box::new(HeapJob::new(move || { - self.base.execute_job(move || body(self)) - })) - .into_job_ref(); + let job_ref = + HeapJob::new(move || self.base.execute_job(move || body(self))).into_job_ref(); // If we're in the pool, use our scope's private fifo for this thread to execute // in a locally-FIFO order. Otherwise, just use the pool's global injector. diff --git a/rayon-core/src/spawn/mod.rs b/rayon-core/src/spawn/mod.rs index 3bdb3db0e..dc5725941 100644 --- a/rayon-core/src/spawn/mod.rs +++ b/rayon-core/src/spawn/mod.rs @@ -91,7 +91,7 @@ where // executed. This ref is decremented at the (*) below. registry.increment_terminate_count(); - Box::new(HeapJob::new({ + HeapJob::new({ let registry = Arc::clone(registry); move || { match unwind::halt_unwinding(func) { @@ -102,7 +102,7 @@ where } registry.terminate(); // (*) permit registry to terminate now } - })) + }) .into_job_ref() } diff --git a/rayon-core/src/thread_pool/mod.rs b/rayon-core/src/thread_pool/mod.rs index 0e86a6b1a..98ea0bdde 100644 --- a/rayon-core/src/thread_pool/mod.rs +++ b/rayon-core/src/thread_pool/mod.rs @@ -3,7 +3,7 @@ //! //! [`ThreadPool`]: struct.ThreadPool.html -use crate::broadcast::{broadcast_in, BroadcastContext}; +use crate::broadcast::{self, BroadcastContext}; use crate::join; use crate::registry::{Registry, ThreadSpawn, WorkerThread}; use crate::scope::{do_in_place_scope, do_in_place_scope_fifo}; @@ -151,7 +151,7 @@ impl ThreadPool { R: Send, { // We assert that `self.registry` has not terminated. - unsafe { broadcast_in(op, &self.registry) } + unsafe { broadcast::broadcast_in(op, &self.registry) } } /// Returns the (current) number of threads in the thread pool. @@ -320,6 +320,18 @@ impl ThreadPool { // We assert that `self.registry` has not terminated. unsafe { spawn::spawn_fifo_in(op, &self.registry) } } + + /// Spawns an asynchronous task on every thread in this thread-pool. This task + /// will run in the implicit, global scope, which means that it may outlast the + /// current stack frame -- therefore, it cannot capture any references onto the + /// stack (you will likely need a `move` closure). + pub fn spawn_broadcast(&self, op: OP) + where + OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static, + { + // We assert that `self.registry` has not terminated. + unsafe { broadcast::spawn_broadcast_in(op, &self.registry) } + } } impl Drop for ThreadPool { diff --git a/src/lib.rs b/src/lib.rs index 49083aff0..25a5e16a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,7 +113,7 @@ pub use rayon_core::ThreadBuilder; pub use rayon_core::ThreadPool; pub use rayon_core::ThreadPoolBuildError; pub use rayon_core::ThreadPoolBuilder; -pub use rayon_core::{broadcast, BroadcastContext}; +pub use rayon_core::{broadcast, spawn_broadcast, BroadcastContext}; pub use rayon_core::{current_num_threads, current_thread_index, max_num_threads}; pub use rayon_core::{in_place_scope, scope, Scope}; pub use rayon_core::{in_place_scope_fifo, scope_fifo, ScopeFifo}; From c7a3172850f9359d9ffb5811433ff8e2417fd1c5 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Fri, 10 Jun 2022 17:22:50 -0700 Subject: [PATCH 3/6] Add Scope::spawn_broadcast --- rayon-core/src/broadcast/mod.rs | 22 ++++----- rayon-core/src/scope/mod.rs | 48 +++++++++++++++++- rayon-core/src/scope/test.rs | 88 ++++++++++++++++++++++++++++++++- 3 files changed, 144 insertions(+), 14 deletions(-) diff --git a/rayon-core/src/broadcast/mod.rs b/rayon-core/src/broadcast/mod.rs index 05de7dfa3..bece9b990 100644 --- a/rayon-core/src/broadcast/mod.rs +++ b/rayon-core/src/broadcast/mod.rs @@ -51,11 +51,13 @@ pub struct BroadcastContext<'a> { } impl<'a> BroadcastContext<'a> { - fn new(worker: &WorkerThread) -> BroadcastContext<'_> { - BroadcastContext { - worker, + pub(super) fn with(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R { + let worker_thread = WorkerThread::current(); + assert!(!worker_thread.is_null()); + f(BroadcastContext { + worker: unsafe { &*worker_thread }, _marker: PhantomData, - } + }) } /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). @@ -98,10 +100,9 @@ where OP: Fn(BroadcastContext<'_>) -> R + Sync, R: Send, { - let f = move |injected| { - let worker_thread = WorkerThread::current(); - assert!(injected && !worker_thread.is_null()); - op(BroadcastContext::new(&*worker_thread)) + let f = move |injected: bool| { + debug_assert!(injected); + BroadcastContext::with(&op) }; let n_threads = registry.num_threads(); @@ -130,10 +131,7 @@ where let job = ArcJob::new({ let registry = Arc::clone(registry); move || { - let worker_thread = WorkerThread::current(); - assert!(!worker_thread.is_null()); - let ctx = BroadcastContext::new(&*worker_thread); - match unwind::halt_unwinding(|| op(ctx)) { + match unwind::halt_unwinding(|| BroadcastContext::with(&op)) { Ok(()) => {} Err(err) => { registry.handle_panic(err); diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index f9e76e85a..7eadf1f19 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -5,7 +5,8 @@ //! [`in_place_scope()`]: fn.in_place_scope.html //! [`join()`]: ../join/join.fn.html -use crate::job::{HeapJob, JobFifo}; +use crate::broadcast::BroadcastContext; +use crate::job::{ArcJob, HeapJob, JobFifo}; use crate::latch::{CountLatch, CountLockLatch, Latch}; use crate::registry::{global_registry, in_worker, Registry, WorkerThread}; use crate::unwind; @@ -549,6 +550,22 @@ impl<'scope> Scope<'scope> { self.base.registry.inject_or_push(job_ref); } } + + /// Spawns a job into every thread of the fork-join scope `self`. This job will + /// execute on each thread sometime before the fork-join scope completes. The + /// job is specified as a closure, and this closure receives its own reference + /// to the scope `self` as argument, as well as a `BroadcastContext`. + pub fn spawn_broadcast(&self, body: BODY) + where + BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, + { + let job = ArcJob::new(move || { + let body = &body; + self.base + .execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx))) + }); + unsafe { self.base.inject_broadcast(job) } + } } impl<'scope> ScopeFifo<'scope> { @@ -593,6 +610,22 @@ impl<'scope> ScopeFifo<'scope> { } } } + + /// Spawns a job into every thread of the fork-join scope `self`. This job will + /// execute on each thread sometime before the fork-join scope completes. The + /// job is specified as a closure, and this closure receives its own reference + /// to the scope `self` as argument, as well as a `BroadcastContext`. + pub fn spawn_broadcast(&self, body: BODY) + where + BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, + { + let job = ArcJob::new(move || { + let body = &body; + self.base + .execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx))) + }); + unsafe { self.base.inject_broadcast(job) } + } } impl<'scope> ScopeBase<'scope> { @@ -615,6 +648,19 @@ impl<'scope> ScopeBase<'scope> { self.job_completed_latch.increment(); } + unsafe fn inject_broadcast(&self, job: Arc>) + where + FUNC: Fn() + Send + Sync, + { + let n_threads = self.registry.num_threads(); + let job_refs = (0..n_threads).map(|_| { + self.increment(); + ArcJob::as_job_ref(&job) + }); + + self.registry.inject_broadcast(job_refs); + } + /// Executes `func` as a job, either aborting or executing as /// appropriate. fn complete(&self, owner: Option<&WorkerThread>, func: FUNC) -> R diff --git a/rayon-core/src/scope/test.rs b/rayon-core/src/scope/test.rs index de06c7b70..00dd18c92 100644 --- a/rayon-core/src/scope/test.rs +++ b/rayon-core/src/scope/test.rs @@ -6,7 +6,7 @@ use rand_xorshift::XorShiftRng; use std::cmp; use std::iter::once; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Mutex; +use std::sync::{Barrier, Mutex}; use std::vec; #[test] @@ -513,3 +513,89 @@ fn mixed_lifetime_scope_fifo() { increment(&[&counter; 100]); assert_eq!(counter.into_inner(), 100); } + +#[test] +fn scope_spawn_broadcast() { + let sum = AtomicUsize::new(0); + let n = scope(|s| { + s.spawn_broadcast(|_, ctx| { + sum.fetch_add(ctx.index(), Ordering::Relaxed); + }); + crate::current_num_threads() + }); + assert_eq!(sum.into_inner(), n * (n - 1) / 2); +} + +#[test] +fn scope_fifo_spawn_broadcast() { + let sum = AtomicUsize::new(0); + let n = scope_fifo(|s| { + s.spawn_broadcast(|_, ctx| { + sum.fetch_add(ctx.index(), Ordering::Relaxed); + }); + crate::current_num_threads() + }); + assert_eq!(sum.into_inner(), n * (n - 1) / 2); +} + +#[test] +fn scope_spawn_broadcast_nested() { + let sum = AtomicUsize::new(0); + let n = scope(|s| { + s.spawn_broadcast(|s, _| { + s.spawn_broadcast(|_, ctx| { + sum.fetch_add(ctx.index(), Ordering::Relaxed); + }); + }); + crate::current_num_threads() + }); + assert_eq!(sum.into_inner(), n * n * (n - 1) / 2); +} + +#[test] +fn scope_spawn_broadcast_barrier() { + let barrier = Barrier::new(8); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool.in_place_scope(|s| { + s.spawn_broadcast(|_, _| { + barrier.wait(); + }); + barrier.wait(); + }); +} + +#[test] +fn scope_spawn_broadcast_panic_one() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.scope(|s| { + s.spawn_broadcast(|_, ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() == 3 { + panic!("Hello, world!"); + } + }); + }); + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +} + +#[test] +fn scope_spawn_broadcast_panic_many() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.scope(|s| { + s.spawn_broadcast(|_, ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() % 2 == 0 { + panic!("Hello, world!"); + } + }); + }); + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +} From 812ca025aedddea8a4c7d8477146527b71b33e19 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Sat, 11 Jun 2022 10:29:54 -0700 Subject: [PATCH 4/6] Simplify calls that use the panic_handler --- rayon-core/src/broadcast/mod.rs | 8 +------- rayon-core/src/registry.rs | 30 +++++++----------------------- rayon-core/src/spawn/mod.rs | 7 +------ 3 files changed, 9 insertions(+), 36 deletions(-) diff --git a/rayon-core/src/broadcast/mod.rs b/rayon-core/src/broadcast/mod.rs index bece9b990..bbf6f3e0d 100644 --- a/rayon-core/src/broadcast/mod.rs +++ b/rayon-core/src/broadcast/mod.rs @@ -1,7 +1,6 @@ use crate::job::{ArcJob, StackJob}; use crate::registry::{Registry, WorkerThread}; use crate::scope::ScopeLatch; -use crate::unwind; use std::fmt; use std::marker::PhantomData; use std::sync::Arc; @@ -131,12 +130,7 @@ where let job = ArcJob::new({ let registry = Arc::clone(registry); move || { - match unwind::halt_unwinding(|| BroadcastContext::with(&op)) { - Ok(()) => {} - Err(err) => { - registry.handle_panic(err); - } - } + registry.catch_unwind(|| BroadcastContext::with(&op)); registry.terminate(); // (*) permit registry to terminate now } }); diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index 32185ad60..69c43d95d 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -8,7 +8,6 @@ use crate::{ ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, }; use crossbeam_deque::{Injector, Steal, Stealer, Worker}; -use std::any::Any; use std::cell::Cell; use std::collections::hash_map::DefaultHasher; use std::fmt; @@ -332,19 +331,14 @@ impl Registry { self.thread_infos.len() } - pub(super) fn handle_panic(&self, err: Box) { - match self.panic_handler { - Some(ref handler) => { - // If the customizable panic handler itself panics, - // then we abort. - let abort_guard = unwind::AbortIfPanic; + pub(super) fn catch_unwind(&self, f: impl FnOnce()) { + if let Err(err) = unwind::halt_unwinding(f) { + // If there is no handler, or if that handler itself panics, then we abort. + let abort_guard = unwind::AbortIfPanic; + if let Some(ref handler) = self.panic_handler { handler(err); mem::forget(abort_guard); } - None => { - // Default panic handler aborts. - let _ = unwind::AbortIfPanic; // let this drop. - } } } @@ -880,12 +874,7 @@ unsafe fn main_loop( // Inform a user callback that we started a thread. if let Some(ref handler) = registry.start_handler { - match unwind::halt_unwinding(|| handler(index)) { - Ok(()) => {} - Err(err) => { - registry.handle_panic(err); - } - } + registry.catch_unwind(|| handler(index)); } let my_terminate_latch = ®istry.thread_infos[index].terminate; @@ -908,12 +897,7 @@ unsafe fn main_loop( // Inform a user callback that we exited a thread. if let Some(ref handler) = registry.exit_handler { - match unwind::halt_unwinding(|| handler(index)) { - Ok(()) => {} - Err(err) => { - registry.handle_panic(err); - } - } + registry.catch_unwind(|| handler(index)); // We're already exiting the thread, there's nothing else to do. } } diff --git a/rayon-core/src/spawn/mod.rs b/rayon-core/src/spawn/mod.rs index dc5725941..827e36e61 100644 --- a/rayon-core/src/spawn/mod.rs +++ b/rayon-core/src/spawn/mod.rs @@ -94,12 +94,7 @@ where HeapJob::new({ let registry = Arc::clone(registry); move || { - match unwind::halt_unwinding(func) { - Ok(()) => {} - Err(err) => { - registry.handle_panic(err); - } - } + registry.catch_unwind(func); registry.terminate(); // (*) permit registry to terminate now } }) From bd7b61ca8bf2ec472c74d221adfc4f8b22d2d090 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Sat, 11 Jun 2022 11:56:25 -0700 Subject: [PATCH 5/6] Add more internal enforcement of static/scope lifetimes --- rayon-core/src/broadcast/mod.rs | 2 +- rayon-core/src/job.rs | 16 +++++ rayon-core/src/scope/mod.rs | 104 ++++++++++++++++++++++---------- rayon-core/src/spawn/mod.rs | 2 +- 4 files changed, 90 insertions(+), 34 deletions(-) diff --git a/rayon-core/src/broadcast/mod.rs b/rayon-core/src/broadcast/mod.rs index bbf6f3e0d..452aa71b6 100644 --- a/rayon-core/src/broadcast/mod.rs +++ b/rayon-core/src/broadcast/mod.rs @@ -141,7 +141,7 @@ where // on each thread. This ref is decremented at the (*) above. registry.increment_terminate_count(); - ArcJob::as_job_ref(&job) + ArcJob::as_static_job_ref(&job) }); registry.inject_broadcast(job_refs); diff --git a/rayon-core/src/job.rs b/rayon-core/src/job.rs index b099d1735..b7a3dae18 100644 --- a/rayon-core/src/job.rs +++ b/rayon-core/src/job.rs @@ -144,6 +144,14 @@ where pub(super) unsafe fn into_job_ref(self: Box) -> JobRef { JobRef::new(Box::into_raw(self)) } + + /// Creates a static `JobRef` from this job. + pub(super) fn into_static_job_ref(self: Box) -> JobRef + where + BODY: 'static, + { + unsafe { self.into_job_ref() } + } } impl Job for HeapJob @@ -179,6 +187,14 @@ where pub(super) unsafe fn as_job_ref(this: &Arc) -> JobRef { JobRef::new(Arc::into_raw(Arc::clone(this))) } + + /// Creates a static `JobRef` from this job. + pub(super) fn as_static_job_ref(this: &Arc) -> JobRef + where + BODY: 'static, + { + unsafe { Self::as_job_ref(this) } + } } impl Job for ArcJob diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index 7eadf1f19..25cda832e 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -6,7 +6,7 @@ //! [`join()`]: ../join/join.fn.html use crate::broadcast::BroadcastContext; -use crate::job::{ArcJob, HeapJob, JobFifo}; +use crate::job::{ArcJob, HeapJob, JobFifo, JobRef}; use crate::latch::{CountLatch, CountLockLatch, Latch}; use crate::registry::{global_registry, in_worker, Registry, WorkerThread}; use crate::unwind; @@ -539,16 +539,18 @@ impl<'scope> Scope<'scope> { where BODY: FnOnce(&Scope<'scope>) + Send + 'scope, { - self.base.increment(); - unsafe { - let job_ref = - HeapJob::new(move || self.base.execute_job(move || body(self))).into_job_ref(); + let scope_ptr = ScopePtr(self); + let job = HeapJob::new(move || { + // SAFETY: this job will execute before the scope ends. + let scope = unsafe { scope_ptr.as_ref() }; + scope.base.execute_job(move || body(scope)) + }); + let job_ref = self.base.heap_job_ref(job); - // Since `Scope` implements `Sync`, we can't be sure that we're still in a - // thread of this pool, so we can't just push to the local worker thread. - // Also, this might be an in-place scope. - self.base.registry.inject_or_push(job_ref); - } + // Since `Scope` implements `Sync`, we can't be sure that we're still in a + // thread of this pool, so we can't just push to the local worker thread. + // Also, this might be an in-place scope. + self.base.registry.inject_or_push(job_ref); } /// Spawns a job into every thread of the fork-join scope `self`. This job will @@ -559,12 +561,15 @@ impl<'scope> Scope<'scope> { where BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, { + let scope_ptr = ScopePtr(self); let job = ArcJob::new(move || { + // SAFETY: this job will execute before the scope ends. + let scope = unsafe { scope_ptr.as_ref() }; let body = &body; - self.base - .execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx))) + let func = move || BroadcastContext::with(move |ctx| body(scope, ctx)); + scope.base.execute_job(func); }); - unsafe { self.base.inject_broadcast(job) } + self.base.inject_broadcast(job) } } @@ -594,20 +599,23 @@ impl<'scope> ScopeFifo<'scope> { where BODY: FnOnce(&ScopeFifo<'scope>) + Send + 'scope, { - self.base.increment(); - unsafe { - let job_ref = - HeapJob::new(move || self.base.execute_job(move || body(self))).into_job_ref(); - - // If we're in the pool, use our scope's private fifo for this thread to execute - // in a locally-FIFO order. Otherwise, just use the pool's global injector. - match self.base.registry.current_thread() { - Some(worker) => { - let fifo = &self.fifos[worker.index()]; - worker.push(fifo.push(job_ref)); - } - None => self.base.registry.inject(&[job_ref]), + let scope_ptr = ScopePtr(self); + let job = HeapJob::new(move || { + // SAFETY: this job will execute before the scope ends. + let scope = unsafe { scope_ptr.as_ref() }; + scope.base.execute_job(move || body(scope)) + }); + let job_ref = self.base.heap_job_ref(job); + + // If we're in the pool, use our scope's private fifo for this thread to execute + // in a locally-FIFO order. Otherwise, just use the pool's global injector. + match self.base.registry.current_thread() { + Some(worker) => { + let fifo = &self.fifos[worker.index()]; + // SAFETY: this job will execute before the scope ends. + unsafe { worker.push(fifo.push(job_ref)) }; } + None => self.base.registry.inject(&[job_ref]), } } @@ -619,12 +627,15 @@ impl<'scope> ScopeFifo<'scope> { where BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, { + let scope_ptr = ScopePtr(self); let job = ArcJob::new(move || { + // SAFETY: this job will execute before the scope ends. + let scope = unsafe { scope_ptr.as_ref() }; let body = &body; - self.base - .execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx))) + let func = move || BroadcastContext::with(move |ctx| body(scope, ctx)); + scope.base.execute_job(func); }); - unsafe { self.base.inject_broadcast(job) } + self.base.inject_broadcast(job) } } @@ -648,12 +659,22 @@ impl<'scope> ScopeBase<'scope> { self.job_completed_latch.increment(); } - unsafe fn inject_broadcast(&self, job: Arc>) + fn heap_job_ref(&self, job: Box>) -> JobRef where - FUNC: Fn() + Send + Sync, + FUNC: FnOnce() + Send + 'scope, + { + unsafe { + self.increment(); + job.into_job_ref() + } + } + + fn inject_broadcast(&self, job: Arc>) + where + FUNC: Fn() + Send + Sync + 'scope, { let n_threads = self.registry.num_threads(); - let job_refs = (0..n_threads).map(|_| { + let job_refs = (0..n_threads).map(|_| unsafe { self.increment(); ArcJob::as_job_ref(&job) }); @@ -817,3 +838,22 @@ impl fmt::Debug for ScopeLatch { } } } + +/// Used to capture a scope `&Self` pointer in jobs, without faking a lifetime. +/// +/// Unsafe code is still required to dereference the pointer, but that's fine in +/// scope jobs that are guaranteed to execute before the scope ends. +struct ScopePtr(*const T); + +// SAFETY: !Send for raw pointers is not for safety, just as a lint +unsafe impl Send for ScopePtr {} + +// SAFETY: !Sync for raw pointers is not for safety, just as a lint +unsafe impl Sync for ScopePtr {} + +impl ScopePtr { + // Helper to avoid disjoint captures of `scope_ptr.0` + unsafe fn as_ref(&self) -> &T { + &*self.0 + } +} diff --git a/rayon-core/src/spawn/mod.rs b/rayon-core/src/spawn/mod.rs index 827e36e61..ae1f211ef 100644 --- a/rayon-core/src/spawn/mod.rs +++ b/rayon-core/src/spawn/mod.rs @@ -98,7 +98,7 @@ where registry.terminate(); // (*) permit registry to terminate now } }) - .into_job_ref() + .into_static_job_ref() } /// Fires off a task into the Rayon threadpool in the "static" or From 9ef85cd5d84966bc332eaa408c38be141f52e0d6 Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Fri, 28 Oct 2022 10:29:27 -0700 Subject: [PATCH 6/6] Add some documentation about *when* broadcasts run --- rayon-core/src/thread_pool/mod.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rayon-core/src/thread_pool/mod.rs b/rayon-core/src/thread_pool/mod.rs index 98ea0bdde..0fc06dd6b 100644 --- a/rayon-core/src/thread_pool/mod.rs +++ b/rayon-core/src/thread_pool/mod.rs @@ -114,6 +114,13 @@ impl ThreadPool { /// `join`, `scope`, or parallel iterators will then operate within that /// threadpool. /// + /// Broadcasts are executed on each thread after they have exhausted their + /// local work queue, before they attempt work-stealing from other threads. + /// The goal of that strategy is to run everywhere in a timely manner + /// *without* being too disruptive to current work. There may be alternative + /// broadcast styles added in the future for more or less aggressive + /// injection, if the need arises. + /// /// # Warning: thread-local data /// /// Because `op` is executing within the Rayon thread-pool,