diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 812dd74d09a9..8e1f04204349 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,19 +1,22 @@ use rmcp::{ model::{ - CallToolRequest, CallToolRequestParam, CallToolResult, ClientCapabilities, ClientInfo, + CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification, + CancelledNotificationMethod, CancelledNotificationParam, ClientCapabilities, ClientInfo, ClientRequest, GetPromptRequest, GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult, ListPromptsRequest, ListPromptsResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, ListToolsResult, LoggingMessageNotification, LoggingMessageNotificationMethod, PaginatedRequestParam, ProgressNotification, ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam, - ReadResourceResult, ServerNotification, ServerResult, + ReadResourceResult, RequestId, ServerNotification, ServerResult, + }, + service::{ + ClientInitializeError, PeerRequestOptions, RequestHandle, RunningService, ServiceRole, }, - service::{ClientInitializeError, PeerRequestOptions, RunningService}, transport::IntoTransport, - ClientHandler, RoleClient, ServiceError, ServiceExt, + ClientHandler, Peer, RoleClient, ServiceError, ServiceExt, }; use serde_json::Value; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use tokio::sync::{ mpsc::{self, Sender}, Mutex, @@ -176,27 +179,52 @@ impl McpClient { .client .lock() .await - .send_request_with_option( - request, - PeerRequestOptions { - timeout: Some(self.timeout), - meta: None, - }, - ) + .send_cancellable_request(request, PeerRequestOptions::no_options()) .await?; - let cancel_token = cancel_token.clone(); - tokio::select! { - res = handle.await_response() => { - Ok(res?) - } - _ = cancel_token.cancelled() => { - Err(Error::Cancelled{reason: None}) - } + await_response(handle, self.timeout, &cancel_token).await + } +} + +async fn await_response( + handle: RequestHandle, + timeout: Duration, + cancel_token: &CancellationToken, +) -> Result<::PeerResp, ServiceError> { + let receiver = handle.rx; + let peer = handle.peer; + let request_id = handle.id; + tokio::select! { + result = receiver => { + result.map_err(|_e| ServiceError::TransportClosed)? + } + _ = tokio::time::sleep(timeout) => { + send_cancel_message(&peer, request_id, Some("timed out".to_owned())).await?; + Err(ServiceError::Timeout{timeout}) + } + _ = cancel_token.cancelled() => { + send_cancel_message(&peer, request_id, Some("operation cancelled".to_owned())).await?; + Err(ServiceError::Cancelled { reason: None }) } } } +async fn send_cancel_message( + peer: &Peer, + request_id: RequestId, + reason: Option, +) -> Result<(), ServiceError> { + peer.send_notification( + CancelledNotification { + params: CancelledNotificationParam { request_id, reason }, + method: CancelledNotificationMethod, + extensions: Default::default(), + } + .into(), + ) + .await +} + #[async_trait::async_trait] impl McpClientTrait for McpClient { fn get_info(&self) -> Option<&InitializeResult> {