diff --git a/crates/rmcp/src/transport/async_rw.rs b/crates/rmcp/src/transport/async_rw.rs index 4c4cb4a3..3d561fe7 100644 --- a/crates/rmcp/src/transport/async_rw.rs +++ b/crates/rmcp/src/transport/async_rw.rs @@ -42,9 +42,11 @@ where } } +pub type TransportWriter = FramedWrite>>; + pub struct AsyncRwTransport { read: FramedRead>>, - write: Arc>>>>, + write: Arc>>>, } impl AsyncRwTransport @@ -57,10 +59,10 @@ where read, JsonRpcMessageCodec::>::default(), ); - let write = Arc::new(Mutex::new(FramedWrite::new( + let write = Arc::new(Mutex::new(Some(FramedWrite::new( write, JsonRpcMessageCodec::>::default(), - ))); + )))); Self { read, write } } } @@ -103,7 +105,14 @@ where let lock = self.write.clone(); async move { let mut write = lock.lock().await; - write.send(item).await.map_err(Into::into) + if let Some(ref mut write) = *write { + write.send(item).await.map_err(Into::into) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::NotConnected, + "Transport is closed", + )) + } } } @@ -120,6 +129,8 @@ where } async fn close(&mut self) -> Result<(), Self::Error> { + let mut write = self.write.lock().await; + drop(write.take()); Ok(()) } } diff --git a/crates/rmcp/src/transport/child_process.rs b/crates/rmcp/src/transport/child_process.rs index e384ad55..d117c09d 100644 --- a/crates/rmcp/src/transport/child_process.rs +++ b/crates/rmcp/src/transport/child_process.rs @@ -104,10 +104,13 @@ impl TokioChildProcess { /// Gracefully shutdown the child process /// - /// This will first wait for the child process to exit normally with a timeout. + /// This will first close the transport to the child process (the server), + /// and wait for the child process to exit normally with a timeout. /// If the child process doesn't exit within the timeout, it will be killed. pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> { if let Some(mut child) = self.child.inner.take() { + self.transport.close().await?; + let wait_fut = Box::into_pin(child.wait()); tokio::select! { _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => {