diff --git a/crates/tasks/src/pool.rs b/crates/tasks/src/pool.rs index 1f72c13d985..88252ae5ba1 100644 --- a/crates/tasks/src/pool.rs +++ b/crates/tasks/src/pool.rs @@ -8,7 +8,7 @@ use std::{ pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, + Arc, OnceLock, }, task::{ready, Context, Poll}, thread, @@ -168,34 +168,50 @@ thread_local! { /// /// The pool supports multiple init/clear cycles, allowing reuse of the same threads with /// different state configurations. +/// +/// The underlying rayon pool is created lazily on first access. #[derive(Debug)] pub struct WorkerPool { - pool: rayon::ThreadPool, + pool: OnceLock, + num_threads: usize, + thread_name_prefix: &'static str, } impl WorkerPool { - /// Creates a new `WorkerPool` with the given number of threads. - pub fn new(num_threads: usize) -> Result { - Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads)) + /// Creates a new lazy `WorkerPool` with the given number of threads and a thread name prefix. + /// + /// The underlying rayon pool is not created until the first method that requires it is called. + /// Thread names follow the pattern `"{prefix}-{index:02}"`. + pub const fn new(num_threads: usize, thread_name_prefix: &'static str) -> Self { + Self { pool: OnceLock::new(), num_threads, thread_name_prefix } } - /// Creates a new `WorkerPool` from a [`rayon::ThreadPoolBuilder`]. - /// - /// Installs a panic handler that logs panics instead of aborting the process. - pub fn from_builder( - builder: rayon::ThreadPoolBuilder, - ) -> Result { - Ok(Self { pool: build_pool_with_panic_handler(builder)? }) + /// Returns a reference to the underlying rayon pool, creating it on first access. + fn pool(&self) -> &rayon::ThreadPool { + self.pool.get_or_init(|| { + let prefix = self.thread_name_prefix; + build_pool_with_panic_handler( + rayon::ThreadPoolBuilder::new() + .num_threads(self.num_threads) + .thread_name(move |i| format!("{prefix}-{i:02}")), + ) + .unwrap_or_else(|err| panic!("failed to build {prefix} worker pool: {err}")) + }) + } + + /// Returns `true` if the underlying rayon pool has been initialized. + pub fn is_initialized(&self) -> bool { + self.pool.get().is_some() } /// Returns the total number of threads in the underlying rayon pool. pub fn current_num_threads(&self) -> usize { - self.pool.current_num_threads() + self.pool().current_num_threads() } /// Initializes per-thread [`Worker`] state on every thread in the pool. pub fn init(&self, f: impl Fn(Option<&mut T>) -> T + Sync) { - self.broadcast(self.pool.current_num_threads(), |worker| { + self.broadcast(self.pool().current_num_threads(), |worker| { worker.init::(&f); }); } @@ -206,14 +222,14 @@ impl WorkerPool { /// Use this to initialize or re-initialize per-thread state via [`Worker::init`]. /// Only `num_threads` threads execute the closure; the rest skip it. pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) { - if num_threads >= self.pool.current_num_threads() { + if num_threads >= self.pool().current_num_threads() { // Fast path: run on every thread, no atomic coordination needed. - self.pool.broadcast(|_| { + self.pool().broadcast(|_| { WORKER.with_borrow_mut(|worker| f(worker)); }); } else { let remaining = AtomicUsize::new(num_threads); - self.pool.broadcast(|_| { + self.pool().broadcast(|_| { // Atomically claim a slot; threads that can't decrement skip the closure. let mut current = remaining.load(Ordering::Relaxed); loop { @@ -237,7 +253,7 @@ impl WorkerPool { /// Clears the state on every thread in the pool. pub fn clear(&self) { - self.pool.broadcast(|_| { + self.pool().broadcast(|_| { WORKER.with_borrow_mut(Worker::clear); }); } @@ -248,7 +264,7 @@ impl WorkerPool { /// Each thread can access its own [`Worker`] via the provided reference or through additional /// [`WorkerPool::with_worker`] calls. pub fn install(&self, f: impl FnOnce(&Worker) -> R + Send) -> R { - self.pool.install(|| WORKER.with_borrow(|worker| f(worker))) + self.pool().install(|| WORKER.with_borrow(|worker| f(worker))) } /// Runs a closure on the pool without worker state access. @@ -256,19 +272,19 @@ impl WorkerPool { /// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`] /// state. pub fn install_fn(&self, f: impl FnOnce() -> R + Send) -> R { - self.pool.install(f) + self.pool().install(f) } /// Spawns a closure on the pool. pub fn spawn(&self, f: impl FnOnce() + Send + 'static) { - self.pool.spawn(f); + self.pool().spawn(f); } /// Executes `f` on this pool using [`rayon::in_place_scope`], which converts the calling /// thread into a worker for the duration — tasks spawned inside the scope run on the pool /// and the call blocks until all of them complete. pub fn in_place_scope<'scope, R>(&self, f: impl FnOnce(&rayon::Scope<'scope>) -> R) -> R { - self.pool.in_place_scope(f) + self.pool().in_place_scope(f) } /// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure. @@ -398,7 +414,7 @@ mod tests { #[test] fn worker_pool_init_and_access() { - let pool = WorkerPool::new(2).unwrap(); + let pool = WorkerPool::new(2, "test"); pool.broadcast(2, |worker| { worker.init::>(|_| vec![1, 2, 3]); @@ -415,7 +431,7 @@ mod tests { #[test] fn worker_pool_reinit_reuses_resources() { - let pool = WorkerPool::new(1).unwrap(); + let pool = WorkerPool::new(1, "test"); pool.broadcast(1, |worker| { worker.init::>(|existing| { @@ -441,7 +457,7 @@ mod tests { #[test] fn worker_pool_clear_and_reinit() { - let pool = WorkerPool::new(1).unwrap(); + let pool = WorkerPool::new(1, "test"); pool.broadcast(1, |worker| { worker.init::(|_| 42); @@ -464,7 +480,7 @@ mod tests { fn worker_pool_par_iter_with_worker() { use rayon::prelude::*; - let pool = WorkerPool::new(2).unwrap(); + let pool = WorkerPool::new(2, "test"); pool.broadcast(2, |worker| { worker.init::(|_| 10); diff --git a/crates/tasks/src/runtime.rs b/crates/tasks/src/runtime.rs index 1e3a6c14cb9..2e7c6684b64 100644 --- a/crates/tasks/src/runtime.rs +++ b/crates/tasks/src/runtime.rs @@ -16,8 +16,6 @@ use crate::{ }; use futures_util::{future::select, Future, FutureExt, TryFutureExt}; #[cfg(feature = "rayon")] -use std::sync::OnceLock; -#[cfg(feature = "rayon")] use std::{num::NonZeroUsize, thread::available_parallelism}; use std::{ pin::pin, @@ -237,34 +235,6 @@ pub enum RuntimeBuildError { RayonBuild(#[from] rayon::ThreadPoolBuildError), } -#[cfg(feature = "rayon")] -#[derive(Debug)] -struct LazyWorkerPool { - pool: OnceLock, - num_threads: usize, - thread_name_prefix: &'static str, -} - -#[cfg(feature = "rayon")] -impl LazyWorkerPool { - const fn new(num_threads: usize, thread_name_prefix: &'static str) -> Self { - Self { pool: OnceLock::new(), num_threads, thread_name_prefix } - } - - fn get(&self) -> &WorkerPool { - let num_threads = self.num_threads; - let thread_name_prefix = self.thread_name_prefix; - self.pool.get_or_init(|| { - WorkerPool::from_builder( - rayon::ThreadPoolBuilder::new() - .num_threads(num_threads) - .thread_name(move |i| format!("{thread_name_prefix}-{i:02}")), - ) - .unwrap_or_else(|err| panic!("failed to build {thread_name_prefix} worker pool: {err}")) - }) - } -} - // ── RuntimeInner ────────────────────────────────────────────────────── struct RuntimeInner { @@ -303,7 +273,7 @@ struct RuntimeInner { prewarming_pool: WorkerPool, /// BAL streaming pool (BAL hashed state streaming). #[cfg(feature = "rayon")] - bal_streaming_pool: LazyWorkerPool, + bal_streaming_pool: WorkerPool, /// Named single-thread worker map. Each unique name gets a dedicated OS thread /// that is reused across all tasks submitted under that name. worker_map: WorkerMap, @@ -392,7 +362,7 @@ impl Runtime { /// Get the BAL streaming pool. #[cfg(feature = "rayon")] pub fn bal_streaming_pool(&self) -> &WorkerPool { - self.0.bal_streaming_pool.get() + &self.0.bal_streaming_pool } } @@ -837,30 +807,20 @@ impl RuntimeBuilder { let proof_storage_worker_threads = config.rayon.proof_storage_worker_threads.unwrap_or(default_threads * 2); - let proof_storage_worker_pool = WorkerPool::from_builder( - rayon::ThreadPoolBuilder::new() - .num_threads(proof_storage_worker_threads) - .thread_name(|i| format!("proof-strg-{i:02}")), - )?; + let proof_storage_worker_pool = + WorkerPool::new(proof_storage_worker_threads, "proof-strg"); let proof_account_worker_threads = config.rayon.proof_account_worker_threads.unwrap_or(default_threads * 2); - let proof_account_worker_pool = WorkerPool::from_builder( - rayon::ThreadPoolBuilder::new() - .num_threads(proof_account_worker_threads) - .thread_name(|i| format!("proof-acct-{i:02}")), - )?; + let proof_account_worker_pool = + WorkerPool::new(proof_account_worker_threads, "proof-acct"); let prewarming_threads = config.rayon.prewarming_threads.unwrap_or(default_threads); - let prewarming_pool = WorkerPool::from_builder( - rayon::ThreadPoolBuilder::new() - .num_threads(prewarming_threads) - .thread_name(|i| format!("prewarm-{i:02}")), - )?; + let prewarming_pool = WorkerPool::new(prewarming_threads, "prewarm"); let bal_streaming_threads = config.rayon.bal_streaming_threads.unwrap_or(default_threads); - let bal_streaming_pool = LazyWorkerPool::new(bal_streaming_threads, "bal-stream"); + let bal_streaming_pool = WorkerPool::new(bal_streaming_threads, "bal-stream"); debug!( default_threads, @@ -871,7 +831,7 @@ impl RuntimeBuilder { prewarming_threads, bal_streaming_threads, max_blocking_tasks = config.rayon.max_blocking_tasks, - "Initialized rayon thread pools and configured lazy BAL streaming pool" + "Configured lazy rayon worker pools" ); ( @@ -962,12 +922,19 @@ mod tests { #[cfg(feature = "rayon")] #[test] - fn test_bal_streaming_pool_is_lazy() { + fn test_worker_pools_are_lazy() { let runtime = Runtime::test(); - assert!(runtime.0.bal_streaming_pool.pool.get().is_none()); + // Worker pools are lazy — not initialized until first access. + assert!(!runtime.0.bal_streaming_pool.is_initialized()); + assert!(!runtime.0.proof_storage_worker_pool.is_initialized()); + // Accessing them triggers initialization and returns the configured thread count. assert_eq!(runtime.bal_streaming_pool().current_num_threads(), 2); - assert!(runtime.0.bal_streaming_pool.pool.get().is_some()); + assert!(runtime.0.bal_streaming_pool.is_initialized()); + + assert_eq!(runtime.proof_storage_worker_pool().current_num_threads(), 2); + assert_eq!(runtime.proof_account_worker_pool().current_num_threads(), 2); + assert_eq!(runtime.prewarming_pool().current_num_threads(), 2); } }