Skip to content

Commit

Permalink
Merge pull request #4107 from wasmerio/tokio-runtime-fix
Browse files Browse the repository at this point in the history
Added a forced shutdown on tokio runtimes as the STDIN blocks the shu…
  • Loading branch information
ptitSeb authored Jul 26, 2023
2 parents 58e7487 + 9e729e0 commit edc4a52
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 28 deletions.
13 changes: 5 additions & 8 deletions lib/c-api/src/wasm_c_api/wasi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ unsafe fn wasi_env_with_filesystem_inner(
let module = &module.as_ref()?.inner;
let imports = imports?;

let (wasi_env, import_object, runtime) = prepare_webc_env(
let (wasi_env, import_object) = prepare_webc_env(
config,
&mut store.store_mut(),
module,
Expand All @@ -247,7 +247,6 @@ unsafe fn wasi_env_with_filesystem_inner(
Some(Box::new(wasi_env_t {
inner: wasi_env,
store: store.clone(),
_runtime: runtime,
}))
}

Expand All @@ -259,7 +258,7 @@ fn prepare_webc_env(
bytes: &'static u8,
len: usize,
package_name: &str,
) -> Option<(WasiFunctionEnv, Imports, tokio::runtime::Runtime)> {
) -> Option<(WasiFunctionEnv, Imports)> {
use virtual_fs::static_fs::StaticFileSystem;
use webc::v1::{FsEntryType, WebC};

Expand All @@ -275,7 +274,7 @@ fn prepare_webc_env(

let handle = runtime.handle().clone();
let _guard = handle.enter();
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(runtime)));
rt.set_engine(Some(store_mut.engine().clone()));

let slice = unsafe { std::slice::from_raw_parts(bytes, len) };
Expand Down Expand Up @@ -316,15 +315,14 @@ fn prepare_webc_env(
let env = builder.finalize(store).ok()?;

let import_object = env.import_object(store, module).ok()?;
Some((env, import_object, runtime))
Some((env, import_object))
}

#[allow(non_camel_case_types)]
pub struct wasi_env_t {
/// cbindgen:ignore
pub(super) inner: WasiFunctionEnv,
pub(super) store: StoreRef,
pub(super) _runtime: tokio::runtime::Runtime,
}

/// Create a new WASI environment.
Expand All @@ -349,7 +347,7 @@ pub unsafe extern "C" fn wasi_env_new(

let handle = runtime.handle().clone();
let _guard = handle.enter();
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(runtime)));
rt.set_engine(Some(store_mut.engine().clone()));

if !config.inherit_stdout {
Expand All @@ -370,7 +368,6 @@ pub unsafe extern "C" fn wasi_env_new(
Some(Box::new(wasi_env_t {
inner: env,
store: store.clone(),
_runtime: runtime,
}))
}

Expand Down
6 changes: 3 additions & 3 deletions lib/cli/src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ impl Run {

let _guard = handle.enter();
let (store, _) = self.store.get_store()?;
let runtime =
self.wasi
.prepare_runtime(store.engine().clone(), &self.env, handle.clone())?;
let runtime = self
.wasi
.prepare_runtime(store.engine().clone(), &self.env, runtime)?;

// This is a slow operation, so let's temporarily wrap the runtime with
// something that displays progress
Expand Down
16 changes: 11 additions & 5 deletions lib/cli/src/commands/run/wasi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use wasmer_wasix::{
FileSystemSource, InMemorySource, MultiSource, PackageSpecifier, Source, WapmSource,
WebSource,
},
task_manager::{tokio::TokioTaskManager, VirtualTaskManagerExt},
task_manager::{
tokio::{RuntimeOrHandle, TokioTaskManager},
VirtualTaskManagerExt,
},
},
types::__WASI_STDIN_FILENO,
wasmer_wasix_types::wasi::Errno,
Expand Down Expand Up @@ -247,13 +250,16 @@ impl Wasi {
caps
}

pub fn prepare_runtime(
pub fn prepare_runtime<I>(
&self,
engine: Engine,
env: &WasmerEnv,
handle: Handle,
) -> Result<impl Runtime + Send + Sync> {
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
rt_or_handle: I,
) -> Result<impl Runtime + Send + Sync>
where
I: Into<RuntimeOrHandle>,
{
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(rt_or_handle.into())));

if self.networking {
rt.set_networking_implementation(virtual_net::host::LocalNetworking::default());
Expand Down
7 changes: 4 additions & 3 deletions lib/wasix/src/os/console/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,10 @@ mod tests {
#[cfg_attr(not(feature = "host-reqwest"), ignore = "Requires a HTTP client")]
fn test_console_dash_tty_with_args_and_env() {
let tokio_rt = tokio::runtime::Runtime::new().unwrap();
let _guard = tokio_rt.handle().enter();
let rt_handle = tokio_rt.handle().clone();
let _guard = rt_handle.enter();

let tm = TokioTaskManager::new(tokio_rt.handle().clone());
let tm = TokioTaskManager::new(tokio_rt);
let mut rt = PluggableRuntime::new(Arc::new(tm));
rt.set_engine(Some(wasmer::Engine::default()))
.set_package_loader(BuiltinPackageLoader::from_env().unwrap());
Expand All @@ -316,7 +317,7 @@ mod tests {
.run()
.unwrap();

let code = tokio_rt
let code = rt_handle
.block_on(async move {
virtual_fs::AsyncWriteExt::write_all(
&mut stdin_tx,
Expand Down
7 changes: 7 additions & 0 deletions lib/wasix/src/runtime/task_manager/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ pub struct ThreadTaskManager {
pool: Arc<rayon::ThreadPool>,
}

impl Drop
for ThreadTaskManager {
fn drop(&mut self) {
self.runtime.shutdown_timeout(Duration::from_secs(0));
}
}

impl Default for ThreadTaskManager {
#[cfg(feature = "sys-thread")]
fn default() -> Self {
Expand Down
55 changes: 47 additions & 8 deletions lib/wasix/src/runtime/task_manager/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,67 @@
use std::sync::Mutex;
use std::{num::NonZeroUsize, pin::Pin, sync::Arc, time::Duration};

use futures::{future::BoxFuture, Future};
use tokio::runtime::Handle;
use tokio::runtime::{Handle, Runtime};

use crate::{os::task::thread::WasiThreadError, WasiFunctionEnv};

use super::{TaskWasm, TaskWasmRunProperties, VirtualTaskManager};

#[derive(Debug, Clone)]
pub enum RuntimeOrHandle {
Handle(Handle),
Runtime(Handle, Arc<Mutex<Option<Runtime>>>),
}
impl From<Handle> for RuntimeOrHandle {
fn from(value: Handle) -> Self {
Self::Handle(value)
}
}
impl From<Runtime> for RuntimeOrHandle {
fn from(value: Runtime) -> Self {
Self::Runtime(value.handle().clone(), Arc::new(Mutex::new(Some(value))))
}
}

impl Drop for RuntimeOrHandle {
fn drop(&mut self) {
if let Self::Runtime(_, runtime) = self {
if let Some(h) = runtime.lock().unwrap().take() {
h.shutdown_timeout(Duration::from_secs(0))
}
}
}
}

impl RuntimeOrHandle {
pub fn handle(&self) -> &Handle {
match self {
Self::Handle(h) => h,
Self::Runtime(h, _) => h,
}
}
}

/// A task manager that uses tokio to spawn tasks.
#[derive(Clone, Debug)]
pub struct TokioTaskManager {
handle: Handle,
rt: RuntimeOrHandle,
pool: Arc<rayon::ThreadPool>,
}

impl TokioTaskManager {
pub fn new(rt: Handle) -> Self {
pub fn new<I>(rt: I) -> Self
where
I: Into<RuntimeOrHandle>,
{
let concurrency = std::thread::available_parallelism()
.unwrap_or(NonZeroUsize::new(1).unwrap())
.get();
let max_threads = 200usize.max(concurrency * 100);

Self {
handle: rt,
rt: rt.into(),
pool: Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(max_threads)
Expand All @@ -33,7 +72,7 @@ impl TokioTaskManager {
}

pub fn runtime_handle(&self) -> tokio::runtime::Handle {
self.handle.clone()
self.rt.handle().clone()
}
}

Expand All @@ -55,7 +94,7 @@ impl VirtualTaskManager for TokioTaskManager {
/// See [`VirtualTaskManager::sleep_now`].
fn sleep_now(&self, time: Duration) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
self.handle.spawn(async move {
self.rt.handle().spawn(async move {
if time == Duration::ZERO {
tokio::task::yield_now().await;
} else {
Expand All @@ -73,7 +112,7 @@ impl VirtualTaskManager for TokioTaskManager {
&self,
task: Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send + 'static>,
) -> Result<(), WasiThreadError> {
self.handle.spawn(async move {
self.rt.handle().spawn(async move {
let fut = task();
fut.await
});
Expand All @@ -99,7 +138,7 @@ impl VirtualTaskManager for TokioTaskManager {

let trigger = trigger();
let pool = self.pool.clone();
self.handle.spawn(async move {
self.rt.handle().spawn(async move {
let result = trigger.await;
// Build the task that will go on the callback
pool.spawn(move || {
Expand Down
2 changes: 1 addition & 1 deletion tests/lib/wast/src/wasi_wast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl<'a> WasiTest<'a> {
#[cfg(not(target_arch = "wasm32"))]
let _guard = handle.enter();
#[cfg(not(target_arch = "wasm32"))]
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(runtime)));
#[cfg(target_arch = "wasm32")]
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::default()));
rt.set_engine(Some(store.engine().clone()));
Expand Down

0 comments on commit edc4a52

Please sign in to comment.