diff --git a/lib/wasi/src/runtime/task_manager/tokio.rs b/lib/wasi/src/runtime/task_manager/tokio.rs index f50e7fcf35a..b54f7314c5c 100644 --- a/lib/wasi/src/runtime/task_manager/tokio.rs +++ b/lib/wasi/src/runtime/task_manager/tokio.rs @@ -1,4 +1,8 @@ -use std::{pin::Pin, time::Duration}; +use std::{ + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, +}; use futures::Future; use tokio::runtime::Handle; @@ -15,6 +19,10 @@ use super::{SpawnType, VirtualTaskManager}; #[derive(Clone, Debug)] pub struct TokioTaskManager(Handle); +/// This holds the currently set shared runtime which should be accessed via +/// TokioTaskManager::shared() and/or set via TokioTaskManager::set_shared() +static GLOBAL_RUNTIME: Mutex, Handle)>> = Mutex::new(None); + impl TokioTaskManager { pub fn new(rt: Handle) -> Self { Self(rt) @@ -24,35 +32,35 @@ impl TokioTaskManager { self.0.clone() } + /// Allows the caller to set the shared runtime that will be used by other + /// async processes within Wasmer + /// + /// The shared runtime must be set before it is used and can only be set once + /// otherwise this call will fail with an error. + pub fn set_shared(rt: Arc) -> Result<(), anyhow::Error> { + let mut guard = GLOBAL_RUNTIME.lock().unwrap(); + if guard.is_some() { + return Err(anyhow::format_err!("The shared runtime has already been set or lazy initialized - it can not be overridden")); + } + guard.replace((rt.clone(), rt.handle().clone())); + Ok(()) + } + /// Shared tokio [`Runtime`] that is used by default. /// /// This exists because a tokio runtime is heavy, and there should not be many /// independent ones in a process. pub fn shared() -> Self { - static GLOBAL_RUNTIME: once_cell::sync::Lazy<(Option, Handle)> = - once_cell::sync::Lazy::new(|| { - if let Ok(handle) = tokio::runtime::Handle::try_current() { - (None, handle) - } else { - #[cfg(feature = "sys")] - { - let rt = tokio::runtime::Runtime::new().unwrap(); - let handle = rt.handle().clone(); - (Some(rt), handle) - } - #[cfg(not(feature = "sys"))] - { - let rt = tokio::runtime::Runtime::new().unwrap(); - let handle = rt.handle().clone(); - (Some(rt), handle) - } - } - }); - if let Ok(handle) = tokio::runtime::Handle::try_current() { Self(handle) } else { - Self(GLOBAL_RUNTIME.1.clone()) + let mut guard = GLOBAL_RUNTIME.lock().unwrap(); + let rt = guard.get_or_insert_with(|| { + let rt = tokio::runtime::Runtime::new().unwrap(); + let handle = rt.handle().clone(); + (Arc::new(rt), handle) + }); + Self(rt.1.clone()) } } }