Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a forced shutdown on tokio runtimes as the STDIN blocks the shu… #4107

Merged
merged 5 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions lib/c-api/src/wasm_c_api/wasi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ unsafe fn wasi_env_with_filesystem_inner(
let module = &module.as_ref()?.inner;
let imports = imports?;

let (wasi_env, import_object, runtime) = prepare_webc_env(
let (wasi_env, import_object) = prepare_webc_env(
config,
&mut store.store_mut(),
module,
Expand All @@ -247,7 +247,6 @@ unsafe fn wasi_env_with_filesystem_inner(
Some(Box::new(wasi_env_t {
inner: wasi_env,
store: store.clone(),
_runtime: runtime,
}))
}

Expand All @@ -259,7 +258,7 @@ fn prepare_webc_env(
bytes: &'static u8,
len: usize,
package_name: &str,
) -> Option<(WasiFunctionEnv, Imports, tokio::runtime::Runtime)> {
) -> Option<(WasiFunctionEnv, Imports)> {
use virtual_fs::static_fs::StaticFileSystem;
use webc::v1::{FsEntryType, WebC};

Expand All @@ -275,7 +274,7 @@ fn prepare_webc_env(

let handle = runtime.handle().clone();
let _guard = handle.enter();
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(runtime)));
rt.set_engine(Some(store_mut.engine().clone()));

let slice = unsafe { std::slice::from_raw_parts(bytes, len) };
Expand Down Expand Up @@ -316,15 +315,14 @@ fn prepare_webc_env(
let env = builder.finalize(store).ok()?;

let import_object = env.import_object(store, module).ok()?;
Some((env, import_object, runtime))
Some((env, import_object))
}

#[allow(non_camel_case_types)]
pub struct wasi_env_t {
/// cbindgen:ignore
pub(super) inner: WasiFunctionEnv,
pub(super) store: StoreRef,
pub(super) _runtime: tokio::runtime::Runtime,
}

/// Create a new WASI environment.
Expand All @@ -349,7 +347,7 @@ pub unsafe extern "C" fn wasi_env_new(

let handle = runtime.handle().clone();
let _guard = handle.enter();
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(runtime)));
rt.set_engine(Some(store_mut.engine().clone()));

if !config.inherit_stdout {
Expand All @@ -370,7 +368,6 @@ pub unsafe extern "C" fn wasi_env_new(
Some(Box::new(wasi_env_t {
inner: env,
store: store.clone(),
_runtime: runtime,
}))
}

Expand Down
6 changes: 3 additions & 3 deletions lib/cli/src/commands/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ impl Run {

let _guard = handle.enter();
let (store, _) = self.store.get_store()?;
let runtime =
self.wasi
.prepare_runtime(store.engine().clone(), &self.env, handle.clone())?;
let runtime = self
.wasi
.prepare_runtime(store.engine().clone(), &self.env, runtime)?;

// This is a slow operation, so let's temporarily wrap the runtime with
// something that displays progress
Expand Down
16 changes: 11 additions & 5 deletions lib/cli/src/commands/run/wasi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use wasmer_wasix::{
FileSystemSource, InMemorySource, MultiSource, PackageSpecifier, Source, WapmSource,
WebSource,
},
task_manager::{tokio::TokioTaskManager, VirtualTaskManagerExt},
task_manager::{
tokio::{RuntimeOrHandle, TokioTaskManager},
VirtualTaskManagerExt,
},
},
types::__WASI_STDIN_FILENO,
wasmer_wasix_types::wasi::Errno,
Expand Down Expand Up @@ -247,13 +250,16 @@ impl Wasi {
caps
}

pub fn prepare_runtime(
pub fn prepare_runtime<I>(
&self,
engine: Engine,
env: &WasmerEnv,
handle: Handle,
) -> Result<impl Runtime + Send + Sync> {
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
rt_or_handle: I,
) -> Result<impl Runtime + Send + Sync>
where
I: Into<RuntimeOrHandle>,
{
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(rt_or_handle.into())));

if self.networking {
rt.set_networking_implementation(virtual_net::host::LocalNetworking::default());
Expand Down
7 changes: 4 additions & 3 deletions lib/wasix/src/os/console/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,10 @@ mod tests {
#[cfg_attr(not(feature = "host-reqwest"), ignore = "Requires a HTTP client")]
fn test_console_dash_tty_with_args_and_env() {
let tokio_rt = tokio::runtime::Runtime::new().unwrap();
let _guard = tokio_rt.handle().enter();
let rt_handle = tokio_rt.handle().clone();
let _guard = rt_handle.enter();

let tm = TokioTaskManager::new(tokio_rt.handle().clone());
let tm = TokioTaskManager::new(tokio_rt);
let mut rt = PluggableRuntime::new(Arc::new(tm));
rt.set_engine(Some(wasmer::Engine::default()))
.set_package_loader(BuiltinPackageLoader::from_env().unwrap());
Expand All @@ -316,7 +317,7 @@ mod tests {
.run()
.unwrap();

let code = tokio_rt
let code = rt_handle
.block_on(async move {
virtual_fs::AsyncWriteExt::write_all(
&mut stdin_tx,
Expand Down
7 changes: 7 additions & 0 deletions lib/wasix/src/runtime/task_manager/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ pub struct ThreadTaskManager {
pool: Arc<rayon::ThreadPool>,
}

impl Drop
for ThreadTaskManager {
fn drop(&mut self) {
self.runtime.shutdown_timeout(Duration::from_secs(0));
}
}

impl Default for ThreadTaskManager {
#[cfg(feature = "sys-thread")]
fn default() -> Self {
Expand Down
55 changes: 47 additions & 8 deletions lib/wasix/src/runtime/task_manager/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,67 @@
use std::sync::Mutex;
use std::{num::NonZeroUsize, pin::Pin, sync::Arc, time::Duration};

use futures::{future::BoxFuture, Future};
use tokio::runtime::Handle;
use tokio::runtime::{Handle, Runtime};

use crate::{os::task::thread::WasiThreadError, WasiFunctionEnv};

use super::{TaskWasm, TaskWasmRunProperties, VirtualTaskManager};

#[derive(Debug, Clone)]
pub enum RuntimeOrHandle {
Handle(Handle),
Runtime(Handle, Arc<Mutex<Option<Runtime>>>),
}
impl From<Handle> for RuntimeOrHandle {
fn from(value: Handle) -> Self {
Self::Handle(value)
}
}
impl From<Runtime> for RuntimeOrHandle {
fn from(value: Runtime) -> Self {
Self::Runtime(value.handle().clone(), Arc::new(Mutex::new(Some(value))))
}
}

impl Drop for RuntimeOrHandle {
fn drop(&mut self) {
if let Self::Runtime(_, runtime) = self {
if let Some(h) = runtime.lock().unwrap().take() {
h.shutdown_timeout(Duration::from_secs(0))
}
}
}
}

impl RuntimeOrHandle {
pub fn handle(&self) -> &Handle {
match self {
Self::Handle(h) => h,
Self::Runtime(h, _) => h,
}
}
}

/// A task manager that uses tokio to spawn tasks.
#[derive(Clone, Debug)]
pub struct TokioTaskManager {
handle: Handle,
rt: RuntimeOrHandle,
john-sharratt marked this conversation as resolved.
Show resolved Hide resolved
pool: Arc<rayon::ThreadPool>,
}

impl TokioTaskManager {
pub fn new(rt: Handle) -> Self {
pub fn new<I>(rt: I) -> Self
where
I: Into<RuntimeOrHandle>,
{
let concurrency = std::thread::available_parallelism()
.unwrap_or(NonZeroUsize::new(1).unwrap())
.get();
let max_threads = 200usize.max(concurrency * 100);

Self {
handle: rt,
rt: rt.into(),
pool: Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(max_threads)
Expand All @@ -33,7 +72,7 @@ impl TokioTaskManager {
}

pub fn runtime_handle(&self) -> tokio::runtime::Handle {
self.handle.clone()
self.rt.handle().clone()
}
}

Expand All @@ -55,7 +94,7 @@ impl VirtualTaskManager for TokioTaskManager {
/// See [`VirtualTaskManager::sleep_now`].
fn sleep_now(&self, time: Duration) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
self.handle.spawn(async move {
self.rt.handle().spawn(async move {
if time == Duration::ZERO {
tokio::task::yield_now().await;
} else {
Expand All @@ -73,7 +112,7 @@ impl VirtualTaskManager for TokioTaskManager {
&self,
task: Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send + 'static>,
) -> Result<(), WasiThreadError> {
self.handle.spawn(async move {
self.rt.handle().spawn(async move {
let fut = task();
fut.await
});
Expand All @@ -99,7 +138,7 @@ impl VirtualTaskManager for TokioTaskManager {

let trigger = trigger();
let pool = self.pool.clone();
self.handle.spawn(async move {
self.rt.handle().spawn(async move {
let result = trigger.await;
// Build the task that will go on the callback
pool.spawn(move || {
Expand Down
2 changes: 1 addition & 1 deletion tests/lib/wast/src/wasi_wast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl<'a> WasiTest<'a> {
#[cfg(not(target_arch = "wasm32"))]
let _guard = handle.enter();
#[cfg(not(target_arch = "wasm32"))]
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(handle)));
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::new(runtime)));
#[cfg(target_arch = "wasm32")]
let mut rt = PluggableRuntime::new(Arc::new(TokioTaskManager::default()));
rt.set_engine(Some(store.engine().clone()));
Expand Down