diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 2ef244407af77..099f96e93d006 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -2,11 +2,11 @@ use std::{ future::Future, marker::PhantomData, mem, - pin::Pin, sync::Arc, thread::{self, JoinHandle}, }; +use async_task::FallibleTask; use concurrent_queue::ConcurrentQueue; use futures_lite::{future, pin, FutureExt}; @@ -248,8 +248,8 @@ impl TaskPool { let task_scope_executor = &async_executor::Executor::default(); let task_scope_executor: &'env async_executor::Executor = unsafe { mem::transmute(task_scope_executor) }; - let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); - let spawned_ref: &'env ConcurrentQueue> = + let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); + let spawned_ref: &'env ConcurrentQueue> = unsafe { mem::transmute(&spawned) }; let scope = Scope { @@ -267,10 +267,10 @@ impl TaskPool { if spawned.is_empty() { Vec::new() } else { - let get_results = async move { - let mut results = Vec::with_capacity(spawned.len()); - while let Ok(task) = spawned.pop() { - results.push(task.await); + let get_results = async { + let mut results = Vec::with_capacity(spawned_ref.len()); + while let Ok(task) = spawned_ref.pop() { + results.push(task.await.unwrap()); } results @@ -279,23 +279,8 @@ impl TaskPool { // Pin the futures on the stack. pin!(get_results); - // SAFETY: This function blocks until all futures complete, so we do not read/write - // the data from futures outside of the 'scope lifetime. However, - // rust has no way of knowing this so we must convert to 'static - // here to appease the compiler as it is unable to validate safety. - let get_results: Pin<&mut (dyn Future> + 'static + Send)> = get_results; - let get_results: Pin<&'static mut (dyn Future> + 'static + Send)> = - unsafe { mem::transmute(get_results) }; - - // The thread that calls scope() will participate in driving tasks in the pool - // forward until the tasks that are spawned by this scope() call - // complete. (If the caller of scope() happens to be a thread in - // this thread pool, and we only have one thread in the pool, then - // simply calling future::block_on(spawned) would deadlock.) - let mut spawned = task_scope_executor.spawn(get_results); - loop { - if let Some(result) = future::block_on(future::poll_once(&mut spawned)) { + if let Some(result) = future::block_on(future::poll_once(&mut get_results)) { break result; }; @@ -378,7 +363,7 @@ impl Drop for TaskPool { pub struct Scope<'scope, 'env: 'scope, T> { executor: &'scope async_executor::Executor<'scope>, task_scope_executor: &'scope async_executor::Executor<'scope>, - spawned: &'scope ConcurrentQueue>, + spawned: &'scope ConcurrentQueue>, // make `Scope` invariant over 'scope and 'env scope: PhantomData<&'scope mut &'scope ()>, env: PhantomData<&'env mut &'env ()>, @@ -394,7 +379,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// /// For more information, see [`TaskPool::scope`]. pub fn spawn + 'scope + Send>(&self, f: Fut) { - let task = self.executor.spawn(f); + let task = self.executor.spawn(f).fallible(); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbouded queue, so it is safe to unwrap self.spawned.push(task).unwrap(); @@ -407,13 +392,26 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// /// For more information, see [`TaskPool::scope`]. pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { - let task = self.task_scope_executor.spawn(f); + let task = self.task_scope_executor.spawn(f).fallible(); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbouded queue, so it is safe to unwrap self.spawned.push(task).unwrap(); } } +impl<'scope, 'env, T> Drop for Scope<'scope, 'env, T> +where + T: 'scope, +{ + fn drop(&mut self) { + future::block_on(async { + while let Ok(task) = self.spawned.pop() { + task.cancel().await; + } + }); + } +} + #[cfg(test)] #[allow(clippy::disallowed_types)] mod tests {