Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 42 additions & 26 deletions crates/tasks/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
Arc, OnceLock,
},
task::{ready, Context, Poll},
thread,
Expand Down Expand Up @@ -168,34 +168,50 @@ thread_local! {
///
/// The pool supports multiple init/clear cycles, allowing reuse of the same threads with
/// different state configurations.
///
/// The underlying rayon pool is created lazily on first access.
#[derive(Debug)]
pub struct WorkerPool {
pool: rayon::ThreadPool,
pool: OnceLock<rayon::ThreadPool>,
num_threads: usize,
thread_name_prefix: &'static str,
}

impl WorkerPool {
/// Creates a new `WorkerPool` with the given number of threads.
pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads))
/// Creates a new lazy `WorkerPool` with the given number of threads and a thread name prefix.
///
/// The underlying rayon pool is not created until the first method that requires it is called.
/// Thread names follow the pattern `"{prefix}-{index:02}"`.
pub const fn new(num_threads: usize, thread_name_prefix: &'static str) -> Self {
Self { pool: OnceLock::new(), num_threads, thread_name_prefix }
}

/// Creates a new `WorkerPool` from a [`rayon::ThreadPoolBuilder`].
///
/// Installs a panic handler that logs panics instead of aborting the process.
pub fn from_builder(
builder: rayon::ThreadPoolBuilder,
) -> Result<Self, rayon::ThreadPoolBuildError> {
Ok(Self { pool: build_pool_with_panic_handler(builder)? })
/// Returns a reference to the underlying rayon pool, creating it on first access.
fn pool(&self) -> &rayon::ThreadPool {
self.pool.get_or_init(|| {
let prefix = self.thread_name_prefix;
build_pool_with_panic_handler(
rayon::ThreadPoolBuilder::new()
.num_threads(self.num_threads)
.thread_name(move |i| format!("{prefix}-{i:02}")),
)
.unwrap_or_else(|err| panic!("failed to build {prefix} worker pool: {err}"))
})
}

/// Returns `true` if the underlying rayon pool has been initialized.
pub fn is_initialized(&self) -> bool {
self.pool.get().is_some()
}

/// Returns the total number of threads in the underlying rayon pool.
pub fn current_num_threads(&self) -> usize {
self.pool.current_num_threads()
self.pool().current_num_threads()
}

/// Initializes per-thread [`Worker`] state on every thread in the pool.
pub fn init<T: 'static>(&self, f: impl Fn(Option<&mut T>) -> T + Sync) {
self.broadcast(self.pool.current_num_threads(), |worker| {
self.broadcast(self.pool().current_num_threads(), |worker| {
worker.init::<T>(&f);
});
}
Expand All @@ -206,14 +222,14 @@ impl WorkerPool {
/// Use this to initialize or re-initialize per-thread state via [`Worker::init`].
/// Only `num_threads` threads execute the closure; the rest skip it.
pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
if num_threads >= self.pool.current_num_threads() {
if num_threads >= self.pool().current_num_threads() {
// Fast path: run on every thread, no atomic coordination needed.
self.pool.broadcast(|_| {
self.pool().broadcast(|_| {
WORKER.with_borrow_mut(|worker| f(worker));
});
} else {
let remaining = AtomicUsize::new(num_threads);
self.pool.broadcast(|_| {
self.pool().broadcast(|_| {
// Atomically claim a slot; threads that can't decrement skip the closure.
let mut current = remaining.load(Ordering::Relaxed);
loop {
Expand All @@ -237,7 +253,7 @@ impl WorkerPool {

/// Clears the state on every thread in the pool.
pub fn clear(&self) {
self.pool.broadcast(|_| {
self.pool().broadcast(|_| {
WORKER.with_borrow_mut(Worker::clear);
});
}
Expand All @@ -248,27 +264,27 @@ impl WorkerPool {
/// Each thread can access its own [`Worker`] via the provided reference or through additional
/// [`WorkerPool::with_worker`] calls.
pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
self.pool.install(|| WORKER.with_borrow(|worker| f(worker)))
self.pool().install(|| WORKER.with_borrow(|worker| f(worker)))
}

/// Runs a closure on the pool without worker state access.
///
/// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`]
/// state.
pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
self.pool.install(f)
self.pool().install(f)
}

/// Spawns a closure on the pool.
pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
self.pool.spawn(f);
self.pool().spawn(f);
}

/// Executes `f` on this pool using [`rayon::in_place_scope`], which converts the calling
/// thread into a worker for the duration — tasks spawned inside the scope run on the pool
/// and the call blocks until all of them complete.
pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R {
self.pool.in_place_scope(f)
self.pool().in_place_scope(f)
}

/// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure.
Expand Down Expand Up @@ -398,7 +414,7 @@ mod tests {

#[test]
fn worker_pool_init_and_access() {
let pool = WorkerPool::new(2).unwrap();
let pool = WorkerPool::new(2, "test");

pool.broadcast(2, |worker| {
worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
Expand All @@ -415,7 +431,7 @@ mod tests {

#[test]
fn worker_pool_reinit_reuses_resources() {
let pool = WorkerPool::new(1).unwrap();
let pool = WorkerPool::new(1, "test");

pool.broadcast(1, |worker| {
worker.init::<Vec<u8>>(|existing| {
Expand All @@ -441,7 +457,7 @@ mod tests {

#[test]
fn worker_pool_clear_and_reinit() {
let pool = WorkerPool::new(1).unwrap();
let pool = WorkerPool::new(1, "test");

pool.broadcast(1, |worker| {
worker.init::<u64>(|_| 42);
Expand All @@ -464,7 +480,7 @@ mod tests {
fn worker_pool_par_iter_with_worker() {
use rayon::prelude::*;

let pool = WorkerPool::new(2).unwrap();
let pool = WorkerPool::new(2, "test");

pool.broadcast(2, |worker| {
worker.init::<u64>(|_| 10);
Expand Down
71 changes: 19 additions & 52 deletions crates/tasks/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ use crate::{
};
use futures_util::{future::select, Future, FutureExt, TryFutureExt};
#[cfg(feature = "rayon")]
use std::sync::OnceLock;
#[cfg(feature = "rayon")]
use std::{num::NonZeroUsize, thread::available_parallelism};
use std::{
pin::pin,
Expand Down Expand Up @@ -237,34 +235,6 @@ pub enum RuntimeBuildError {
RayonBuild(#[from] rayon::ThreadPoolBuildError),
}

#[cfg(feature = "rayon")]
#[derive(Debug)]
struct LazyWorkerPool {
pool: OnceLock<WorkerPool>,
num_threads: usize,
thread_name_prefix: &'static str,
}

#[cfg(feature = "rayon")]
impl LazyWorkerPool {
const fn new(num_threads: usize, thread_name_prefix: &'static str) -> Self {
Self { pool: OnceLock::new(), num_threads, thread_name_prefix }
}

fn get(&self) -> &WorkerPool {
let num_threads = self.num_threads;
let thread_name_prefix = self.thread_name_prefix;
self.pool.get_or_init(|| {
WorkerPool::from_builder(
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(move |i| format!("{thread_name_prefix}-{i:02}")),
)
.unwrap_or_else(|err| panic!("failed to build {thread_name_prefix} worker pool: {err}"))
})
}
}

// ── RuntimeInner ──────────────────────────────────────────────────────

struct RuntimeInner {
Expand Down Expand Up @@ -303,7 +273,7 @@ struct RuntimeInner {
prewarming_pool: WorkerPool,
/// BAL streaming pool (BAL hashed state streaming).
#[cfg(feature = "rayon")]
bal_streaming_pool: LazyWorkerPool,
bal_streaming_pool: WorkerPool,
/// Named single-thread worker map. Each unique name gets a dedicated OS thread
/// that is reused across all tasks submitted under that name.
worker_map: WorkerMap,
Expand Down Expand Up @@ -392,7 +362,7 @@ impl Runtime {
/// Get the BAL streaming pool.
#[cfg(feature = "rayon")]
pub fn bal_streaming_pool(&self) -> &WorkerPool {
self.0.bal_streaming_pool.get()
&self.0.bal_streaming_pool
}
}

Expand Down Expand Up @@ -837,30 +807,20 @@ impl RuntimeBuilder {

let proof_storage_worker_threads =
config.rayon.proof_storage_worker_threads.unwrap_or(default_threads * 2);
let proof_storage_worker_pool = WorkerPool::from_builder(
rayon::ThreadPoolBuilder::new()
.num_threads(proof_storage_worker_threads)
.thread_name(|i| format!("proof-strg-{i:02}")),
)?;
let proof_storage_worker_pool =
WorkerPool::new(proof_storage_worker_threads, "proof-strg");

let proof_account_worker_threads =
config.rayon.proof_account_worker_threads.unwrap_or(default_threads * 2);
let proof_account_worker_pool = WorkerPool::from_builder(
rayon::ThreadPoolBuilder::new()
.num_threads(proof_account_worker_threads)
.thread_name(|i| format!("proof-acct-{i:02}")),
)?;
let proof_account_worker_pool =
WorkerPool::new(proof_account_worker_threads, "proof-acct");

let prewarming_threads = config.rayon.prewarming_threads.unwrap_or(default_threads);
let prewarming_pool = WorkerPool::from_builder(
rayon::ThreadPoolBuilder::new()
.num_threads(prewarming_threads)
.thread_name(|i| format!("prewarm-{i:02}")),
)?;
let prewarming_pool = WorkerPool::new(prewarming_threads, "prewarm");

let bal_streaming_threads =
config.rayon.bal_streaming_threads.unwrap_or(default_threads);
let bal_streaming_pool = LazyWorkerPool::new(bal_streaming_threads, "bal-stream");
let bal_streaming_pool = WorkerPool::new(bal_streaming_threads, "bal-stream");

debug!(
default_threads,
Expand All @@ -871,7 +831,7 @@ impl RuntimeBuilder {
prewarming_threads,
bal_streaming_threads,
max_blocking_tasks = config.rayon.max_blocking_tasks,
"Initialized rayon thread pools and configured lazy BAL streaming pool"
"Configured lazy rayon worker pools"
);

(
Expand Down Expand Up @@ -962,12 +922,19 @@ mod tests {

#[cfg(feature = "rayon")]
#[test]
fn test_bal_streaming_pool_is_lazy() {
fn test_worker_pools_are_lazy() {
let runtime = Runtime::test();

assert!(runtime.0.bal_streaming_pool.pool.get().is_none());
// Worker pools are lazy — not initialized until first access.
assert!(!runtime.0.bal_streaming_pool.is_initialized());
assert!(!runtime.0.proof_storage_worker_pool.is_initialized());

// Accessing them triggers initialization and returns the configured thread count.
assert_eq!(runtime.bal_streaming_pool().current_num_threads(), 2);
assert!(runtime.0.bal_streaming_pool.pool.get().is_some());
assert!(runtime.0.bal_streaming_pool.is_initialized());

assert_eq!(runtime.proof_storage_worker_pool().current_num_threads(), 2);
assert_eq!(runtime.proof_account_worker_pool().current_num_threads(), 2);
assert_eq!(runtime.prewarming_pool().current_num_threads(), 2);
}
}
Loading