Skip to content

Commit

Permalink
Consolidate together Bevy's TaskPools
Browse files Browse the repository at this point in the history
  • Loading branch information
james7132 committed Feb 24, 2024
1 parent 9d420b4 commit d6cbbbb
Show file tree
Hide file tree
Showing 19 changed files with 142 additions and 266 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5027,7 +5027,7 @@ current changes on git with [previous release tags][git_tag_comparison].
- [Fix confusing near and far fields in Camera][4457]
- [Allow minimising window if using a 2d camera][4527]
- [WGSL: use correct syntax for matrix access][5039]
- [Gltf: do not import `IoTaskPool` in wasm][5038]
- [Gltf: do not import `ComputeTaskPool` in wasm][5038]
- [Fix skinned mesh normal handling in mesh shader][5095]
- [Don't panic when `StandardMaterial` `normal_map` hasn't loaded yet][5307]
- [Fix incorrect rotation in `Transform::rotate_around`][5300]
Expand Down
8 changes: 4 additions & 4 deletions crates/bevy_asset/src/processor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
};
use bevy_ecs::prelude::*;
use bevy_log::{debug, error, trace, warn};
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_utils::{BoxedFuture, HashMap, HashSet};
use futures_io::ErrorKind;
use futures_lite::{AsyncReadExt, AsyncWriteExt, StreamExt};
Expand Down Expand Up @@ -165,7 +165,7 @@ impl AssetProcessor {
pub fn process_assets(&self) {
let start_time = std::time::Instant::now();
debug!("Processing Assets");
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
scope.spawn(async move {
self.initialize().await.unwrap();
for source in self.sources().iter_processed() {
Expand Down Expand Up @@ -315,7 +315,7 @@ impl AssetProcessor {
#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
error!("AddFolder event cannot be handled in single threaded mode (or WASM) yet.");
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
scope.spawn(async move {
self.process_assets_internal(scope, source, path)
.await
Expand Down Expand Up @@ -457,7 +457,7 @@ impl AssetProcessor {
loop {
let mut check_reprocess_queue =
std::mem::take(&mut self.data.asset_infos.write().await.check_reprocess_queue);
IoTaskPool::get().scope(|scope| {
ComputeTaskPool::get().scope(|scope| {
for path in check_reprocess_queue.drain(..) {
let processor = self.clone();
let source = self.get_source(path.source()).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/bevy_asset/src/server/loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
};
use async_broadcast::RecvError;
use bevy_log::{error, warn};
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_utils::{HashMap, TypeIdMap};
use std::{any::TypeId, sync::Arc};
use thiserror::Error;
Expand Down Expand Up @@ -78,7 +78,7 @@ impl AssetLoaders {
match maybe_loader {
MaybeAssetLoader::Ready(_) => unreachable!(),
MaybeAssetLoader::Pending { sender, .. } => {
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let _ = sender.broadcast(loader).await;
})
Expand Down
10 changes: 5 additions & 5 deletions crates/bevy_asset/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
};
use bevy_ecs::prelude::*;
use bevy_log::{error, info};
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_utils::{CowArc, HashSet};
use crossbeam_channel::{Receiver, Sender};
use futures_lite::StreamExt;
Expand Down Expand Up @@ -296,7 +296,7 @@ impl AssetServer {
if should_load {
let owned_handle = Some(handle.clone().untyped());
let server = self.clone();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
if let Err(err) = server.load_internal(owned_handle, path, false, None).await {
error!("{}", err);
Expand Down Expand Up @@ -366,7 +366,7 @@ impl AssetServer {
let id = handle.id().untyped();

let server = self.clone();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let path_clone = path.clone();
match server.load_untyped_async(path).await {
Expand Down Expand Up @@ -551,7 +551,7 @@ impl AssetServer {
pub fn reload<'a>(&self, path: impl Into<AssetPath<'a>>) {
let server = self.clone();
let path = path.into().into_owned();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let mut reloaded = false;

Expand Down Expand Up @@ -690,7 +690,7 @@ impl AssetServer {

let path = path.into_owned();
let server = self.clone();
IoTaskPool::get()
ComputeTaskPool::get()
.spawn(async move {
let Ok(source) = server.get_source(path.source()) else {
error!(
Expand Down
21 changes: 2 additions & 19 deletions crates/bevy_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ fn register_math_types(app: &mut App) {
.register_type::<Vec<bevy_math::Vec3>>();
}

/// Setup of default task pools: [`AsyncComputeTaskPool`](bevy_tasks::AsyncComputeTaskPool),
/// [`ComputeTaskPool`](bevy_tasks::ComputeTaskPool), [`IoTaskPool`](bevy_tasks::IoTaskPool).
/// Setup of default task pool: [`ComputeTaskPool`](bevy_tasks::ComputeTaskPool).
#[derive(Default)]
pub struct TaskPoolPlugin {
/// Options for the [`TaskPool`](bevy_tasks::TaskPool) created at application start.
Expand Down Expand Up @@ -175,39 +174,23 @@ pub fn update_frame_count(mut frame_count: ResMut<FrameCount>) {
#[cfg(test)]
mod tests {
use super::*;
use bevy_tasks::prelude::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
use bevy_tasks::prelude::ComputeTaskPool;

#[test]
fn runs_spawn_local_tasks() {
let mut app = App::new();
app.add_plugins((TaskPoolPlugin::default(), TypeRegistrationPlugin));

let (async_tx, async_rx) = crossbeam_channel::unbounded();
AsyncComputeTaskPool::get()
.spawn_local(async move {
async_tx.send(()).unwrap();
})
.detach();

let (compute_tx, compute_rx) = crossbeam_channel::unbounded();
ComputeTaskPool::get()
.spawn_local(async move {
compute_tx.send(()).unwrap();
})
.detach();

let (io_tx, io_rx) = crossbeam_channel::unbounded();
IoTaskPool::get()
.spawn_local(async move {
io_tx.send(()).unwrap();
})
.detach();

app.run();

async_rx.try_recv().unwrap();
compute_rx.try_recv().unwrap();
io_rx.try_recv().unwrap();
}

#[test]
Expand Down
117 changes: 7 additions & 110 deletions crates/bevy_core/src/task_pool_options.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,6 @@
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
use bevy_tasks::{ComputeTaskPool, TaskPoolBuilder};
use bevy_utils::tracing::trace;

/// Defines a simple way to determine how many threads to use given the number of remaining cores
/// and number of total cores
#[derive(Clone, Debug)]
pub struct TaskPoolThreadAssignmentPolicy {
/// Force using at least this many threads
pub min_threads: usize,
/// Under no circumstance use more than this many threads for this pool
pub max_threads: usize,
/// Target using this percentage of total cores, clamped by min_threads and max_threads. It is
/// permitted to use 1.0 to try to use all remaining threads
pub percent: f32,
}

impl TaskPoolThreadAssignmentPolicy {
/// Determine the number of threads to use for this task pool
fn get_number_of_threads(&self, remaining_threads: usize, total_threads: usize) -> usize {
assert!(self.percent >= 0.0);
let mut desired = (total_threads as f32 * self.percent).round() as usize;

// Limit ourselves to the number of cores available
desired = desired.min(remaining_threads);

// Clamp by min_threads, max_threads. (This may result in us using more threads than are
// available, this is intended. An example case where this might happen is a device with
// <= 2 threads.
desired.clamp(self.min_threads, self.max_threads)
}
}

/// Helper for configuring and creating the default task pools. For end-users who want full control,
/// set up [`TaskPoolPlugin`](super::TaskPoolPlugin)
#[derive(Clone, Debug)]
Expand All @@ -40,13 +11,6 @@ pub struct TaskPoolOptions {
/// If the number of physical cores is greater than max_total_threads, force using
/// max_total_threads
pub max_total_threads: usize,

/// Used to determine number of IO threads to allocate
pub io: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of async compute threads to allocate
pub async_compute: TaskPoolThreadAssignmentPolicy,
/// Used to determine number of compute threads to allocate
pub compute: TaskPoolThreadAssignmentPolicy,
}

impl Default for TaskPoolOptions {
Expand All @@ -55,27 +19,6 @@ impl Default for TaskPoolOptions {
// By default, use however many cores are available on the system
min_total_threads: 1,
max_total_threads: usize::MAX,

// Use 25% of cores for IO, at least 1, no more than 4
io: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
},

// Use 25% of cores for async compute, at least 1, no more than 4
async_compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: 4,
percent: 0.25,
},

// Use all remaining cores for compute (at least 1)
compute: TaskPoolThreadAssignmentPolicy {
min_threads: 1,
max_threads: usize::MAX,
percent: 1.0, // This 1.0 here means "whatever is left over"
},
}
}
}
Expand All @@ -96,57 +39,11 @@ impl TaskPoolOptions {
.clamp(self.min_total_threads, self.max_total_threads);
trace!("Assigning {} cores to default task pools", total_threads);

let mut remaining_threads = total_threads;

{
// Determine the number of IO threads we will use
let io_threads = self
.io
.get_number_of_threads(remaining_threads, total_threads);

trace!("IO Threads: {}", io_threads);
remaining_threads = remaining_threads.saturating_sub(io_threads);

IoTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(io_threads)
.thread_name("IO Task Pool".to_string())
.build()
});
}

{
// Determine the number of async compute threads we will use
let async_compute_threads = self
.async_compute
.get_number_of_threads(remaining_threads, total_threads);

trace!("Async Compute Threads: {}", async_compute_threads);
remaining_threads = remaining_threads.saturating_sub(async_compute_threads);

AsyncComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string())
.build()
});
}

{
// Determine the number of compute threads we will use
// This is intentionally last so that an end user can specify 1.0 as the percent
let compute_threads = self
.compute
.get_number_of_threads(remaining_threads, total_threads);

trace!("Compute Threads: {}", compute_threads);

ComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string())
.build()
});
}
ComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default()
.num_threads(total_threads)
.thread_name("Compute Task Pool".to_string())
.build()
});
}
}
4 changes: 2 additions & 2 deletions crates/bevy_gltf/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use bevy_render::{
};
use bevy_scene::Scene;
#[cfg(not(target_arch = "wasm32"))]
use bevy_tasks::IoTaskPool;
use bevy_tasks::ComputeTaskPool;
use bevy_transform::components::Transform;
use bevy_utils::{
smallvec::{smallvec, SmallVec},
Expand Down Expand Up @@ -348,7 +348,7 @@ async fn load_gltf<'a, 'b, 'c>(
}
} else {
#[cfg(not(target_arch = "wasm32"))]
IoTaskPool::get()
ComputeTaskPool::get()
.scope(|scope| {
gltf.textures().for_each(|gltf_texture| {
let parent_path = load_context.path().parent().unwrap();
Expand Down
1 change: 0 additions & 1 deletion crates/bevy_pbr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ bevy_derive = { path = "../bevy_derive", version = "0.14.0-dev" }

# other
bitflags = "2.3"
fixedbitset = "0.4"
# direct dependency required for derive macro
bytemuck = { version = "1", features = ["derive"] }
radsort = "0.1"
Expand Down
22 changes: 18 additions & 4 deletions crates/bevy_pbr/src/render/mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use bevy_core_pipeline::{
deferred::{AlphaMask3dDeferred, Opaque3dDeferred},
};
use bevy_derive::{Deref, DerefMut};
use bevy_ecs::entity::EntityHashMap;
use bevy_ecs::entity::{EntityHashMap, EntityHasher};
use bevy_ecs::{
prelude::*,
query::ROQueryItem,
Expand All @@ -36,7 +36,10 @@ use bevy_render::{
};
use bevy_transform::components::GlobalTransform;
use bevy_utils::{tracing::error, Entry, HashMap, Hashed};
use std::cell::Cell;
use std::{
cell::Cell,
hash::{Hash, Hasher},
};
use thread_local::ThreadLocal;

#[cfg(debug_assertions)]
Expand Down Expand Up @@ -268,7 +271,7 @@ pub struct RenderMeshInstances(EntityHashMap<RenderMeshInstance>);

pub fn extract_meshes(
mut render_mesh_instances: ResMut<RenderMeshInstances>,
mut thread_local_queues: Local<ThreadLocal<Cell<Vec<(Entity, RenderMeshInstance)>>>>,
mut thread_local_queues: Local<ThreadLocal<Cell<Vec<(u64, Entity, RenderMeshInstance)>>>>,
meshes_query: Extract<
Query<(
Entity,
Expand Down Expand Up @@ -316,9 +319,12 @@ pub fn extract_meshes(
previous_transform: (&previous_transform).into(),
flags: flags.bits(),
};
let mut hasher = EntityHasher::default();
entity.hash(&mut hasher);
let tls = thread_local_queues.get_or_default();
let mut queue = tls.take();
queue.push((
hasher.finish(),
entity,
RenderMeshInstance {
mesh_asset_id: handle.id(),
Expand All @@ -332,9 +338,17 @@ pub fn extract_meshes(
},
);

let render_mesh_instances = render_mesh_instances.bypass_change_detection();
render_mesh_instances.clear();
for queue in thread_local_queues.iter_mut() {
render_mesh_instances.extend(queue.get_mut().drain(..));
let queue = queue.get_mut();
render_mesh_instances.reserve(queue.len());
for (hash, entity, instance) in queue.drain(..) {
render_mesh_instances
.raw_entry_mut()
.from_key_hashed_nocheck(hash, &entity)
.insert(entity, instance);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ impl Plugin for RenderPlugin {
};
// In wasm, spawn a task and detach it for execution
#[cfg(target_arch = "wasm32")]
bevy_tasks::IoTaskPool::get()
bevy_tasks::ComputeTaskPool::get()
.spawn_local(async_renderer)
.detach();
// Otherwise, just block for it to complete
Expand Down
Loading

0 comments on commit d6cbbbb

Please sign in to comment.