Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TaskPools for use in static variables. #12990

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion benches/benches/bevy_ecs/iteration/heavy_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub fn heavy_compute(c: &mut Criterion) {
group.warm_up_time(std::time::Duration::from_millis(500));
group.measurement_time(std::time::Duration::from_secs(4));
group.bench_function("base", |b| {
ComputeTaskPool::get_or_init(TaskPool::default);
ComputeTaskPool::get_or_default();

let mut world = World::default();

Expand Down
2 changes: 1 addition & 1 deletion benches/benches/bevy_ecs/iteration/par_iter_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn insert_if_bit_enabled<const B: u16>(entity: &mut EntityWorldMut, i: u16) {

impl<'w> Benchmark<'w> {
pub fn new(fragment: u16) -> Self {
ComputeTaskPool::get_or_init(TaskPool::default);
ComputeTaskPool::get_or_default();

let mut world = World::new();

Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_asset/src/processor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ impl AssetProcessor {
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
async fn process_assets_internal<'scope>(
&'scope self,
scope: &'scope bevy_tasks::Scope<'scope, '_, ()>,
scope: &'scope bevy_tasks::StaticScope<'scope, '_, ()>,
source: &'scope AssetSource,
path: PathBuf,
) -> Result<(), AssetReaderError> {
Expand Down
60 changes: 40 additions & 20 deletions crates/bevy_core/src/task_pool_options.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
use bevy_utils::tracing::trace;
use bevy_tasks::{
AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder, TaskPoolInitializationError,
};
use bevy_utils::tracing::{trace, warn};

/// Defines a simple way to determine how many threads to use given the number of remaining cores
/// and number of total cores
Expand Down Expand Up @@ -80,6 +82,18 @@ impl Default for TaskPoolOptions {
}
}

fn handle_initialization_error(name: &str, res: Result<(), TaskPoolInitializationError>) {
match res {
Ok(()) => {}
Err(TaskPoolInitializationError::AlreadyInitialized) => {
warn!("{} already initialized.", name);
}
Err(err) => {
panic!("Error while initializing: {}", err);
}
}
}

impl TaskPoolOptions {
/// Create a configuration that forces using the given number of threads.
pub fn with_num_threads(thread_count: usize) -> Self {
Expand Down Expand Up @@ -107,12 +121,14 @@ impl TaskPoolOptions {
trace!("IO Threads: {}", io_threads);
remaining_threads = remaining_threads.saturating_sub(io_threads);

IoTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(io_threads)
.thread_name("IO Task Pool".to_string())
.build()
});
handle_initialization_error(
"IO Task Pool",
IoTaskPool::get().init(
TaskPoolBuilder::default()
.num_threads(io_threads)
.thread_name("IO Task Pool".to_string()),
),
);
}

{
Expand All @@ -124,12 +140,14 @@ impl TaskPoolOptions {
trace!("Async Compute Threads: {}", async_compute_threads);
remaining_threads = remaining_threads.saturating_sub(async_compute_threads);

AsyncComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string())
.build()
});
handle_initialization_error(
"Async Task Pool",
AsyncComputeTaskPool::get().init(
TaskPoolBuilder::default()
.num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string()),
),
);
}

{
Expand All @@ -141,12 +159,14 @@ impl TaskPoolOptions {

trace!("Compute Threads: {}", compute_threads);

ComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string())
.build()
});
handle_initialization_error(
"Compute Task Pool",
ComputeTaskPool::get().init(
TaskPoolBuilder::default()
.num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string()),
),
);
}
}
}
6 changes: 3 additions & 3 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ mod tests {
system::Resource,
world::{EntityRef, Mut, World},
};
use bevy_tasks::{ComputeTaskPool, TaskPool};
use bevy_tasks::ComputeTaskPool;
use std::num::NonZeroU32;
use std::{
any::TypeId,
Expand Down Expand Up @@ -405,7 +405,7 @@ mod tests {

#[test]
fn par_for_each_dense() {
ComputeTaskPool::get_or_init(TaskPool::default);
ComputeTaskPool::get_or_default();
let mut world = World::new();
let e1 = world.spawn(A(1)).id();
let e2 = world.spawn(A(2)).id();
Expand All @@ -428,7 +428,7 @@ mod tests {

#[test]
fn par_for_each_sparse() {
ComputeTaskPool::get_or_init(TaskPool::default);
ComputeTaskPool::get_or_default();
let mut world = World::new();
let e1 = world.spawn(SparseStored(1)).id();
let e2 = world.spawn(SparseStored(2)).id();
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_ecs/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
/// #[derive(Component, PartialEq, Debug)]
/// struct A(usize);
///
/// # bevy_tasks::ComputeTaskPool::get_or_init(|| bevy_tasks::TaskPool::new());
/// # bevy_tasks::ComputeTaskPool::get_or_default();
///
/// let mut world = World::new();
///
Expand Down
20 changes: 8 additions & 12 deletions crates/bevy_ecs/src/schedule/executor/multi_threaded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
sync::{Arc, Mutex, MutexGuard},
};

use bevy_tasks::{ComputeTaskPool, Scope, TaskPool, ThreadExecutor};
use bevy_tasks::{ComputeTaskPool, StaticScope, TaskPool, ThreadExecutor};
use bevy_utils::default;
use bevy_utils::syncunsafecell::SyncUnsafeCell;
#[cfg(feature = "trace")]
Expand Down Expand Up @@ -132,7 +132,7 @@ pub struct ExecutorState {
#[derive(Copy, Clone)]
struct Context<'scope, 'env, 'sys> {
environment: &'env Environment<'env, 'sys>,
scope: &'scope Scope<'scope, 'env, ()>,
scope: &'scope StaticScope<'scope, 'env, ()>,
}

impl Default for MultiThreadedExecutor {
Expand Down Expand Up @@ -218,17 +218,13 @@ impl SystemExecutor for MultiThreadedExecutor {

let environment = &Environment::new(self, schedule, world);

ComputeTaskPool::get_or_init(TaskPool::default).scope_with_executor(
false,
thread_executor,
|scope| {
let context = Context { environment, scope };
ComputeTaskPool::get().scope_with_executor(false, thread_executor, |scope| {
let context = Context { environment, scope };

// The first tick won't need to process finished systems, but we still need to run the loop in
// tick_executor() in case a system completes while the first tick still holds the mutex.
context.tick_executor();
},
);
// The first tick won't need to process finished systems, but we still need to run the loop in
// tick_executor() in case a system completes while the first tick still holds the mutex.
context.tick_executor();
});

// End the borrows of self and world in environment by copying out the reference to systems.
let systems = environment.systems;
Expand Down
4 changes: 2 additions & 2 deletions crates/bevy_ecs/src/schedule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ mod tests {
#[test]
#[cfg(not(miri))]
fn parallel_execution() {
use bevy_tasks::{ComputeTaskPool, TaskPool};
use bevy_tasks::ComputeTaskPool;
use std::sync::{Arc, Barrier};

let mut world = World::default();
let mut schedule = Schedule::default();
let thread_count = ComputeTaskPool::get_or_init(TaskPool::default).thread_num();
let thread_count = ComputeTaskPool::get_or_default().thread_num();

let barrier = Arc::new(Barrier::new(thread_count));

Expand Down
7 changes: 6 additions & 1 deletion crates/bevy_tasks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@ keywords = ["bevy"]

[features]
multi-threaded = ["dep:async-channel", "dep:async-task", "dep:concurrent-queue"]
trace = ["tracing"]

[dependencies]
futures-lite = "2.0.1"
async-executor = "1.7.2"
async-executor = { git = "https://github.com/james7132/async-executor", branch = "leaked-executor", features = [
"static",
] }
async-channel = { version = "2.2.0", optional = true }
async-io = { version = "2.0.0", optional = true }
async-task = { version = "4.2.0", optional = true }
concurrent-queue = { version = "2.0.0", optional = true }
tracing = { version = "0.1", optional = true }
thiserror = "1.0"

[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen-futures = "0.4"
Expand Down
23 changes: 22 additions & 1 deletion crates/bevy_tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ pub use slice::{ParallelSlice, ParallelSliceMut};
mod task;
pub use task::Task;

#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
mod static_task_pool;
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
mod task_pool;
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub use static_task_pool::{StaticScope, StaticTaskPool};
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub use task_pool::{Scope, TaskPool, TaskPoolBuilder};

#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
mod single_threaded_task_pool;
#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
pub use single_threaded_task_pool::{FakeTask, Scope, TaskPool, TaskPoolBuilder, ThreadExecutor};
pub use single_threaded_task_pool::{
FakeTask, Scope, StaticScope, StaticTaskPool, TaskPool, TaskPoolBuilder, ThreadExecutor,
};

mod usages;
#[cfg(not(target_arch = "wasm32"))]
Expand All @@ -41,6 +47,7 @@ mod iter;
pub use iter::ParallelIterator;

pub use futures_lite;
use thiserror::Error;

#[allow(missing_docs)]
pub mod prelude {
Expand All @@ -55,6 +62,20 @@ pub mod prelude {

use std::num::NonZeroUsize;

/// Potential errors when initializing a [`StaticTaskPool`].
#[derive(Error, Debug)]
pub enum TaskPoolInitializationError {
/// The task pool was already initialized and cannot be changed after initialization.
#[error("The task pool is already initialized.")]
AlreadyInitialized,
/// The task pool would have been initialized with zero threads.
#[error("The task pool would have been initialized with zero threads.")]
ZeroThreads,
/// Initialization failed to spawn a thread.
#[error("Failed to spawn thread: {0:?}")]
ThreadSpawnError(#[from] std::io::Error),
}

/// Gets the logical CPU core count available to the current process.
///
/// This is identical to [`std::thread::available_parallelism`], except
Expand Down
32 changes: 27 additions & 5 deletions crates/bevy_tasks/src/single_threaded_task_pool.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
use crate::TaskPoolInitializationError;
use std::sync::Arc;
use std::{cell::RefCell, future::Future, marker::PhantomData, mem, rc::Rc};

thread_local! {
static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new();
}

/// A [`TaskPool`] optimized for use in static variables.
pub type StaticTaskPool = TaskPool;

/// A [`StaticTaskPool`] scope for running one or more non-`'static` futures.
///
/// For more information, see [`TaskPool::scope`].
pub type StaticScope<'scope, 'env, T> = Scope<'scope, 'env, T>;

/// Used to create a [`TaskPool`].
#[derive(Debug, Default, Clone)]
pub struct TaskPoolBuilder {}
Expand All @@ -25,8 +34,8 @@ impl<'a> ThreadExecutor<'a> {

impl TaskPoolBuilder {
/// Creates a new `TaskPoolBuilder` instance
pub fn new() -> Self {
Self::default()
pub const fn new() -> Self {
Self {}
}

/// No op on the single threaded task pool
Expand All @@ -45,7 +54,7 @@ impl TaskPoolBuilder {
}

/// Creates a new [`TaskPool`]
pub fn build(self) -> TaskPool {
pub const fn build(self) -> TaskPool {
TaskPool::new_internal()
}
}
Expand All @@ -62,15 +71,28 @@ impl TaskPool {
}

/// Create a `TaskPool` with the default configuration.
pub fn new() -> Self {
pub const fn new() -> Self {
TaskPoolBuilder::new().build()
}

#[allow(unused_variables)]
fn new_internal() -> Self {
const fn new_internal() -> Self {
Self {}
}

/// Checks if the threads in the task pool have been started or not. This always returns
/// true in single threaded builds.
pub fn is_initialized(&self) -> bool {
true
}

/// Initializes the task pool with the provided builder. This always is a no-op
/// true in single threaded builds.
#[allow(unused_variables)]
pub fn init(&self, builder: TaskPoolBuilder) -> Result<(), TaskPoolInitializationError> {
Ok(())
}

/// Return the number of threads owned by the task pool
pub fn thread_num(&self) -> usize {
1
Expand Down
Loading
Loading