Skip to content

Commit

Permalink
Add ThreadPool::spawn_broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Nov 16, 2022
1 parent eb6c6ef commit 817c4cc
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 19 deletions.
57 changes: 56 additions & 1 deletion rayon-core/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::job::StackJob;
use crate::job::{ArcJob, StackJob};
use crate::registry::{Registry, WorkerThread};
use crate::scope::ScopeLatch;
use crate::unwind;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
Expand All @@ -25,6 +26,22 @@ where
unsafe { broadcast_in(op, &Registry::current()) }
}

/// Spawns an asynchronous task on every thread in this thread-pool. This task
/// will run in the implicit, global scope, which means that it may outlast the
/// current stack frame -- therefore, it cannot capture any references onto the
/// stack (you will likely need a `move` closure).
///
/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
///
/// [m]: struct.ThreadPool.html#method.spawn_broadcast
pub fn spawn_broadcast<OP>(op: OP)
where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
// We assert that current registry has not terminated.
unsafe { spawn_broadcast_in(op, &Registry::current()) }
}

/// Provides context to a closure called by `broadcast`.
pub struct BroadcastContext<'a> {
worker: &'a WorkerThread,
Expand Down Expand Up @@ -99,3 +116,41 @@ where
latch.wait(current_thread);
jobs.into_iter().map(|job| job.into_result()).collect()
}

/// Execute `op` on every thread in the pool. It will be executed on each
/// thread when they have nothing else to do locally, before they try to
/// steal work from other threads. This function returns immediately after
/// injecting the jobs.
///
/// Unsafe because `registry` must not yet have terminated.
pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
let job = ArcJob::new({
let registry = Arc::clone(registry);
move || {
let worker_thread = WorkerThread::current();
assert!(!worker_thread.is_null());
let ctx = BroadcastContext::new(&*worker_thread);
match unwind::halt_unwinding(|| op(ctx)) {
Ok(()) => {}
Err(err) => {
registry.handle_panic(err);
}
}
registry.terminate(); // (*) permit registry to terminate now
}
});

let n_threads = registry.num_threads();
let job_refs = (0..n_threads).map(|_| {
// Ensure that registry cannot terminate until this job has executed
// on each thread. This ref is decremented at the (*) above.
registry.increment_terminate_count();

ArcJob::as_job_ref(&job)
});

registry.inject_broadcast(job_refs);
}
115 changes: 113 additions & 2 deletions rayon-core/src/broadcast/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,61 @@

use crate::ThreadPoolBuilder;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::{thread, time};

#[test]
fn broadcast_global() {
let v = crate::broadcast(|ctx| ctx.index());
assert!(v.into_iter().eq(0..crate::current_num_threads()));
}

#[test]
fn spawn_broadcast_global() {
let (tx, rx) = crossbeam_channel::unbounded();
crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap());

let mut v: Vec<_> = rx.into_iter().collect();
v.sort_unstable();
assert!(v.into_iter().eq(0..crate::current_num_threads()));
}

#[test]
fn broadcast_pool() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.broadcast(|ctx| ctx.index());
assert!(v.into_iter().eq(0..7));
}

#[test]
fn spawn_broadcast_pool() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool.spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap());

let mut v: Vec<_> = rx.into_iter().collect();
v.sort_unstable();
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_self() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.install(|| crate::broadcast(|ctx| ctx.index()));
assert!(v.into_iter().eq(0..7));
}

#[test]
fn spawn_broadcast_self() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool.spawn(|| crate::spawn_broadcast(move |ctx| tx.send(ctx.index()).unwrap()));

let mut v: Vec<_> = rx.into_iter().collect();
v.sort_unstable();
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_mutual() {
let count = AtomicUsize::new(0);
Expand All @@ -39,9 +73,24 @@ fn broadcast_mutual() {
}

#[test]
fn broadcast_mutual_sleepy() {
use std::{thread, time};
fn spawn_broadcast_mutual() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap());
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.spawn({
let pool1 = Arc::clone(&pool1);
move || {
pool2.spawn_broadcast(move |_| {
let tx = tx.clone();
pool1.spawn_broadcast(move |_| tx.send(()).unwrap())
})
}
});
assert_eq!(rx.into_iter().count(), 3 * 7);
}

