diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index fd16e8b3e7de1..a476ca78e2776 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -54,8 +54,8 @@ crossbeam-queue = { version = "0.3", default-features = false, features = [ ] } [target.'cfg(target_arch = "wasm32")'.dependencies] -pin-project = { version = "1" } -futures-channel = { version = "0.3", default-features = false } +pin-project = "1" +async-channel = "2.3.0" [target.'cfg(not(all(target_has_atomic = "8", target_has_atomic = "16", target_has_atomic = "32", target_has_atomic = "64", target_has_atomic = "ptr")))'.dependencies] async-task = { version = "4.4.0", default-features = false, features = [ @@ -72,6 +72,7 @@ atomic-waker = { version = "1", default-features = false, features = [ futures-lite = { version = "2.0.1", default-features = false, features = [ "std", ] } +async-channel = "2.3.0" [lints] workspace = true diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index 1aa4598243a0c..ddb014bb9867b 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -71,7 +71,13 @@ use alloc::boxed::Box; /// An owned and dynamically typed Future used when you can't statically type your result or need to add some indirection. pub type BoxedFuture<'a, T> = core::pin::Pin + 'a>>; +// Modules +mod executor; pub mod futures; +mod iter; +mod slice; +mod task; +mod usages; cfg::async_executor! { if {} else { @@ -79,24 +85,21 @@ cfg::async_executor! { } } -mod executor; - -mod slice; +// Exports +pub use iter::ParallelIterator; pub use slice::{ParallelSlice, ParallelSliceMut}; +pub use task::Task; +pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; -cfg::web! { - if { - #[path = "wasm_task.rs"] - mod task; - } else { - mod task; +pub use futures_lite; +pub use futures_lite::future::poll_once; +cfg::web! { + if {} else { pub use usages::tick_global_task_pools_on_main_thread; } } -pub use task::Task; - cfg::multi_threaded! { if { mod task_pool; @@ -111,10 +114,6 @@ cfg::multi_threaded! { } } -mod usages; -pub use futures_lite::future::poll_once; -pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; - cfg::switch! { cfg::async_io => { pub use async_io::block_on; @@ -147,11 +146,6 @@ cfg::switch! { } } -mod iter; -pub use iter::ParallelIterator; - -pub use futures_lite; - /// The tasks prelude. /// /// This includes the most common types in this crate, re-exported for your convenience. diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index d7f74c2d66c86..d81e43b4e91b9 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -1,26 +1,25 @@ use alloc::{string::String, vec::Vec}; use bevy_platform::sync::Arc; -use core::{cell::RefCell, future::Future, marker::PhantomData, mem}; +use core::{cell::{RefCell, Cell}, future::Future, marker::PhantomData, mem}; -use crate::Task; +use crate::executor::LocalExecutor; +use crate::{block_on, Task}; crate::cfg::std! { if { use std::thread_local; - use crate::executor::LocalExecutor; + + use crate::executor::LocalExecutor as Executor; thread_local! { - static LOCAL_EXECUTOR: LocalExecutor<'static> = const { LocalExecutor::new() }; + static LOCAL_EXECUTOR: Executor<'static> = const { Executor::new() }; } - - type ScopeResult = alloc::rc::Rc>>; } else { - use bevy_platform::sync::{Mutex, PoisonError}; - use crate::executor::Executor as LocalExecutor; - static LOCAL_EXECUTOR: LocalExecutor<'static> = const { LocalExecutor::new() }; + // Because we do not have thread-locals without std, we cannot use LocalExecutor here. + use crate::executor::Executor; - type ScopeResult = Arc>>; + static LOCAL_EXECUTOR: Executor<'static> = const { Executor::new() }; } } @@ -111,7 +110,7 @@ impl TaskPool { /// This is similar to `rayon::scope` and `crossbeam::scope` pub fn scope<'env, F, T>(&self, f: F) -> Vec where - F: for<'scope> FnOnce(&'env mut Scope<'scope, 'env, T>), + F: for<'scope> FnOnce(&'scope mut Scope<'scope, 'env, T>), T: Send + 'static, { self.scope_with_executor(false, None, f) @@ -130,7 +129,7 @@ impl TaskPool { f: F, ) -> Vec where - F: for<'scope> FnOnce(&'env mut Scope<'scope, 'env, T>), + F: for<'scope> FnOnce(&'scope mut Scope<'scope, 'env, T>), T: Send + 'static, { // SAFETY: This safety comment applies to all references transmuted to 'env. @@ -141,17 +140,22 @@ impl TaskPool { // Any usages of the references passed into `Scope` must be accessed through // the transmuted reference for the rest of this function. - let executor = &LocalExecutor::new(); + let executor = LocalExecutor::new(); + // SAFETY: As above, all futures must complete in this function so we can change the lifetime + let executor_ref: &'env LocalExecutor<'env> = unsafe { mem::transmute(&executor) }; + + let results: RefCell>> = RefCell::new(Vec::new()); // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let executor: &'env LocalExecutor<'env> = unsafe { mem::transmute(executor) }; + let results_ref: &'env RefCell>> = unsafe { mem::transmute(&results) }; - let results: RefCell>> = RefCell::new(Vec::new()); + let pending_tasks: Cell = Cell::new(0); // SAFETY: As above, all futures must complete in this function so we can change the lifetime - let results: &'env RefCell>> = unsafe { mem::transmute(&results) }; + let pending_tasks: &'env Cell = unsafe { mem::transmute(&pending_tasks) }; let mut scope = Scope { - executor, - results, + executor_ref, + pending_tasks, + results_ref, scope: PhantomData, env: PhantomData, }; @@ -161,21 +165,17 @@ impl TaskPool { f(scope_ref); - // Loop until all tasks are done - while executor.try_tick() {} + // Wait until the scope is complete + block_on(executor.run(async { + while pending_tasks.get() != 0 { + futures_lite::future::yield_now().await; + } + })); - let results = scope.results.borrow(); results - .iter() - .map(|result| crate::cfg::switch! {{ - crate::cfg::std => { - result.borrow_mut().take().unwrap() - } - _ => { - let mut lock = result.lock().unwrap_or_else(PoisonError::into_inner); - lock.take().unwrap() - } - }}) + .take() + .into_iter() + .map(|result| result.unwrap()) .collect() } @@ -239,7 +239,7 @@ impl TaskPool { /// ``` pub fn with_local_executor(&self, f: F) -> R where - F: FnOnce(&LocalExecutor) -> R, + F: FnOnce(&Executor) -> R, { crate::cfg::switch! {{ crate::cfg::std => { @@ -257,9 +257,11 @@ impl TaskPool { /// For more information, see [`TaskPool::scope`]. #[derive(Debug)] pub struct Scope<'scope, 'env: 'scope, T> { - executor: &'scope LocalExecutor<'scope>, + executor_ref: &'scope LocalExecutor<'scope>, + // The number of pending tasks spawned on the scope + pending_tasks: &'scope Cell, // Vector to gather results of all futures spawned during scope run - results: &'env RefCell>>, + results_ref: &'env RefCell>>, // make `Scope` invariant over 'scope and 'env scope: PhantomData<&'scope mut &'scope ()>, @@ -295,21 +297,32 @@ impl<'scope, 'env, T: Send + 'env> Scope<'scope, 'env, T> { /// /// For more information, see [`TaskPool::scope`]. pub fn spawn_on_scope + 'scope + MaybeSend>(&self, f: Fut) { - let result = ScopeResult::::default(); - self.results.borrow_mut().push(result.clone()); + // increment the number of pending tasks + let pending_tasks = self.pending_tasks; + pending_tasks.update(|i| i + 1); + + // add a spot to keep the result, and record the index + let results_ref = self.results_ref; + let mut results = results_ref.borrow_mut(); + let task_number = results.len(); + results.push(None); + drop(results); + + // create the job closure let f = async move { - let temp_result = f.await; - - crate::cfg::std! { - if { - result.borrow_mut().replace(temp_result); - } else { - let mut lock = result.lock().unwrap_or_else(PoisonError::into_inner); - *lock = Some(temp_result); - } - } + let result = f.await; + + // store the result in the allocated slot + let mut results = results_ref.borrow_mut(); + results[task_number] = Some(result); + drop(results); + + // decrement the pending tasks count + pending_tasks.update(|i| i - 1); }; - self.executor.spawn(f).detach(); + + // spawn the job itself + self.executor_ref.spawn(f).detach(); } } @@ -328,3 +341,32 @@ crate::cfg::std! { impl MaybeSync for T {} } } + +#[cfg(test)] +mod test { + use std::{time, thread}; + + use super::*; + + /// This test creates a scope with a single task that goes to sleep for a + /// nontrivial amount of time. At one point, the scope would (incorrectly) + /// return early under these conditions, causing a crash. + /// + /// The correct behavior is for the scope to block until the receiver is + /// woken by the external thread. + #[test] + fn scoped_spawn() { + let (sender, recever) = async_channel::unbounded(); + let task_pool = TaskPool {}; + let thread = thread::spawn(move || { + let duration = time::Duration::from_millis(50); + thread::sleep(duration); + let _ = sender.send(0); + }); + task_pool.scope(|scope| { + scope.spawn(async { + recever.recv().await + }); + }); + } +} diff --git a/crates/bevy_tasks/src/task.rs b/crates/bevy_tasks/src/task.rs index d4afb775f2e01..dd649ba47dca3 100644 --- a/crates/bevy_tasks/src/task.rs +++ b/crates/bevy_tasks/src/task.rs @@ -1,9 +1,12 @@ +use alloc::fmt; use core::{ future::Future, pin::Pin, task::{Context, Poll}, }; +use crate::cfg; + /// Wraps `async_executor::Task`, a spawned future. /// /// Tasks are also futures themselves and yield the output of the spawned future. @@ -12,20 +15,59 @@ use core::{ /// more gracefully and wait until it stops running, use the [`Task::cancel()`] method. /// /// Tasks that panic get immediately canceled. Awaiting a canceled task also causes a panic. -#[derive(Debug)] #[must_use = "Tasks are canceled when dropped, use `.detach()` to run them in the background."] -pub struct Task(async_task::Task); +pub struct Task( + cfg::web! { + if { + async_channel::Receiver> + } else { + async_task::Task + } + }, +); -impl Task { - /// Creates a new task from a given `async_executor::Task` - pub fn new(task: async_task::Task) -> Self { - Self(task) +// Custom constructors for web and non-web platforms +cfg::web! { + if { + impl Task { + /// Creates a new task by passing the given future to the web + /// runtime as a promise. + pub(crate) fn wrap_future(future: impl Future + 'static) -> Self { + use bevy_platform::exports::wasm_bindgen_futures::spawn_local; + let (sender, receiver) = async_channel::bounded(1); + spawn_local(async move { + // Catch any panics that occur when polling the future so they can + // be propagated back to the task handle. + let value = CatchUnwind(AssertUnwindSafe(future)).await; + let _ = sender.send(value); + }); + Self(receiver) + } + } + } else { + impl Task { + /// Creates a new task from a given `async_executor::Task` + pub(crate) fn new(task: async_task::Task) -> Self { + Self(task) + } + } } +} - /// Detaches the task to let it keep running in the background. See - /// `async_executor::Task::detach` +impl Task { + /// Detaches the task to let it keep running in the background. + /// + /// # Platform-Specific Behavior + /// + /// When building for the web, this method has no effect. pub fn detach(self) { - self.0.detach(); + cfg::web! { + if { + // Tasks are already treated as detached on the web. + } else { + self.0.detach(); + } + } } /// Cancels the task and waits for it to stop running. @@ -36,25 +78,135 @@ impl Task { /// While it's possible to simply drop the [`Task`] to cancel it, this is a cleaner way of /// canceling because it also waits for the task to stop running. /// - /// See `async_executor::Task::cancel` + /// # Platform-Specific Behavior + /// + /// Canceling tasks is unsupported on the web, and this is the same as awaiting the task. pub async fn cancel(self) -> Option { - self.0.cancel().await + cfg::web! { + if { + // Await the task and handle any panics. + match self.0.recv().await { + Ok(Ok(value)) => Some(value), + Err(_) => None, + Ok(Err(panic)) => { + // drop this to prevent the panic payload from resuming the panic on drop. + // this also leaks the box but I'm not sure how to avoid that + core::mem::forget(panic); + None + } + } + } else { + // Wait for the task to become canceled + self.0.cancel().await + } + } } /// Returns `true` if the current task is finished. /// - /// /// Unlike poll, it doesn't resolve the final value, it just checks if the task has finished. /// Note that in a multithreaded environment, this task can be finished immediately after calling this function. pub fn is_finished(&self) -> bool { - self.0.is_finished() + cfg::web! { + if { + // We treat the task as unfinished until the result is sent over the channel. + !self.0.is_empty() + } else { + // Defer to the `async_task` implementation. + self.0.is_finished() + } + } } } impl Future for Task { type Output = T; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.0).poll(cx) + cfg::web! { + if { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // `recv()` returns a future, so we just poll that and hand the result. + let recv = core::pin::pin!(self.0.recv()); + match recv.poll(cx) { + Poll::Ready(Ok(Ok(value))) => Poll::Ready(value), + // NOTE: Propagating the panic here sorta has parity with the async_executor behavior. + // For those tasks, polling them after a panic returns a `None` which gets `unwrap`ed, so + // using `resume_unwind` here is essentially keeping the same behavior while adding more information. + Poll::Ready(Ok(Err(_panic))) => crate::cfg::switch! {{ + crate::cfg::std => { + std::panic::resume_unwind(_panic) + } + _ => { + unreachable!("catching a panic is only possible with std") + } + }}, + Poll::Ready(Err(_)) => panic!("Polled a task after it finished running"), + Poll::Pending => Poll::Pending, + } + } + } else { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // `async_task` has `Task` implement `Future`, so we just poll it. + Pin::new(&mut self.0).poll(cx) + } + } + } +} + +// All variants of Task are expected to implement Unpin +impl Unpin for Task {} + +// Derive doesn't work for macro types, so we have to implement this manually. +impl fmt::Debug for Task { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +// Utilities for catching unwinds on the web. +cfg::web! { + use alloc::boxed::Box; + use core::{ + panic::{AssertUnwindSafe, UnwindSafe}, + any::Any, + }; + + type Panic = Box; + + #[pin_project::pin_project] + struct CatchUnwind(#[pin] F); + + impl Future for CatchUnwind { + type Output = Result; + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let f = AssertUnwindSafe(|| self.project().0.poll(cx)); + + let result = cfg::std! { + if { + std::panic::catch_unwind(f)? + } else { + f() + } + }; + + result.map(Ok) + } + } +} + +#[cfg(test)] +mod tests { + use crate::Task; + + #[test] + fn task_is_sync() { + fn is_sync() {} + is_sync::>(); + } + + #[test] + fn task_is_send() { + fn is_send() {} + is_send::>(); } } diff --git a/crates/bevy_tasks/src/wasm_task.rs b/crates/bevy_tasks/src/wasm_task.rs deleted file mode 100644 index 91eac7304ddec..0000000000000 --- a/crates/bevy_tasks/src/wasm_task.rs +++ /dev/null @@ -1,99 +0,0 @@ -use alloc::boxed::Box; -use core::{ - any::Any, - future::Future, - panic::{AssertUnwindSafe, UnwindSafe}, - pin::Pin, - task::{Context, Poll}, -}; - -use futures_channel::oneshot; -use bevy_platform::exports::wasm_bindgen_futures; - -/// Wraps an asynchronous task, a spawned future. -/// -/// Tasks are also futures themselves and yield the output of the spawned future. -#[derive(Debug)] -pub struct Task(oneshot::Receiver>); - -impl Task { - pub(crate) fn wrap_future(future: impl Future + 'static) -> Self { - let (sender, receiver) = oneshot::channel(); - wasm_bindgen_futures::spawn_local(async move { - // Catch any panics that occur when polling the future so they can - // be propagated back to the task handle. - let value = CatchUnwind(AssertUnwindSafe(future)).await; - let _ = sender.send(value); - }); - Self(receiver) - } - - /// When building for Wasm, this method has no effect. - /// This is only included for feature parity with other platforms. - pub fn detach(self) {} - - /// Requests a task to be cancelled and returns a future that suspends until it completes. - /// Returns the output of the future if it has already completed. - /// - /// # Implementation - /// - /// When building for Wasm, it is not possible to cancel tasks, which means this is the same - /// as just awaiting the task. This method is only included for feature parity with other platforms. - pub async fn cancel(self) -> Option { - match self.0.await { - Ok(Ok(value)) => Some(value), - Err(_) => None, - Ok(Err(panic)) => { - // drop this to prevent the panic payload from resuming the panic on drop. - // this also leaks the box but I'm not sure how to avoid that - core::mem::forget(panic); - None - } - } - } -} - -impl Future for Task { - type Output = T; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match Pin::new(&mut self.0).poll(cx) { - Poll::Ready(Ok(Ok(value))) => Poll::Ready(value), - // NOTE: Propagating the panic here sorta has parity with the async_executor behavior. - // For those tasks, polling them after a panic returns a `None` which gets `unwrap`ed, so - // using `resume_unwind` here is essentially keeping the same behavior while adding more information. - Poll::Ready(Ok(Err(_panic))) => crate::cfg::switch! {{ - crate::cfg::std => { - std::panic::resume_unwind(_panic) - } - _ => { - unreachable!("catching a panic is only possible with std") - } - }}, - Poll::Ready(Err(_)) => panic!("Polled a task after it was cancelled"), - Poll::Pending => Poll::Pending, - } - } -} - -type Panic = Box; - -#[pin_project::pin_project] -struct CatchUnwind(#[pin] F); - -impl Future for CatchUnwind { - type Output = Result; - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - let f = AssertUnwindSafe(|| self.project().0.poll(cx)); - - let result = crate::cfg::switch! {{ - crate::cfg::std => { - std::panic::catch_unwind(f)? - } - _ => { - f() - } - }}; - - result.map(Ok) - } -}