diff --git a/rayon-core/src/broadcast/mod.rs b/rayon-core/src/broadcast/mod.rs index 05de7dfa3..bece9b990 100644 --- a/rayon-core/src/broadcast/mod.rs +++ b/rayon-core/src/broadcast/mod.rs @@ -51,11 +51,13 @@ pub struct BroadcastContext<'a> { } impl<'a> BroadcastContext<'a> { - fn new(worker: &WorkerThread) -> BroadcastContext<'_> { - BroadcastContext { - worker, + pub(super) fn with(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R { + let worker_thread = WorkerThread::current(); + assert!(!worker_thread.is_null()); + f(BroadcastContext { + worker: unsafe { &*worker_thread }, _marker: PhantomData, - } + }) } /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`). @@ -98,10 +100,9 @@ where OP: Fn(BroadcastContext<'_>) -> R + Sync, R: Send, { - let f = move |injected| { - let worker_thread = WorkerThread::current(); - assert!(injected && !worker_thread.is_null()); - op(BroadcastContext::new(&*worker_thread)) + let f = move |injected: bool| { + debug_assert!(injected); + BroadcastContext::with(&op) }; let n_threads = registry.num_threads(); @@ -130,10 +131,7 @@ where let job = ArcJob::new({ let registry = Arc::clone(registry); move || { - let worker_thread = WorkerThread::current(); - assert!(!worker_thread.is_null()); - let ctx = BroadcastContext::new(&*worker_thread); - match unwind::halt_unwinding(|| op(ctx)) { + match unwind::halt_unwinding(|| BroadcastContext::with(&op)) { Ok(()) => {} Err(err) => { registry.handle_panic(err); diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index f9e76e85a..7eadf1f19 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -5,7 +5,8 @@ //! [`in_place_scope()`]: fn.in_place_scope.html //! [`join()`]: ../join/join.fn.html -use crate::job::{HeapJob, JobFifo}; +use crate::broadcast::BroadcastContext; +use crate::job::{ArcJob, HeapJob, JobFifo}; use crate::latch::{CountLatch, CountLockLatch, Latch}; use crate::registry::{global_registry, in_worker, Registry, WorkerThread}; use crate::unwind; @@ -549,6 +550,22 @@ impl<'scope> Scope<'scope> { self.base.registry.inject_or_push(job_ref); } } + + /// Spawns a job into every thread of the fork-join scope `self`. This job will + /// execute on each thread sometime before the fork-join scope completes. The + /// job is specified as a closure, and this closure receives its own reference + /// to the scope `self` as argument, as well as a `BroadcastContext`. + pub fn spawn_broadcast(&self, body: BODY) + where + BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, + { + let job = ArcJob::new(move || { + let body = &body; + self.base + .execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx))) + }); + unsafe { self.base.inject_broadcast(job) } + } } impl<'scope> ScopeFifo<'scope> { @@ -593,6 +610,22 @@ impl<'scope> ScopeFifo<'scope> { } } } + + /// Spawns a job into every thread of the fork-join scope `self`. This job will + /// execute on each thread sometime before the fork-join scope completes. The + /// job is specified as a closure, and this closure receives its own reference + /// to the scope `self` as argument, as well as a `BroadcastContext`. + pub fn spawn_broadcast(&self, body: BODY) + where + BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope, + { + let job = ArcJob::new(move || { + let body = &body; + self.base + .execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx))) + }); + unsafe { self.base.inject_broadcast(job) } + } } impl<'scope> ScopeBase<'scope> { @@ -615,6 +648,19 @@ impl<'scope> ScopeBase<'scope> { self.job_completed_latch.increment(); } + unsafe fn inject_broadcast(&self, job: Arc>) + where + FUNC: Fn() + Send + Sync, + { + let n_threads = self.registry.num_threads(); + let job_refs = (0..n_threads).map(|_| { + self.increment(); + ArcJob::as_job_ref(&job) + }); + + self.registry.inject_broadcast(job_refs); + } + /// Executes `func` as a job, either aborting or executing as /// appropriate. fn complete(&self, owner: Option<&WorkerThread>, func: FUNC) -> R diff --git a/rayon-core/src/scope/test.rs b/rayon-core/src/scope/test.rs index de06c7b70..00dd18c92 100644 --- a/rayon-core/src/scope/test.rs +++ b/rayon-core/src/scope/test.rs @@ -6,7 +6,7 @@ use rand_xorshift::XorShiftRng; use std::cmp; use std::iter::once; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Mutex; +use std::sync::{Barrier, Mutex}; use std::vec; #[test] @@ -513,3 +513,89 @@ fn mixed_lifetime_scope_fifo() { increment(&[&counter; 100]); assert_eq!(counter.into_inner(), 100); } + +#[test] +fn scope_spawn_broadcast() { + let sum = AtomicUsize::new(0); + let n = scope(|s| { + s.spawn_broadcast(|_, ctx| { + sum.fetch_add(ctx.index(), Ordering::Relaxed); + }); + crate::current_num_threads() + }); + assert_eq!(sum.into_inner(), n * (n - 1) / 2); +} + +#[test] +fn scope_fifo_spawn_broadcast() { + let sum = AtomicUsize::new(0); + let n = scope_fifo(|s| { + s.spawn_broadcast(|_, ctx| { + sum.fetch_add(ctx.index(), Ordering::Relaxed); + }); + crate::current_num_threads() + }); + assert_eq!(sum.into_inner(), n * (n - 1) / 2); +} + +#[test] +fn scope_spawn_broadcast_nested() { + let sum = AtomicUsize::new(0); + let n = scope(|s| { + s.spawn_broadcast(|s, _| { + s.spawn_broadcast(|_, ctx| { + sum.fetch_add(ctx.index(), Ordering::Relaxed); + }); + }); + crate::current_num_threads() + }); + assert_eq!(sum.into_inner(), n * n * (n - 1) / 2); +} + +#[test] +fn scope_spawn_broadcast_barrier() { + let barrier = Barrier::new(8); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + pool.in_place_scope(|s| { + s.spawn_broadcast(|_, _| { + barrier.wait(); + }); + barrier.wait(); + }); +} + +#[test] +fn scope_spawn_broadcast_panic_one() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.scope(|s| { + s.spawn_broadcast(|_, ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() == 3 { + panic!("Hello, world!"); + } + }); + }); + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +} + +#[test] +fn scope_spawn_broadcast_panic_many() { + let count = AtomicUsize::new(0); + let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap(); + let result = crate::unwind::halt_unwinding(|| { + pool.scope(|s| { + s.spawn_broadcast(|_, ctx| { + count.fetch_add(1, Ordering::Relaxed); + if ctx.index() % 2 == 0 { + panic!("Hello, world!"); + } + }); + }); + }); + assert_eq!(count.into_inner(), 7); + assert!(result.is_err(), "broadcast panic should propagate!"); +}