From 5f3159fa4fa59722891c3ba77ca8ea210c06609b Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Mon, 11 Aug 2025 11:28:35 +0800 Subject: [PATCH] fix: make stdio shutdown more graceful According to the protocol specifications Signed-off-by: jokemanfire --- crates/rmcp/src/transport/child_process.rs | 179 +++++++++++++++++---- 1 file changed, 150 insertions(+), 29 deletions(-) diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index 2e7c034f..e384ad55 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -1,14 +1,16 @@ use std::process::Stdio; +use futures::future::Future; use process_wrap::tokio::{TokioChildWrapper, TokioCommandWrap}; use tokio::{ io::AsyncRead, process::{ChildStderr, ChildStdin, ChildStdout}, }; -use super::{IntoTransport, Transport}; -use crate::service::ServiceRole; +use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport}; +use crate::RoleClient; +const MAX_WAIT_ON_DROP_SECS: u64 = 3; /// The parts of a child process. type ChildProcessParts = ( Box, @@ -36,18 +38,23 @@ fn child_process(mut child: Box) -> std::io::Result, } pub struct ChildWithCleanup { - inner: Box, + inner: Option>, } impl Drop for ChildWithCleanup { fn drop(&mut self) { - if let Err(e) = self.inner.start_kill() { - tracing::warn!("Failed to kill child process: {e}"); + // We should not use start_kill(), instead we should use kill() to avoid zombies + if let Some(mut inner) = self.inner.take() { + // We don't care about the result, just try to kill it + tokio::spawn(async move { + if let Err(e) = Box::into_pin(inner.kill()).await { + tracing::warn!("Error killing child process: {}", e); + } + }); } } } @@ -64,7 +71,7 @@ pin_project_lite::pin_project! { impl TokioChildProcessOut { /// Get the process ID of the child process. pub fn id(&self) -> Option { - self.child.inner.id() + self.child.inner.as_ref()?.id() } } @@ -92,23 +99,51 @@ impl TokioChildProcess { /// Get the process ID of the child process. pub fn id(&self) -> Option { - self.child.inner.id() + self.child.inner.as_ref()?.id() + } + + /// Gracefully shutdown the child process + /// + /// This will first wait for the child process to exit normally with a timeout. + /// If the child process doesn't exit within the timeout, it will be killed. + pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { + if let Some(mut child) = self.child.inner.take() { + let wait_fut = Box::into_pin(child.wait()); + tokio::select! { + _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => { + if let Err(e) = Box::into_pin(child.kill()).await { + tracing::warn!("Error killing child: {e}"); + return Err(e); + } + }, + res = wait_fut => { + match res { + Ok(status) => { + tracing::info!("Child exited gracefully {}", status); + } + Err(e) => { + tracing::warn!("Error waiting for child: {e}"); + return Err(e); + } + } + } + } + } + Ok(()) + } + + /// Take ownership of the inner child process + pub fn into_inner(mut self) -> Option> { + self.child.inner.take() } /// Split this helper into a reader (stdout) and writer (stdin). + #[deprecated( + since = "0.5.0", + note = "use the Transport trait implementation instead" + )] pub fn split(self) -> (TokioChildProcessOut, ChildStdin) { - let TokioChildProcess { - child, - child_stdin, - child_stdout, - } = self; - ( - TokioChildProcessOut { - child, - child_stdout, - }, - child_stdin, - ) + unimplemented!("This method is deprecated, use the Transport trait implementation instead"); } } @@ -156,20 +191,31 @@ impl TokioChildProcessBuilder { let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?; + let transport = AsyncRwTransport::new(stdout, stdin); let proc = TokioChildProcess { - child: ChildWithCleanup { inner: child }, - child_stdin: stdin, - child_stdout: stdout, + child: ChildWithCleanup { inner: Some(child) }, + transport, }; Ok((proc, stderr_opt)) } } -impl IntoTransport for TokioChildProcess { - fn into_transport(self) -> impl Transport + 'static { - IntoTransport::::into_transport( - self.split(), - ) +impl Transport for TokioChildProcess { + type Error = std::io::Error; + + fn send( + &mut self, + item: TxJsonRpcMessage, + ) -> impl Future> + Send + 'static { + self.transport.send(item) + } + + fn receive(&mut self) -> impl Future>> + Send { + self.transport.receive() + } + + fn close(&mut self) -> impl Future> + Send { + self.graceful_shutdown() } } @@ -183,3 +229,78 @@ impl ConfigureCommandExt for tokio::process::Command { self } } + +#[cfg(unix)] +#[cfg(test)] +mod tests { + use tokio::process::Command; + + use super::*; + + #[tokio::test] + async fn test_tokio_child_process_drop() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + // Drop the child process + drop(child_process); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } + + #[tokio::test] + async fn test_tokio_child_process_graceful_shutdown() { + let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| { + cmd.arg("30"); + })); + assert!(r.is_ok()); + let mut child_process = r.unwrap(); + let id = child_process.id(); + assert!(id.is_some()); + let id = id.unwrap(); + child_process.graceful_shutdown().await.unwrap(); + // Wait a moment to allow the cleanup task to run + tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await; + // Check if the process is still running + let status = Command::new("ps") + .arg("-p") + .arg(id.to_string()) + .status() + .await; + match status { + Ok(status) => { + assert!( + !status.success(), + "Process with PID {} is still running", + id + ); + } + Err(e) => { + panic!("Failed to check process status: {}", e); + } + } + } +}