Skip to content

Commit

Permalink
Add ThreadPool::broadcast
Browse files Browse the repository at this point in the history
A broadcast runs the closure on every thread in the pool, then collects
the results.  It's scheduled somewhat like a very soft interrupt -- it
won't preempt a thread's local work, but will run before it goes to
steal from any other threads.

This can be used when you want to precisely split your work per-thread,
or to set or retrieve some thread-local data in the pool, e.g. rayon-rs#483.
  • Loading branch information
cuviper committed May 9, 2019
1 parent 939d3ee commit e0f5400
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 6 deletions.
1 change: 1 addition & 0 deletions rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub use join::{join, join_context};
pub use scope::{scope, Scope};
pub use scope::{scope_fifo, ScopeFifo};
pub use spawn::{spawn, spawn_fifo};
pub use thread_pool::broadcast;
pub use thread_pool::current_thread_has_pending_tasks;
pub use thread_pool::current_thread_index;
pub use thread_pool::ThreadPool;
Expand Down
3 changes: 3 additions & 0 deletions rayon-core/src/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ pub(super) enum Event {
InjectJobs {
count: usize,
},
BroadcastJobs {
count: usize,
},
Join {
worker: usize,
},
Expand Down
128 changes: 122 additions & 6 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::ptr;
#[allow(deprecated)]
use std::sync::atomic::ATOMIC_USIZE_INIT;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Once, ONCE_INIT};
use std::sync::{Arc, Mutex, Once, ONCE_INIT};
use std::thread;
use std::usize;
use unwind;
Expand All @@ -28,6 +28,7 @@ pub(super) struct Registry {
thread_infos: Vec<ThreadInfo>,
sleep: Sleep,
injected_jobs: SegQueue<JobRef>,
broadcasts: Mutex<Vec<Worker<JobRef>>>,
panic_handler: Option<Box<PanicHandler>>,
start_handler: Option<Box<StartHandler>>,
exit_handler: Option<Box<ExitHandler>>,
Expand Down Expand Up @@ -116,10 +117,14 @@ impl Registry {
})
.unzip();

let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) =
(0..n_threads).map(|_| deque::fifo()).unzip();

let registry = Arc::new(Registry {
thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
sleep: Sleep::new(),
injected_jobs: SegQueue::new(),
broadcasts: Mutex::new(broadcasts),
terminate_latch: CountLatch::new(),
panic_handler: builder.take_panic_handler(),
start_handler: builder.take_start_handler(),
Expand All @@ -129,7 +134,7 @@ impl Registry {
// If we return early or panic, make sure to terminate existing threads.
let t1000 = Terminator(&registry);

for (index, worker) in workers.into_iter().enumerate() {
for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
let registry = registry.clone();
let mut b = thread::Builder::new();
if let Some(name) = builder.get_thread_name(index) {
Expand All @@ -138,7 +143,8 @@ impl Registry {
if let Some(stack_size) = builder.get_stack_size() {
b = b.stack_size(stack_size);
}
if let Err(e) = b.spawn(move || unsafe { main_loop(worker, registry, index) }) {
if let Err(e) = b.spawn(move || unsafe { main_loop(worker, stealer, registry, index) })
{
return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
}
}
Expand Down Expand Up @@ -307,7 +313,7 @@ impl Registry {
}

/// Push a job into the "external jobs" queue; it will be taken by
/// whatever worker has nothing to do. Use this is you know that
/// whatever worker has nothing to do. Use this if you know that
/// you are not on a worker of this registry.
pub(super) fn inject(&self, injected_jobs: &[JobRef]) {
log!(InjectJobs {
Expand Down Expand Up @@ -340,6 +346,99 @@ impl Registry {
job
}

/// Push a job into each thread's own "external jobs" queue; it will be
/// executed only on that thread, when it has nothing else to do locally,
/// before it tries to steal other work.
///
/// **Panics** if not given exactly as many jobs as there are threads.
pub(super) fn inject_all(&self, injected_jobs: &[JobRef]) {
assert_eq!(self.num_threads(), injected_jobs.len());
log!(BroadcastJobs {
count: injected_jobs.len()
});
{
let broadcasts = self.broadcasts.lock().unwrap();

// It should not be possible for `state.terminate` to be true
// here. It is only set to true when the user creates (and
// drops) a `ThreadPool`; and, in that case, they cannot be
// calling `inject_all()` later, since they dropped their
// `ThreadPool`.
assert!(
!self.terminate_latch.probe(),
"inject_all() sees state.terminate as true"
);

assert_eq!(broadcasts.len(), injected_jobs.len());
for (worker, &job_ref) in broadcasts.iter().zip(injected_jobs) {
worker.push(job_ref);
}
}
self.sleep.tickle(usize::MAX);
}

/// Execute `op` on every thread in the pool. It will be executed on each
/// thread when they have nothing else to do locally, before they try to
/// steal work from other threads. This function will not return until all
/// threads have completed the `op`.
pub(super) fn broadcast<OP, R>(&self, op: OP) -> Vec<R>
where
OP: Fn(&WorkerThread) -> R + Sync,
R: Send,
{
unsafe {
if let Some(current_thread) = WorkerThread::current().as_ref() {
if current_thread.registry().id() == self.id() {
// broadcasting within in our own pool
self.broadcast_jobs(op, SpinLatch::new, |latch| {
current_thread.wait_until(latch)
})
} else {
// broadcasting from a different pool
let sleep = &current_thread.registry().sleep;
self.broadcast_jobs(
op,
|| TickleLatch::new(SpinLatch::new(), sleep),
|latch| current_thread.wait_until(latch),
)
}
} else {
// broadcasting from outside any pool
self.broadcast_jobs(op, LockLatch::new, LockLatch::wait)
}
}
}

/// Common broadcast helper with different kinds of latches
unsafe fn broadcast_jobs<OP, R, L, New, Wait>(&self, op: OP, latch: New, wait: Wait) -> Vec<R>
where
OP: Fn(&WorkerThread) -> R + Sync,
R: Send,
L: Latch + Sync,
New: Fn() -> L,
Wait: Fn(&L),
{
let f = |injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread)
};

let n_threads = self.thread_infos.len();
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, latch())).collect();
let job_refs: Vec<_> = jobs.iter().map(|job| job.as_job_ref()).collect();

self.inject_all(&job_refs);

// Let all jobs have a chance to complete.
for job in &jobs {
wait(&job.latch);
}

// Collect the results, maybe propagating a panic.
jobs.into_iter().map(|job| job.into_result()).collect()
}

/// If already in a worker-thread of this registry, just execute `op`.
/// Otherwise, inject `op` in this thread-pool. Either way, block until `op`
/// completes and return its return value. If `op` panics, that panic will
Expand Down Expand Up @@ -478,6 +577,9 @@ pub(super) struct WorkerThread {
/// the "worker" half of our local deque
worker: Worker<JobRef>,

/// the "stealer" half of the worker's broadcast deque
stealer: Stealer<JobRef>,

/// local queue used for `spawn_fifo` indirection
fifo: JobFifo,

Expand Down Expand Up @@ -551,11 +653,19 @@ impl WorkerThread {
pub(super) unsafe fn take_local_job(&self) -> Option<JobRef> {
loop {
match self.worker.pop() {
Pop::Empty => return None,
Pop::Empty => break,
Pop::Data(d) => return Some(d),
Pop::Retry => {}
}
}

loop {
match self.stealer.steal() {
Steal::Empty => return None,
Steal::Data(d) => return Some(d),
Steal::Retry => {}
}
}
}

/// Wait until the latch is set. Try to keep busy by popping and
Expand Down Expand Up @@ -655,9 +765,15 @@ impl WorkerThread {

/// ////////////////////////////////////////////////////////////////////////

unsafe fn main_loop(worker: Worker<JobRef>, registry: Arc<Registry>, index: usize) {
unsafe fn main_loop(
worker: Worker<JobRef>,
stealer: Stealer<JobRef>,
registry: Arc<Registry>,
index: usize,
) {
let worker_thread = WorkerThread {
worker,
stealer,
fifo: JobFifo::new(),
index,
rng: XorShift64Star::new(),
Expand Down
59 changes: 59 additions & 0 deletions rayon-core/src/thread_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,49 @@ impl ThreadPool {
self.registry.in_worker(|_, _| op())
}

/// Executes `op` within every thread in the threadpool. Any attempts to use
/// `join`, `scope`, or parallel iterators will then operate within that
/// threadpool.
///
/// # Warning: thread-local data
///
/// Because `op` is executing within the Rayon thread-pool,
/// thread-local data from the current thread will not be
/// accessible.
///
/// # Panics
///
/// If `op` should panic on one or more threads, exactly one panic
/// will be propagated, only after all threads have completed
/// (or panicked) their own `op`.
///
/// # Examples
///
/// ```
/// # use rayon_core as rayon;
/// use std::sync::atomic::{AtomicUsize, Ordering};
///
/// fn main() {
/// let pool = rayon::ThreadPoolBuilder::new().num_threads(5).build().unwrap();
///
/// // The argument is the index of each thread
/// let v: Vec<usize> = pool.broadcast(|i| i * i);
/// assert_eq!(v, &[0, 1, 4, 9, 16]);
///
/// // The closure can reference the local stack
/// let count = AtomicUsize::new(0);
/// pool.broadcast(|_| count.fetch_add(1, Ordering::Relaxed));
/// assert_eq!(count.into_inner(), 5);
/// }
/// ```
pub fn broadcast<OP, R>(&self, op: OP) -> Vec<R>
where
OP: Fn(usize) -> R + Sync,
R: Send,
{
self.registry.broadcast(|worker| op(worker.index()))
}

/// Returns the (current) number of threads in the thread pool.
///
/// # Future compatibility note
Expand Down Expand Up @@ -330,3 +373,19 @@ pub fn current_thread_has_pending_tasks() -> Option<bool> {
Some(!curr.local_deque_is_empty())
}
}

/// Executes `op` within every thread in the current threadpool. If this is
/// called from a non-Rayon thread, it will execute in the global threadpool.
/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
/// within that threadpool.
///
/// For more information, see the [`ThreadPool::broadcast()`][m] method.
///
/// [m]: struct.ThreadPool.html#method.broadcast
pub fn broadcast<OP, R>(op: OP) -> Vec<R>
where
OP: Fn(usize) -> R + Sync,
R: Send,
{
Registry::current().broadcast(|worker| op(worker.index()))
}
87 changes: 87 additions & 0 deletions rayon-core/src/thread_pool/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,90 @@ fn spawn_fifo_order() {
let expected: Vec<i32> = (0..10).collect(); // FIFO -> natural order
assert_eq!(vec, expected);
}

#[test]
fn broadcast_global() {
let v = ::broadcast(|i| i);
assert!(v.into_iter().eq(0..::current_num_threads()));
}

#[test]
fn broadcast_pool() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.broadcast(|i| i);
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_self() {
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let v = pool.install(|| ::broadcast(|i| i));
assert!(v.into_iter().eq(0..7));
}

#[test]
fn broadcast_mutual() {
let count = AtomicUsize::new(0);
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.install(|| {
pool2.broadcast(|_| {
pool1.broadcast(|_| {
count.fetch_add(1, Ordering::Relaxed);
})
})
});
assert_eq!(count.into_inner(), 3 * 7);
}

#[test]
fn broadcast_mutual_sleepy() {
use std::{thread, time};

let count = AtomicUsize::new(0);
let pool1 = ThreadPoolBuilder::new().num_threads(3).build().unwrap();
let pool2 = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
pool1.install(|| {
thread::sleep(time::Duration::from_secs(1));
pool2.broadcast(|_| {
thread::sleep(time::Duration::from_secs(1));
pool1.broadcast(|_| {
thread::sleep(time::Duration::from_millis(100));
count.fetch_add(1, Ordering::Relaxed);
})
})
});
assert_eq!(count.into_inner(), 3 * 7);
}

#[test]
fn broadcast_panic_one() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = unwind::halt_unwinding(|| {
pool.broadcast(|i| {
count.fetch_add(1, Ordering::Relaxed);
if i == 3 {
panic!("Hello, world!");
}
})
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}

#[test]
fn broadcast_panic_many() {
let count = AtomicUsize::new(0);
let pool = ThreadPoolBuilder::new().num_threads(7).build().unwrap();
let result = unwind::halt_unwinding(|| {
pool.broadcast(|i| {
count.fetch_add(1, Ordering::Relaxed);
if i % 2 == 0 {
panic!("Hello, world!");
}
})
});
assert_eq!(count.into_inner(), 7);
assert!(result.is_err(), "broadcast panic should propagate!");
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ mod par_either;

mod compile_fail;

pub use rayon_core::broadcast;
pub use rayon_core::FnContext;
pub use rayon_core::ThreadPool;
pub use rayon_core::ThreadPoolBuildError;
Expand Down

0 comments on commit e0f5400

Please sign in to comment.