Skip to content

Commit

Permalink
Add Scope::spawn_broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Nov 16, 2022
1 parent 817c4cc commit c7a3172
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 14 deletions.
22 changes: 10 additions & 12 deletions rayon-core/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ pub struct BroadcastContext<'a> {
}

impl<'a> BroadcastContext<'a> {
fn new(worker: &WorkerThread) -> BroadcastContext<'_> {
BroadcastContext {
worker,
pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
let worker_thread = WorkerThread::current();
assert!(!worker_thread.is_null());
f(BroadcastContext {
worker: unsafe { &*worker_thread },
_marker: PhantomData,
}
})
}

/// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
Expand Down Expand Up @@ -98,10 +100,9 @@ where
OP: Fn(BroadcastContext<'_>) -> R + Sync,
R: Send,
{
let f = move |injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(BroadcastContext::new(&*worker_thread))
let f = move |injected: bool| {
debug_assert!(injected);
BroadcastContext::with(&op)
};

let n_threads = registry.num_threads();
Expand Down Expand Up @@ -130,10 +131,7 @@ where
let job = ArcJob::new({
let registry = Arc::clone(registry);
move || {
let worker_thread = WorkerThread::current();
assert!(!worker_thread.is_null());
let ctx = BroadcastContext::new(&*worker_thread);
match unwind::halt_unwinding(|| op(ctx)) {
match unwind::halt_unwinding(|| BroadcastContext::with(&op)) {
Ok(()) => {}
Err(err) => {
registry.handle_panic(err);
Expand Down
48 changes: 47 additions & 1 deletion rayon-core/src/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
//! [`in_place_scope()`]: fn.in_place_scope.html
//! [`join()`]: ../join/join.fn.html

use crate::job::{HeapJob, JobFifo};
use crate::broadcast::BroadcastContext;
use crate::job::{ArcJob, HeapJob, JobFifo};
use crate::latch::{CountLatch, CountLockLatch, Latch};
use crate::registry::{global_registry, in_worker, Registry, WorkerThread};
use crate::unwind;
Expand Down Expand Up @@ -549,6 +550,22 @@ impl<'scope> Scope<'scope> {
self.base.registry.inject_or_push(job_ref);
}
}

/// Spawns a job into every thread of the fork-join scope `self`. This job will
/// execute on each thread sometime before the fork-join scope completes. The
/// job is specified as a closure, and this closure receives its own reference
/// to the scope `self` as argument, as well as a `BroadcastContext`.
pub fn spawn_broadcast<BODY>(&self, body: BODY)
where
BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
{
let job = ArcJob::new(move || {
let body = &body;
self.base
.execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx)))
});
unsafe { self.base.inject_broadcast(job) }
}
}

impl<'scope> ScopeFifo<'scope> {
Expand Down Expand Up @@ -593,6 +610,22 @@ impl<'scope> ScopeFifo<'scope> {
}
}
}

/// Spawns a job into every thread of the fork-join scope `self`. This job will
/// execute on each thread sometime before the fork-join scope completes. The
/// job is specified as a closure, and this closure receives its own reference
/// to the scope `self` as argument, as well as a `BroadcastContext`.
pub fn spawn_broadcast<BODY>(&self, body: BODY)
where
BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
{
let job = ArcJob::new(move || {
let body = &body;
self.base
.execute_job(move || BroadcastContext::with(move |ctx| body(self, ctx)))
});
unsafe { self.base.inject_broadcast(job) }
}
}

impl<'scope> ScopeBase<'scope> {
Expand All @@ -615,6 +648,19 @@ impl<'scope> ScopeBase<'scope> {
self.job_completed_latch.increment();
}

unsafe fn inject_broadcast<FUNC>(&self, job: Arc<ArcJob<FUNC>>)
where
FUNC: Fn() + Send + Sync,
{
let n_threads = self.registry.num_threads();
let job_refs = (0..n_threads).map(|_| {
self.increment();
ArcJob::as_job_ref(&job)
});

self.registry.inject_broadcast(job_refs);
}

/// Executes `func` as a job, either aborting or executing as
/// appropriate.
fn complete<FUNC, R>(&self, owner: Option<&WorkerThread>, func: FUNC) -> R
Expand Down
88 changes: 87 additions & 1 deletion rayon-core/src/scope/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rand_xorshift::XorShiftRng;
use std::cmp;
use std::iter::once;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::sync::{Barrier, Mutex};
use std::vec;

#[test]
Expand Down Expand Up @@ -513,3 +513,89 @@ fn mixed_lifetime_scope_fifo() {
increment(&[&counter; 100]);
assert_eq!(counter.into_inner(), 100);
}

#[test]
fn scope_spawn_broadcast() {
let sum = AtomicUsize::new(0);
let n = scope(|s| {
s.spawn_broadcast(|_, ctx| {
sum.fetch_add(ctx.index(), Ordering::Relaxed);
});
crate::current_num_threads()
});
assert_eq!(sum.into_inner(), n * (n - 1) / 2);
}

#[test]
fn scope_fifo_spawn_broadcast() {
let sum = AtomicUsize::new(0);
let n = scope_fifo(|s| {
s.spawn_broadcast(|_, ctx| {
sum.fetch_add(ctx.index(), Ordering::Relaxed);
});
crate::current_num_threads()
});
assert_eq!(sum.into_inner(), n * (n - 1) / 2);
}

#[test]
fn scope_spawn_broadcast_nested() {
let sum = AtomicUsize::new(0);
let n = scope(|s| {
s.spawn_broadcast(|s, _| {
s.spawn_broadcast(|_, ctx| {
sum.fetch_add(ctx.index(), Ordering::Relaxed);
});
});
crate::current_num_threads()
});
assert_eq!(sum.into_inner(), n * n * (n - 1) / 2);
}

#[test]
fn scope_spawn_broadcast_barrier() {
let barrier = Barrier::new(8);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool.in_place_scope(|s| {
s.spawn_broadcast(|_, _| {
barrier.wait();
});
barrier.wait();
});
}

#[test]
fn scope_spawn_broadcast_panic_one() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = crate::unwind::halt_unwinding(|| {
pool.scope(|s| {
s.spawn_broadcast(|_, ctx| {
count.fetch_add(1, Ordering::Relaxed);
if ctx.index() == 3 {
panic!("Hello, world!");
}
});
});
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}

#[test]
fn scope_spawn_broadcast_panic_many() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = crate::unwind::halt_unwinding(|| {
pool.scope(|s| {
s.spawn_broadcast(|_, ctx| {
count.fetch_add(1, Ordering::Relaxed);
if ctx.index() % 2 == 0 {
panic!("Hello, world!");
}
});
});
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}

0 comments on commit c7a3172

Please sign in to comment.