diff --git a/lib/wasi/Cargo.toml b/lib/wasi/Cargo.toml index f5fa6011c57..bcead472aba 100644 --- a/lib/wasi/Cargo.toml +++ b/lib/wasi/Cargo.toml @@ -37,7 +37,7 @@ sha2 = { version = "0.10" } waker-fn = { version = "1.1" } cooked-waker = "^5" rand = "0.8" -tokio = { version = "1", features = ["sync", "macros", "time"], default_features = false } +tokio = { version = "1", features = ["sync", "macros", "time", "rt"], default_features = false } futures = { version = "0.3" } # used by feature='os' async-trait = { version = "^0.1" } diff --git a/lib/wasi/src/os/task/process.rs b/lib/wasi/src/os/task/process.rs index 829d5cd5ed8..38fb54d2907 100644 --- a/lib/wasi/src/os/task/process.rs +++ b/lib/wasi/src/os/task/process.rs @@ -201,11 +201,7 @@ impl WasiProcess { inner.threads.insert(id, ctrl.clone()); inner.thread_count += 1; - Ok(WasiThreadHandle { - id: Arc::new(id), - thread: ctrl, - inner: self.inner.clone(), - }) + Ok(WasiThreadHandle::new(ctrl, &self.inner)) } /// Gets a reference to a particular thread diff --git a/lib/wasi/src/os/task/thread.rs b/lib/wasi/src/os/task/thread.rs index a57e2af6867..7be61bfcaf7 100644 --- a/lib/wasi/src/os/task/thread.rs +++ b/lib/wasi/src/os/task/thread.rs @@ -1,7 +1,7 @@ use std::{ collections::HashMap, ops::{Deref, DerefMut}, - sync::{Arc, Mutex, RwLock}, + sync::{Arc, Mutex, RwLock, Weak}, task::Waker, }; @@ -379,29 +379,45 @@ impl WasiThread { } } +#[derive(Debug)] +pub struct WasiThreadHandleProtected { + thread: WasiThread, + inner: Weak>, +} + #[derive(Debug, Clone)] pub struct WasiThreadHandle { - pub(super) id: Arc, - pub(super) thread: WasiThread, - pub(super) inner: Arc>, + protected: Arc, } impl WasiThreadHandle { + pub(crate) fn new( + thread: WasiThread, + inner: &Arc>, + ) -> WasiThreadHandle { + Self { + protected: Arc::new(WasiThreadHandleProtected { + thread, + inner: Arc::downgrade(inner), + }), + } + } + pub fn id(&self) -> WasiThreadId { - self.id.0.into() + self.protected.thread.tid() } pub fn as_thread(&self) -> WasiThread { - self.thread.clone() + self.protected.thread.clone() } } -impl Drop for WasiThreadHandle { +impl Drop for WasiThreadHandleProtected { fn drop(&mut self) { - // We do this so we track when the last handle goes out of scope - if let Some(id) = Arc::get_mut(&mut self.id) { - let mut inner = self.inner.write().unwrap(); - if let Some(ctrl) = inner.threads.remove(id) { + let id = self.thread.tid(); + if let Some(inner) = Weak::upgrade(&self.inner) { + let mut inner = inner.write().unwrap(); + if let Some(ctrl) = inner.threads.remove(&id) { ctrl.set_status_finished(Ok(0)); } inner.thread_count -= 1; @@ -413,13 +429,7 @@ impl std::ops::Deref for WasiThreadHandle { type Target = WasiThread; fn deref(&self) -> &Self::Target { - &self.thread - } -} - -impl std::ops::DerefMut for WasiThreadHandle { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.thread + &self.protected.thread } } diff --git a/lib/wasi/src/syscalls/wasix/futex_wait.rs b/lib/wasi/src/syscalls/wasix/futex_wait.rs index 14f8dc89a68..41dfd83a74d 100644 --- a/lib/wasi/src/syscalls/wasix/futex_wait.rs +++ b/lib/wasi/src/syscalls/wasix/futex_wait.rs @@ -73,21 +73,10 @@ pub fn futex_wait( timeout: WasmPtr, ret_woken: WasmPtr, ) -> Result { - trace!( - "wasi[{}:{}]::futex_wait(offset={})", - ctx.data().pid(), - ctx.data().tid(), - futex_ptr.offset() - ); - wasi_try_ok!(WasiEnv::process_signals_and_exit(&mut ctx)?); - let mut env = ctx.data(); - let state = env.state.clone(); - - let futex_idx: u64 = wasi_try_ok!(futex_ptr.offset().try_into().map_err(|_| Errno::Overflow)); - // Determine the timeout + let mut env = ctx.data(); let timeout = { let memory = env.memory_view(&ctx); wasi_try_mem_ok!(timeout.read(&memory)) @@ -97,6 +86,17 @@ pub fn futex_wait( _ => None, }; + trace!( + "wasi[{}:{}]::futex_wait(offset={}, timeout={:?})", + ctx.data().pid(), + ctx.data().tid(), + futex_ptr.offset(), + timeout + ); + + let state = env.state.clone(); + let futex_idx: u64 = wasi_try_ok!(futex_ptr.offset().try_into().map_err(|_| Errno::Overflow)); + // Create a poller which will register ourselves against // this futex event and check when it has changed let view = env.memory_view(&ctx); @@ -123,6 +123,6 @@ pub fn futex_wait( }; let memory = env.memory_view(&ctx); let mut env = ctx.data(); - wasi_try_mem_ok!(ret_woken.write(&memory, Bool::False)); + wasi_try_mem_ok!(ret_woken.write(&memory, woken)); Ok(ret) } diff --git a/lib/wasi/src/syscalls/wasix/thread_spawn.rs b/lib/wasi/src/syscalls/wasix/thread_spawn.rs index db866164986..dd12dcdcec0 100644 --- a/lib/wasi/src/syscalls/wasix/thread_spawn.rs +++ b/lib/wasi/src/syscalls/wasix/thread_spawn.rs @@ -32,16 +32,6 @@ pub fn thread_spawn( reactor: Bool, ret_tid: WasmPtr, ) -> Errno { - debug!( - "wasi[{}:{}]::thread_spawn (reactor={:?}, thread_id={}, stack_base={}, caller_id={})", - ctx.data().pid(), - ctx.data().tid(), - reactor, - ctx.data().thread.tid().raw(), - stack_base, - current_caller_id().raw() - ); - // Now we use the environment and memory references let env = ctx.data(); let memory = env.memory_view(&ctx); @@ -53,11 +43,10 @@ pub fn thread_spawn( Ok(h) => h, Err(err) => { error!( - "wasi[{}:{}]::thread_spawn (reactor={:?}, thread_id={}, stack_base={}, caller_id={}) - failed to create thread handle: {}", + "wasi[{}:{}]::thread_spawn (reactor={:?}, stack_base={}, caller_id={}) - failed to create thread handle: {}", ctx.data().pid(), ctx.data().tid(), reactor, - ctx.data().thread.tid().raw(), stack_base, current_caller_id().raw(), err @@ -68,6 +57,16 @@ pub fn thread_spawn( }; let thread_id: Tid = thread_handle.id().into(); + debug!( + %thread_id, + "wasi[{}:{}]::thread_spawn (reactor={:?}, stack_base={}, caller_id={})", + ctx.data().pid(), + ctx.data().tid(), + reactor, + stack_base, + current_caller_id().raw() + ); + // We need a copy of the process memory and a packaged store in order to // launch threads and reactors let thread_memory = wasi_try!(ctx.data().memory().try_clone(&ctx).ok_or_else(|| { @@ -145,10 +144,16 @@ pub fn thread_spawn( let user_data_low: u32 = (user_data & 0xFFFFFFFF) as u32; let user_data_high: u32 = (user_data >> 32) as u32; + trace!( + %user_data, + "wasi[{}:{}]::thread_spawn spawn.call()", + ctx.data(&store).pid(), + ctx.data(&store).tid(), + ); + let mut ret = Errno::Success; if let Err(err) = spawn.call(store, user_data_low as i32, user_data_high as i32) { match err.downcast::() { - Ok(WasiError::Exit(0)) => ret = Errno::Success, Ok(WasiError::Exit(code)) => { debug!( %code, @@ -156,7 +161,11 @@ pub fn thread_spawn( ctx.data(&store).pid(), ctx.data(&store).tid(), ); - ret = Errno::Noexec; + ret = if code == 0 { + Errno::Success + } else { + Errno::Noexec + }; } Ok(WasiError::UnknownWasiVersion) => { debug!( @@ -219,6 +228,12 @@ pub fn thread_spawn( if let Some(thread) = thread { let mut store = thread.store.borrow_mut(); let ret = call_module(&thread.ctx, store.deref_mut()); + + { + let mut guard = state.threading.write().unwrap(); + guard.thread_ctx.remove(&caller_id); + } + return ret; }