#[test]
fn broadcast_mutual_sleepy() {
let count = AtomicUsize::new(0);
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
Expand All @@ -58,6 +107,28 @@ fn broadcast_mutual_sleepy() {
assert_eq!(count.into_inner(), 3 * 7);
}

#[test]
fn spawn_broadcast_mutual_sleepy() {
let (tx, rx) = crossbeam_channel::unbounded();
let pool1 = Arc::new(ThreadPoolBuilder::new().num_threads(3).build().unwrap());
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.spawn({
let pool1 = Arc::clone(&pool1);
move || {
thread::sleep(time::Duration::from_secs(1));
pool2.spawn_broadcast(move |_| {
let tx = tx.clone();
thread::sleep(time::Duration::from_secs(1));
pool1.spawn_broadcast(move |_| {
thread::sleep(time::Duration::from_millis(100));
tx.send(()).unwrap();
})
})
}
});
assert_eq!(rx.into_iter().count(), 3 * 7);
}

#[test]
fn broadcast_panic_one() {
let count = AtomicUsize::new(0);
Expand All @@ -74,6 +145,26 @@ fn broadcast_panic_one() {
assert!(result.is_err(), "broadcast panic should propagate!");
}

#[test]
fn spawn_broadcast_panic_one() {
let (tx, rx) = crossbeam_channel::unbounded();
let (panic_tx, panic_rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new()
.num_threads(7)
.panic_handler(move |e| panic_tx.send(e).unwrap())
.build()
.unwrap();
pool.spawn_broadcast(move |ctx| {
tx.send(()).unwrap();
if ctx.index() == 3 {
panic!("Hello, world!");
}
});
drop(pool); // including panic_tx
assert_eq!(rx.into_iter().count(), 7);
assert_eq!(panic_rx.into_iter().count(), 1);
}

#[test]
fn broadcast_panic_many() {
let count = AtomicUsize::new(0);
Expand All @@ -89,3 +180,23 @@ fn broadcast_panic_many() {
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}

#[test]
fn spawn_broadcast_panic_many() {
let (tx, rx) = crossbeam_channel::unbounded();
let (panic_tx, panic_rx) = crossbeam_channel::unbounded();
let pool = ThreadPoolBuilder::new()
.num_threads(7)
.panic_handler(move |e| panic_tx.send(e).unwrap())
.build()
.unwrap();
pool.spawn_broadcast(move |ctx| {
tx.send(()).unwrap();
if ctx.index() % 2 == 0 {
panic!("Hello, world!");
}
});
drop(pool); // including panic_tx
assert_eq!(rx.into_iter().count(), 7);
assert_eq!(panic_rx.into_iter().count(), 4);
}
40 changes: 38 additions & 2 deletions rayon-core/src/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crossbeam_deque::{Injector, Steal};
use std::any::Any;
use std::cell::UnsafeCell;
use std::mem;
use std::sync::Arc;

pub(super) enum JobResult<T> {
None,
Expand Down Expand Up @@ -133,8 +134,8 @@ impl<BODY> HeapJob<BODY>
where
BODY: FnOnce() + Send,
{
pub(super) fn new(job: BODY) -> Self {
HeapJob { job }
pub(super) fn new(job: BODY) -> Box<Self> {
Box::new(HeapJob { job })
}

/// Creates a `JobRef` from this job -- note that this hides all
Expand All @@ -155,6 +156,41 @@ where
}
}

/// Represents a job stored in an `Arc` -- like `HeapJob`, but may
/// be turned into multiple `JobRef`s and called multiple times.
pub(super) struct ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
{
job: BODY,
}

impl<BODY> ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
{
pub(super) fn new(job: BODY) -> Arc<Self> {
Arc::new(ArcJob { job })
}

/// Creates a `JobRef` from this job -- note that this hides all
/// lifetimes, so it is up to you to ensure that this JobRef
/// doesn't outlive any data that it closes over.
pub(super) unsafe fn as_job_ref(this: &Arc<Self>) -> JobRef {
JobRef::new(Arc::into_raw(Arc::clone(this)))
}
}

impl<BODY> Job for ArcJob<BODY>
where
BODY: Fn() + Send + Sync,
{
unsafe fn execute(this: *const ()) {
let this = Arc::from_raw(this as *mut Self);
(this.job)();
}
}

impl<T> JobResult<T> {
fn call(func: impl FnOnce(bool) -> T) -> Self {
match unwind::halt_unwinding(|| func(true)) {
Expand Down
2 changes: 1 addition & 1 deletion rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ mod unwind;
mod compile_fail;
mod test;

pub use self::broadcast::{broadcast, BroadcastContext};
pub use self::broadcast::{broadcast, spawn_broadcast, BroadcastContext};
pub use self::join::{join, join_context};
pub use self::registry::ThreadBuilder;
pub use self::scope::{in_place_scope, scope, Scope};
Expand Down
12 changes: 4 additions & 8 deletions rayon-core/src/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,8 @@ impl<'scope> Scope<'scope> {
{
self.base.increment();
unsafe {
let job_ref = Box::new(HeapJob::new(move || {
self.base.execute_job(move || body(self))
}))
.into_job_ref();
let job_ref =
HeapJob::new(move || self.base.execute_job(move || body(self))).into_job_ref();

// Since `Scope` implements `Sync`, we can't be sure that we're still in a
// thread of this pool, so we can't just push to the local worker thread.
Expand Down Expand Up @@ -581,10 +579,8 @@ impl<'scope> ScopeFifo<'scope> {
{
self.base.increment();
unsafe {
let job_ref = Box::new(HeapJob::new(move || {
self.base.execute_job(move || body(self))
}))
.into_job_ref();
let job_ref =
HeapJob::new(move || self.base.execute_job(move || body(self))).into_job_ref();

// If we're in the pool, use our scope's private fifo for this thread to execute
// in a locally-FIFO order. Otherwise, just use the pool's global injector.
Expand Down
4 changes: 2 additions & 2 deletions rayon-core/src/spawn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ where
// executed. This ref is decremented at the (*) below.
registry.increment_terminate_count();

Box::new(HeapJob::new({
HeapJob::new({
let registry = Arc::clone(registry);
move || {
match unwind::halt_unwinding(func) {
Expand All @@ -102,7 +102,7 @@ where
}
registry.terminate(); // (*) permit registry to terminate now
}
}))
})
.into_job_ref()
}

Expand Down
16 changes: 14 additions & 2 deletions rayon-core/src/thread_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//!
//! [`ThreadPool`]: struct.ThreadPool.html

use crate::broadcast::{broadcast_in, BroadcastContext};
use crate::broadcast::{self, BroadcastContext};
use crate::join;
use crate::registry::{Registry, ThreadSpawn, WorkerThread};
use crate::scope::{do_in_place_scope, do_in_place_scope_fifo};
Expand Down Expand Up @@ -151,7 +151,7 @@ impl ThreadPool {
R: Send,
{
// We assert that `self.registry` has not terminated.
unsafe { broadcast_in(op, &self.registry) }
unsafe { broadcast::broadcast_in(op, &self.registry) }
}

/// Returns the (current) number of threads in the thread pool.
Expand Down Expand Up @@ -320,6 +320,18 @@ impl ThreadPool {
// We assert that `self.registry` has not terminated.
unsafe { spawn::spawn_fifo_in(op, &self.registry) }
}

/// Spawns an asynchronous task on every thread in this thread-pool. This task
/// will run in the implicit, global scope, which means that it may outlast the
/// current stack frame -- therefore, it cannot capture any references onto the
/// stack (you will likely need a `move` closure).
pub fn spawn_broadcast<OP>(&self, op: OP)
where
OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
{
// We assert that `self.registry` has not terminated.
unsafe { broadcast::spawn_broadcast_in(op, &self.registry) }
}
}

impl Drop for ThreadPool {
Expand Down
Loading

0 comments on commit 817c4cc

Please sign in to comment.