diff --git a/lib/wasi/src/os/console/mod.rs b/lib/wasi/src/os/console/mod.rs index 10ad7d700aa..02bfeae4a6e 100644 --- a/lib/wasi/src/os/console/mod.rs +++ b/lib/wasi/src/os/console/mod.rs @@ -18,9 +18,9 @@ use tokio::sync::{mpsc, RwLock}; use tracing::{debug, error, info, trace, warn}; #[cfg(feature = "sys")] use wasmer::Engine; -use wasmer_vbus::{BusSpawnedProcess, SpawnOptionsConfig}; +use wasmer_vbus::{BusSpawnedProcess, SpawnOptionsConfig, VirtualBusError}; use wasmer_vfs::{FileSystem, RootFileSystemBuilder, SpecialFile, WasiPipe}; -use wasmer_wasi_types::types::__WASI_STDIN_FILENO; +use wasmer_wasi_types::{types::__WASI_STDIN_FILENO, wasi::BusErrno}; use super::{cconst::ConsoleConst, common::*}; use crate::{ @@ -161,8 +161,12 @@ impl Console { // Create the control plane, process and thread let control_plane = WasiControlPlane::default(); - let wasi_process = control_plane.new_process(); - let wasi_thread = wasi_process.new_thread(); + let wasi_process = control_plane + .new_process() + .expect("creating processes on new control planes should always work"); + let wasi_thread = wasi_process + .new_thread() + .expect("creating the main thread should always work"); // Create the state let mut state = WasiState::builder(prog); diff --git a/lib/wasi/src/os/task/control_plane.rs b/lib/wasi/src/os/task/control_plane.rs index aaf3ce96dd4..8c58a15f2bb 100644 --- a/lib/wasi/src/os/task/control_plane.rs +++ b/lib/wasi/src/os/task/control_plane.rs @@ -1,99 +1,206 @@ use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, RwLock, + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, }, }; -use crate::{os::task::process::WasiProcessInner, WasiProcess, WasiProcessId}; +use crate::{WasiProcess, WasiProcessId}; #[derive(Debug, Clone)] pub struct WasiControlPlane { - /// The processes running on this machine - pub(crate) processes: Arc>>, - /// Seed used to generate process ID's - pub(crate) process_seed: Arc, - /// Allows for a PID to be reserved - pub(crate) reserved: Arc>>, + state: Arc, } -impl Default for WasiControlPlane { - fn default() -> Self { +#[derive(Debug, Clone)] +pub struct ControlPlaneConfig { + /// Total number of tasks (processes + threads) that can be spawned. + pub max_task_count: Option, +} + +impl ControlPlaneConfig { + pub fn new() -> Self { Self { - processes: Default::default(), - process_seed: Arc::new(AtomicU32::new(0)), - reserved: Default::default(), + max_task_count: None, } } } +impl Default for ControlPlaneConfig { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +struct State { + config: ControlPlaneConfig, + + /// Total number of active tasks (threads) across all processes. + task_count: Arc, + + /// Mutable state. + mutable: RwLock, +} + +#[derive(Debug)] +struct MutableState { + /// Seed used to generate process ID's + process_seed: u32, + /// The processes running on this machine + processes: HashMap, + // TODO: keep a queue of terminated process ids for id reuse. +} + impl WasiControlPlane { - /// Reserves a PID and returns it - pub fn reserve_pid(&self) -> WasiProcessId { - let mut pid: WasiProcessId; - loop { - pid = self.process_seed.fetch_add(1, Ordering::AcqRel).into(); - - { - let mut guard = self.reserved.lock().unwrap(); - if guard.contains(&pid) { - continue; - } - guard.insert(pid); - } + pub fn new(config: ControlPlaneConfig) -> Self { + Self { + state: Arc::new(State { + config, + task_count: Arc::new(AtomicUsize::new(0)), + mutable: RwLock::new(MutableState { + process_seed: 0, + processes: Default::default(), + }), + }), + } + } - { - let guard = self.processes.read().unwrap(); - if guard.contains_key(&pid) == false { - break; - } - } + /// Get the current count of active tasks (threads). + fn active_task_count(&self) -> usize { + self.state.task_count.load(Ordering::SeqCst) + } - { - let mut guard = self.reserved.lock().unwrap(); - guard.remove(&pid); + /// Register a new task. + /// + // Currently just increments the task counter. + pub(super) fn register_task(&self) -> Result { + let count = self.state.task_count.fetch_add(1, Ordering::SeqCst); + if let Some(max) = self.state.config.max_task_count { + if count > max { + self.state.task_count.fetch_sub(1, Ordering::SeqCst); + return Err(ControlPlaneError::TaskLimitReached { max: count }); } } - pid + Ok(TaskCountGuard(self.state.task_count.clone())) } /// Creates a new process - pub fn new_process(&self) -> WasiProcess { - let pid = self.reserve_pid(); - let ret = WasiProcess { - pid, - ppid: 0u32.into(), - compute: self.clone(), - inner: Arc::new(RwLock::new(WasiProcessInner { - threads: Default::default(), - thread_count: Default::default(), - thread_seed: Default::default(), - thread_local: Default::default(), - thread_local_user_data: Default::default(), - thread_local_seed: Default::default(), - signal_intervals: Default::default(), - bus_processes: Default::default(), - bus_process_reuse: Default::default(), - })), - children: Arc::new(RwLock::new(Default::default())), - finished: Arc::new(Mutex::new((None, tokio::sync::broadcast::channel(1).0))), - waiting: Arc::new(AtomicU32::new(0)), - }; - { - let mut guard = self.processes.write().unwrap(); - guard.insert(pid, ret.clone()); - } - { - let mut guard = self.reserved.lock().unwrap(); - guard.remove(&pid); + // FIXME: De-register terminated processes! + // Currently they just accumulate. + pub fn new_process(&self) -> Result { + if let Some(max) = self.state.config.max_task_count { + if self.active_task_count() >= max { + // NOTE: task count is not incremented here, only when new threads are spawned. + // A process will always have a main thread. + return Err(ControlPlaneError::TaskLimitReached { max }); + } } - ret + + // Create the process first to do all the allocations before locking. + let mut proc = WasiProcess::new(WasiProcessId::from(0), self.clone()); + + let mut mutable = self.state.mutable.write().unwrap(); + + let pid = mutable.next_process_id()?; + proc.set_pid(pid); + mutable.processes.insert(pid, proc.clone()); + Ok(proc) } /// Gets a reference to a running process pub fn get_process(&self, pid: WasiProcessId) -> Option { - let guard = self.processes.read().unwrap(); - guard.get(&pid).map(|a| a.clone()) + self.state + .mutable + .read() + .unwrap() + .processes + .get(&pid) + .cloned() + } +} + +impl MutableState { + fn next_process_id(&mut self) -> Result { + // TODO: reuse terminated ids, handle wrap-around, ... + let id = self.process_seed.checked_add(1).ok_or_else(|| { + ControlPlaneError::TaskLimitReached { + max: u32::MAX as usize, + } + })?; + self.process_seed = id; + Ok(WasiProcessId::from(id)) + } +} + +impl Default for WasiControlPlane { + fn default() -> Self { + let config = ControlPlaneConfig::default(); + Self::new(config) + } +} + +/// Guard that ensures the [`WasiControlPlane`] task counter is decremented when dropped. +#[derive(Debug)] +pub struct TaskCountGuard(Arc); + +impl Drop for TaskCountGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, Ordering::SeqCst); + } +} + +#[derive(thiserror::Error, PartialEq, Eq, Clone, Debug)] +pub enum ControlPlaneError { + /// The maximum number of execution tasks has been reached. + #[error("The maximum number of execution tasks has been reached ({max})")] + TaskLimitReached { + /// The maximum number of tasks. + max: usize, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Simple test to ensure task limits are respected. + #[test] + fn test_control_plane_task_limits() { + let p = WasiControlPlane::new(ControlPlaneConfig { + max_task_count: Some(2), + }); + + let p1 = p.new_process().unwrap(); + let _t1 = p1.new_thread().unwrap(); + let _t2 = p1.new_thread().unwrap(); + + assert_eq!( + p.new_process().unwrap_err(), + ControlPlaneError::TaskLimitReached { max: 2 } + ); + } + + /// Simple test to ensure task limits are respected and that thread drop guards work. + #[test] + fn test_control_plane_task_limits_with_dropped_threads() { + let p = WasiControlPlane::new(ControlPlaneConfig { + max_task_count: Some(2), + }); + + let p1 = p.new_process().unwrap(); + + for _ in 0..10 { + let _thread = p1.new_thread().unwrap(); + } + + let _t1 = p1.new_thread().unwrap(); + let _t2 = p1.new_thread().unwrap(); + + assert_eq!( + p.new_process().unwrap_err(), + ControlPlaneError::TaskLimitReached { max: 2 } + ); } } diff --git a/lib/wasi/src/os/task/mod.rs b/lib/wasi/src/os/task/mod.rs index fabd65d9d82..5d17c9871cf 100644 --- a/lib/wasi/src/os/task/mod.rs +++ b/lib/wasi/src/os/task/mod.rs @@ -3,4 +3,5 @@ pub mod control_plane; pub mod process; pub mod signal; +pub mod task_join_handle; pub mod thread; diff --git a/lib/wasi/src/os/task/process.rs b/lib/wasi/src/os/task/process.rs index 03f18c415ad..e38b583799f 100644 --- a/lib/wasi/src/os/task/process.rs +++ b/lib/wasi/src/os/task/process.rs @@ -4,7 +4,7 @@ use std::{ convert::TryInto, sync::{ atomic::{AtomicU32, Ordering}, - Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard, + Arc, RwLock, RwLockReadGuard, RwLockWriteGuard, }, time::Duration, }; @@ -17,11 +17,13 @@ use wasmer_wasi_types::{ }; use crate::{ - os::task::{control_plane::WasiControlPlane, signal::WasiSignalInterval, thread::ThreadStack}, + os::task::{control_plane::WasiControlPlane, signal::WasiSignalInterval}, syscalls::platform_clock_time_get, WasiThread, WasiThreadHandle, WasiThreadId, }; +use super::{control_plane::ControlPlaneError, task_join_handle::TaskJoinHandle}; + /// Represents the ID of a sub-process #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct WasiProcessId(u32); @@ -62,6 +64,29 @@ impl std::fmt::Display for WasiProcessId { } } +/// Represents a process running within the compute state +// TODO: fields should be private and only accessed via methods. +#[derive(Debug, Clone)] +pub struct WasiProcess { + /// Unique ID of this process + pub(crate) pid: WasiProcessId, + /// ID of the parent process + pub(crate) ppid: WasiProcessId, + /// The inner protected region of the process + pub(crate) inner: Arc>, + /// Reference back to the compute engine + // TODO: remove this reference, access should happen via separate state instead + // (we don't want cyclical references) + pub(crate) compute: WasiControlPlane, + /// Reference to the exit code for the main thread + pub(crate) finished: Arc, + /// List of all the children spawned from this thread + pub(crate) children: Arc>>, + /// Number of threads waiting for children to exit + pub(crate) waiting: Arc, +} + +// TODO: fields should be private and only accessed via methods. #[derive(Debug)] pub struct WasiProcessInner { /// The threads that make up this process @@ -84,25 +109,6 @@ pub struct WasiProcessInner { pub bus_process_reuse: HashMap, WasiProcessId>, } -/// Represents a process running within the compute state -#[derive(Debug, Clone)] -pub struct WasiProcess { - /// Unique ID of this process - pub(crate) pid: WasiProcessId, - /// ID of the parent process - pub(crate) ppid: WasiProcessId, - /// The inner protected region of the process - pub(crate) inner: Arc>, - /// Reference back to the compute engine - pub(crate) compute: WasiControlPlane, - /// Reference to the exit code for the main thread - pub(crate) finished: Arc, tokio::sync::broadcast::Sender<()>)>>, - /// List of all the children spawned from this thread - pub(crate) children: Arc>>, - /// Number of threads waiting for children to exit - pub(crate) waiting: Arc, -} - pub(crate) struct WasiProcessWait { waiting: Arc, } @@ -123,6 +129,32 @@ impl Drop for WasiProcessWait { } impl WasiProcess { + pub fn new(pid: WasiProcessId, compute: WasiControlPlane) -> Self { + WasiProcess { + pid, + ppid: 0u32.into(), + compute, + inner: Arc::new(RwLock::new(WasiProcessInner { + threads: Default::default(), + thread_count: Default::default(), + thread_seed: Default::default(), + thread_local: Default::default(), + thread_local_user_data: Default::default(), + thread_local_seed: Default::default(), + signal_intervals: Default::default(), + bus_processes: Default::default(), + bus_process_reuse: Default::default(), + })), + children: Arc::new(RwLock::new(Default::default())), + finished: Arc::new(TaskJoinHandle::new()), + waiting: Arc::new(AtomicU32::new(0)), + } + } + + pub(super) fn set_pid(&mut self, pid: WasiProcessId) { + self.pid = pid; + } + /// Gets the process ID of this process pub fn pid(&self) -> WasiProcessId { self.pid @@ -134,17 +166,21 @@ impl WasiProcess { } /// Gains write access to the process internals + // TODO: Make this private, all inner access should be exposed with methods. pub fn write(&self) -> RwLockWriteGuard { self.inner.write().unwrap() } /// Gains read access to the process internals + // TODO: Make this private, all inner access should be exposed with methods. pub fn read(&self) -> RwLockReadGuard { self.inner.read().unwrap() } /// Creates a a thread and returns it - pub fn new_thread(&self) -> WasiThreadHandle { + pub fn new_thread(&self) -> Result { + let task_count_guard = self.compute.register_task()?; + let mut inner = self.inner.write().unwrap(); let id = inner.thread_seed.inc(); @@ -153,28 +189,18 @@ impl WasiProcess { is_main = true; self.finished.clone() } else { - Arc::new(Mutex::new((None, tokio::sync::broadcast::channel(1).0))) + Arc::new(TaskJoinHandle::new()) }; - let ctrl = WasiThread { - pid: self.pid(), - id, - is_main, - finished, - signals: Arc::new(Mutex::new(( - Vec::new(), - tokio::sync::broadcast::channel(1).0, - ))), - stack: Arc::new(Mutex::new(ThreadStack::default())), - }; + let ctrl = WasiThread::new(self.pid(), id, is_main, finished, task_count_guard); inner.threads.insert(id, ctrl.clone()); inner.thread_count += 1; - WasiThreadHandle { + Ok(WasiThreadHandle { id: Arc::new(id), thread: ctrl, inner: self.inner.clone(), - } + }) } /// Gets a reference to a particular thread @@ -254,24 +280,12 @@ impl WasiProcess { /// Waits until the process is finished or the timeout is reached pub async fn join(&self) -> Option { let _guard = WasiProcessWait::new(self); - loop { - let mut rx = { - let finished = self.finished.lock().unwrap(); - if finished.0.is_some() { - return finished.0.clone(); - } - finished.1.subscribe() - }; - if rx.recv().await.is_err() { - return None; - } - } + self.finished.await_termination().await } /// Attempts to join on the process pub fn try_join(&self) -> Option { - let guard = self.finished.lock().unwrap(); - guard.0.clone() + self.finished.get_exit_code() } /// Waits for all the children to be finished @@ -299,8 +313,7 @@ impl WasiProcess { futures::future::join_all(waits.into_iter()) .await .into_iter() - .filter_map(|a| a) - .next() + .find_map(|a| a) } /// Waits for any of the children to finished diff --git a/lib/wasi/src/os/task/task_join_handle.rs b/lib/wasi/src/os/task/task_join_handle.rs new file mode 100644 index 00000000000..115d0b9dc21 --- /dev/null +++ b/lib/wasi/src/os/task/task_join_handle.rs @@ -0,0 +1,52 @@ +use std::sync::Mutex; + +use wasmer_wasi_types::wasi::ExitCode; + +/// A handle that allows awaiting the termination of a task, and retrieving its exit code. +#[derive(Debug)] +pub struct TaskJoinHandle { + exit_code: Mutex>, + sender: tokio::sync::broadcast::Sender<()>, +} + +impl TaskJoinHandle { + pub fn new() -> Self { + let (sender, _) = tokio::sync::broadcast::channel(1); + Self { + exit_code: Mutex::new(None), + sender, + } + } + + /// Marks the task as finished. + pub(super) fn terminate(&self, exit_code: u32) { + let mut lock = self.exit_code.lock().unwrap(); + if lock.is_none() { + *lock = Some(exit_code); + std::mem::drop(lock); + self.sender.send(()).ok(); + } + } + + pub async fn await_termination(&self) -> Option { + // FIXME: why is this a loop? should not be necessary, + // Should be redundant since the subscriber is created while holding the lock. + loop { + let mut rx = { + let code_opt = self.exit_code.lock().unwrap(); + if code_opt.is_some() { + return code_opt.clone(); + } + self.sender.subscribe() + }; + if rx.recv().await.is_err() { + return None; + } + } + } + + /// Returns the exit code if the task has finished, and None otherwise. + pub fn get_exit_code(&self) -> Option { + self.exit_code.lock().unwrap().clone() + } +} diff --git a/lib/wasi/src/os/task/thread.rs b/lib/wasi/src/os/task/thread.rs index 34e3c700a39..a7d83cd06e9 100644 --- a/lib/wasi/src/os/task/thread.rs +++ b/lib/wasi/src/os/task/thread.rs @@ -9,6 +9,8 @@ use wasmer_wasi_types::{types::Signal, wasi::ExitCode}; use crate::os::task::process::{WasiProcessId, WasiProcessInner}; +use super::{control_plane::TaskCountGuard, task_join_handle::TaskJoinHandle}; + /// Represents the ID of a WASI thread #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct WasiThreadId(u32); @@ -73,69 +75,87 @@ pub struct ThreadStack { /// Represents a running thread which allows a joiner to /// wait for the thread to exit -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct WasiThread { - pub(crate) is_main: bool, - pub(crate) pid: WasiProcessId, - pub(crate) id: WasiThreadId, - pub(super) finished: Arc, tokio::sync::broadcast::Sender<()>)>>, - pub(crate) signals: Arc, tokio::sync::broadcast::Sender<()>)>>, - pub(super) stack: Arc>, + state: Arc, +} + +#[derive(Debug)] +struct WasiThreadState { + is_main: bool, + pid: WasiProcessId, + id: WasiThreadId, + signals: Mutex<(Vec, tokio::sync::broadcast::Sender<()>)>, + stack: Mutex, + finished: Arc, + + // Registers the task termination with the ControlPlane on drop. + // Never accessed, since it's a drop guard. + _task_count_guard: TaskCountGuard, } static NO_MORE_BYTES: [u8; 0] = [0u8; 0]; impl WasiThread { + pub fn new( + pid: WasiProcessId, + id: WasiThreadId, + is_main: bool, + finished: Arc, + guard: TaskCountGuard, + ) -> Self { + Self { + state: Arc::new(WasiThreadState { + is_main, + pid, + id, + finished, + signals: Mutex::new((Vec::new(), tokio::sync::broadcast::channel(1).0)), + stack: Mutex::new(ThreadStack::default()), + _task_count_guard: guard, + }), + } + } + /// Returns the process ID pub fn pid(&self) -> WasiProcessId { - self.pid + self.state.pid } /// Returns the thread ID pub fn tid(&self) -> WasiThreadId { - self.id + self.state.id } /// Returns true if this thread is the main thread pub fn is_main(&self) -> bool { - self.is_main + self.state.is_main + } + + // TODO: this should be private, access should go through utility methods. + pub fn signals(&self) -> &Mutex<(Vec, tokio::sync::broadcast::Sender<()>)> { + &self.state.signals } /// Marks the thread as finished (which will cause anyone that /// joined on it to wake up) pub fn terminate(&self, exit_code: u32) { - let mut guard = self.finished.lock().unwrap(); - if guard.0.is_none() { - guard.0 = Some(exit_code); - } - let _ = guard.1.send(()); + self.state.finished.terminate(exit_code); } /// Waits until the thread is finished or the timeout is reached pub async fn join(&self) -> Option { - loop { - let mut rx = { - let finished = self.finished.lock().unwrap(); - if finished.0.is_some() { - return finished.0.clone(); - } - finished.1.subscribe() - }; - if rx.recv().await.is_err() { - return None; - } - } + self.state.finished.await_termination().await } /// Attempts to join on the thread pub fn try_join(&self) -> Option { - let guard = self.finished.lock().unwrap(); - guard.0.clone() + self.state.finished.get_exit_code() } /// Adds a signal for this thread to process pub fn signal(&self, signal: Signal) { - let mut guard = self.signals.lock().unwrap(); + let mut guard = self.state.signals.lock().unwrap(); if guard.0.contains(&signal) == false { guard.0.push(signal); } @@ -144,7 +164,7 @@ impl WasiThread { /// Returns all the signals that are waiting to be processed pub fn has_signal(&self, signals: &[Signal]) -> bool { - let guard = self.signals.lock().unwrap(); + let guard = self.state.signals.lock().unwrap(); for s in guard.0.iter() { if signals.contains(s) { return true; @@ -157,7 +177,7 @@ impl WasiThread { pub fn pop_signals_or_subscribe( &self, ) -> Result, tokio::sync::broadcast::Receiver<()>> { - let mut guard = self.signals.lock().unwrap(); + let mut guard = self.state.signals.lock().unwrap(); let mut ret = Vec::new(); std::mem::swap(&mut ret, &mut guard.0); match ret.is_empty() { @@ -176,7 +196,7 @@ impl WasiThread { store_data: &[u8], ) { // Lock the stack - let mut stack = self.stack.lock().unwrap(); + let mut stack = self.state.stack.lock().unwrap(); let mut pstack = stack.deref_mut(); loop { // First we validate if the stack is no longer valid @@ -208,14 +228,19 @@ impl WasiThread { let mut disown = Some(Box::new(new_stack)); if let Some(disown) = disown.as_ref() { if disown.snapshots.is_empty() == false { - tracing::trace!("wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})", self.pid, memory_stack_before, memory_stack_after); + tracing::trace!( + "wasi[{}]::stacks forgotten (memory_stack_before={}, memory_stack_after={})", + self.pid(), + memory_stack_before, + memory_stack_after + ); } } while let Some(disowned) = disown { for hash in disowned.snapshots.keys() { tracing::trace!( "wasi[{}]::stack has been forgotten (hash={})", - self.pid, + self.pid(), hash ); } @@ -256,7 +281,7 @@ impl WasiThread { pub fn get_snapshot(&self, hash: u128) -> Option<(BytesMut, Bytes, Bytes)> { let mut memory_stack = BytesMut::new(); - let stack = self.stack.lock().unwrap(); + let stack = self.state.stack.lock().unwrap(); let mut pstack = stack.deref(); loop { memory_stack.extend(pstack.memory_stack_corrected.iter()); @@ -278,11 +303,11 @@ impl WasiThread { // Copy the stacks from another thread pub fn copy_stack_from(&self, other: &WasiThread) { let mut stack = { - let stack_guard = other.stack.lock().unwrap(); + let stack_guard = other.state.stack.lock().unwrap(); stack_guard.clone() }; - let mut stack_guard = self.stack.lock().unwrap(); + let mut stack_guard = self.state.stack.lock().unwrap(); std::mem::swap(stack_guard.deref_mut(), &mut stack); } } diff --git a/lib/wasi/src/state/builder.rs b/lib/wasi/src/state/builder.rs index cd595af7ee8..b0879fa9a3e 100644 --- a/lib/wasi/src/state/builder.rs +++ b/lib/wasi/src/state/builder.rs @@ -16,7 +16,7 @@ use wasmer_vfs::{ArcFile, FsError, TmpFileSystem, VirtualFile}; use crate::{ bin_factory::ModuleCache, fs::{WasiFs, WasiFsRoot, WasiInodes}, - os::task::control_plane::WasiControlPlane, + os::task::control_plane::{ControlPlaneError, WasiControlPlane}, state::WasiState, syscalls::types::{__WASI_STDERR_FILENO, __WASI_STDIN_FILENO, __WASI_STDOUT_FILENO}, PluggableRuntimeImplementation, WasiEnv, WasiFunctionEnv, @@ -95,6 +95,8 @@ pub enum WasiStateCreationError { FileSystemError(FsError), #[error("wasi inherit error: `{0}`")] WasiInheritError(String), + #[error("control plain error")] + ControlPlane(#[from] ControlPlaneError), } fn validate_mapped_dir_alias(alias: &str) -> Result<(), WasiStateCreationError> { @@ -612,8 +614,8 @@ impl WasiStateBuilder { let state = Arc::new(self.build()?); let runtime = state.runtime.clone(); - let process = control_plane.new_process(); - let thread = process.new_thread(); + let process = control_plane.new_process()?; + let thread = process.new_thread()?; let env = WasiEnv::new_ext( state, diff --git a/lib/wasi/src/state/env.rs b/lib/wasi/src/state/env.rs index a03f2b685c6..fa5cec1ce86 100644 --- a/lib/wasi/src/state/env.rs +++ b/lib/wasi/src/state/env.rs @@ -22,6 +22,7 @@ use crate::{ os::{ command::builtins::cmd_wasmer::CmdWasmer, task::{ + control_plane::ControlPlaneError, process::{WasiProcess, WasiProcessId}, thread::{WasiThread, WasiThreadHandle, WasiThreadId}, }, @@ -225,9 +226,9 @@ unsafe impl Sync for WasiEnv {} impl WasiEnv { /// Forking the WasiState is used when either fork or vfork is called - pub fn fork(&self) -> (Self, WasiThreadHandle) { - let process = self.process.compute.new_process(); - let handle = process.new_thread(); + pub fn fork(&self) -> Result<(Self, WasiThreadHandle), ControlPlaneError> { + let process = self.process.compute.new_process()?; + let handle = process.new_thread()?; let thread = handle.as_thread(); thread.copy_stack_from(&self.thread); @@ -240,23 +241,21 @@ impl WasiEnv { bin_factory }; - ( - Self { - process, - thread, - vfork: None, - stack_base: self.stack_base, - stack_start: self.stack_start, - bin_factory, - state, - inner: None, - owned_handles: Vec::new(), - runtime: self.runtime.clone(), - tasks: self.tasks.clone(), - capabilities: self.capabilities.clone(), - }, - handle, - ) + let new_env = Self { + process, + thread, + vfork: None, + stack_base: self.stack_base, + stack_start: self.stack_start, + bin_factory, + state, + inner: None, + owned_handles: Vec::new(), + runtime: self.runtime.clone(), + tasks: self.tasks.clone(), + capabilities: self.capabilities.clone(), + }; + Ok((new_env, handle)) } pub fn pid(&self) -> WasiProcessId { diff --git a/lib/wasi/src/syscalls/mod.rs b/lib/wasi/src/syscalls/mod.rs index 290af16b365..be4a9f3ca03 100644 --- a/lib/wasi/src/syscalls/mod.rs +++ b/lib/wasi/src/syscalls/mod.rs @@ -259,7 +259,7 @@ where }; let mut signaler = { - let signals = env.thread.signals.lock().unwrap(); + let signals = env.thread.signals().lock().unwrap(); let signaler = signals.1.subscribe(); if signals.0.is_empty() == false { drop(signals); diff --git a/lib/wasi/src/syscalls/wasix/proc_fork.rs b/lib/wasi/src/syscalls/wasix/proc_fork.rs index ad8dca27706..d8c25db84f3 100644 --- a/lib/wasi/src/syscalls/wasix/proc_fork.rs +++ b/lib/wasi/src/syscalls/wasix/proc_fork.rs @@ -51,7 +51,18 @@ pub fn proc_fork( // and associate a new context but otherwise shares things like the // file system interface. The handle to the forked process is stored // in the parent process context - let (mut child_env, mut child_handle) = ctx.data().fork(); + let (mut child_env, mut child_handle) = match ctx.data().fork() { + Ok(p) => p, + Err(err) => { + debug!( + pid=%ctx.data().pid(), + tid=%ctx.data().tid(), + "could not fork process: {err}" + ); + // TODO: evaluate the appropriate error code, document it in the spec. + return Ok(Errno::Perm); + } + }; let child_pid = child_env.process.pid(); // We write a zero to the PID before we capture the stack diff --git a/lib/wasi/src/syscalls/wasix/proc_spawn.rs b/lib/wasi/src/syscalls/wasix/proc_spawn.rs index 935a87d87be..792d44fc34b 100644 --- a/lib/wasi/src/syscalls/wasix/proc_spawn.rs +++ b/lib/wasi/src/syscalls/wasix/proc_spawn.rs @@ -109,7 +109,13 @@ pub fn proc_spawn_internal( let new_store = ctx.data().runtime.new_store(); // Fork the current environment and set the new arguments - let (mut child_env, handle) = ctx.data().fork(); + let (mut child_env, handle) = match ctx.data().fork() { + Ok(x) => x, + Err(err) => { + // TODO: evaluate the appropriate error code, document it in the spec. + return Ok(Err(BusErrno::Denied)); + } + }; if let Some(args) = args { let mut child_state = env.state.fork(true); child_state.args = args; diff --git a/lib/wasi/src/syscalls/wasix/thread_spawn.rs b/lib/wasi/src/syscalls/wasix/thread_spawn.rs index 2b911681236..a151fb89f01 100644 --- a/lib/wasi/src/syscalls/wasix/thread_spawn.rs +++ b/lib/wasi/src/syscalls/wasix/thread_spawn.rs @@ -44,7 +44,23 @@ pub fn thread_spawn( let tasks = env.tasks.clone(); // Create the handle that represents this thread - let mut thread_handle = env.process.new_thread(); + let mut thread_handle = match env.process.new_thread() { + Ok(h) => h, + Err(err) => { + error!( + "wasi[{}:{}]::thread_spawn (reactor={:?}, thread_id={}, stack_base={}, caller_id={}) - failed to create thread handle: {}", + ctx.data().pid(), + ctx.data().tid(), + reactor, + ctx.data().thread.tid().raw(), + stack_base, + current_caller_id().raw(), + err + ); + // TODO: evaluate the appropriate error code, document it in the spec. + return Errno::Access; + } + }; let thread_id: Tid = thread_handle.id().into(); // We need a copy of the process memory and a packaged store in order to