diff --git a/lib/wasix/src/fs/mod.rs b/lib/wasix/src/fs/mod.rs index 671bcfd65b7..8838eafac63 100644 --- a/lib/wasix/src/fs/mod.rs +++ b/lib/wasix/src/fs/mod.rs @@ -460,31 +460,26 @@ impl WasiFs { /// Closes all the file handles. #[allow(clippy::await_holding_lock)] pub async fn close_all(&self) { - // TODO: this should close all uniquely owned files instead of just flushing. - - if let Ok(mut map) = self.fd_map.write() { - map.clear(); - } - - let to_close = { + let mut to_close = { if let Ok(map) = self.fd_map.read() { - map.keys().copied().collect::>() + map.keys().copied().collect::>() } else { - Vec::new() + HashSet::new() } }; + to_close.insert(__WASI_STDOUT_FILENO); + to_close.insert(__WASI_STDERR_FILENO); - let _ = tokio::join!( - // Make sure the STDOUT and STDERR are explicitely flushed - self.flush(__WASI_STDOUT_FILENO), - self.flush(__WASI_STDERR_FILENO), - async { - for fd in to_close { - self.flush(fd).await.ok(); - self.close_fd(fd).ok(); - } + let _ = tokio::join!(async { + for fd in to_close { + self.flush(fd).await.ok(); + self.close_fd(fd).ok(); } - ); + }); + + if let Ok(mut map) = self.fd_map.write() { + map.clear(); + } } /// Will conditionally union the binary file system with this one @@ -1566,36 +1561,33 @@ impl WasiFs { return Err(Errno::Access); } - let work = { + let file = { let guard = fd.inode.read(); match guard.deref() { Kind::File { handle: Some(file), .. - } => { - struct FlushPoller { - file: Arc>>, - } - impl Future for FlushPoller { - type Output = Result<(), Errno>; - fn poll( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll { - let mut file = self.file.write().unwrap(); - Pin::new(file.as_mut()) - .poll_flush(cx) - .map_err(|_| Errno::Io) - } - } - FlushPoller { file: file.clone() } - } + } => file.clone(), // TODO: verify this behavior Kind::Dir { .. } => return Err(Errno::Isdir), Kind::Buffer { .. } => return Ok(()), _ => return Err(Errno::Io), } }; - work.await? + drop(fd); + + struct FlushPoller { + file: Arc>>, + } + impl Future for FlushPoller { + type Output = Result<(), Errno>; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut file = self.file.write().unwrap(); + Pin::new(file.as_mut()) + .poll_flush(cx) + .map_err(|_| Errno::Io) + } + } + FlushPoller { file }.await?; } } Ok(())