diff --git a/crates/goose-cli/src/scenario_tests/mock_client.rs b/crates/goose-cli/src/scenario_tests/mock_client.rs index 006795ce9d94..38e74ca44b01 100644 --- a/crates/goose-cli/src/scenario_tests/mock_client.rs +++ b/crates/goose-cli/src/scenario_tests/mock_client.rs @@ -13,6 +13,7 @@ use rmcp::{ use serde_json::Value; use std::collections::HashMap; use tokio::sync::mpsc::{self, Receiver}; +use tokio_util::sync::CancellationToken; pub struct MockClient { tools: HashMap, @@ -43,6 +44,7 @@ impl McpClientTrait for MockClient { async fn list_resources( &self, _next_cursor: Option, + _cancel_token: CancellationToken, ) -> Result { Ok(ListResourcesResult { resources: vec![], @@ -54,11 +56,19 @@ impl McpClientTrait for MockClient { todo!() } - async fn read_resource(&self, _uri: &str) -> Result { + async fn read_resource( + &self, + _uri: &str, + _cancel_token: CancellationToken, + ) -> Result { Err(Error::UnexpectedResponse) } - async fn list_tools(&self, _: Option) -> Result { + async fn list_tools( + &self, + _: Option, + _cancel_token: CancellationToken, + ) -> Result { let rmcp_tools: Vec = self .tools .values() @@ -77,7 +87,12 @@ impl McpClientTrait for MockClient { }) } - async fn call_tool(&self, name: &str, arguments: Value) -> Result { + async fn call_tool( + &self, + name: &str, + arguments: Value, + _cancel_token: CancellationToken, + ) -> Result { if let Some(handler) = self.handlers.get(name) { match handler(&arguments) { Ok(content) => Ok(CallToolResult { @@ -91,14 +106,23 @@ impl McpClientTrait for MockClient { } } - async fn list_prompts(&self, _next_cursor: Option) -> Result { + async fn list_prompts( + &self, + _next_cursor: Option, + _cancel_token: CancellationToken, + ) -> Result { Ok(ListPromptsResult { prompts: vec![], next_cursor: None, }) } - async fn get_prompt(&self, _name: &str, _arguments: Value) -> Result { + async fn get_prompt( + &self, + _name: &str, + _arguments: Value, + _cancel_token: CancellationToken, + ) -> Result { Err(Error::UnexpectedResponse) } diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index 687d6e413a80..5ab1cd97a4df 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -12,6 +12,7 @@ use goose::providers::{create, testprovider::TestProvider}; use std::collections::{HashMap, HashSet}; use std::path::Path; use std::sync::Arc; +use tokio_util::sync::CancellationToken; pub const SCENARIO_TESTS_DIR: &str = "src/scenario_tests"; @@ -205,7 +206,10 @@ where let mut error = None; for message in &messages { - if let Err(e) = session.process_message(message.clone()).await { + if let Err(e) = session + .process_message(message.clone(), CancellationToken::default()) + .await + { error = Some(e.to_string()); break; } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 106bc0a0744d..0cdda04bcf31 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -364,7 +364,12 @@ impl Session { } /// Process a single message and get the response - pub(crate) async fn process_message(&mut self, message: Message) -> Result<()> { + pub(crate) async fn process_message( + &mut self, + message: Message, + cancel_token: CancellationToken, + ) -> Result<()> { + let cancel_token = cancel_token.clone(); let message_text = message.as_concat_text(); self.push_message(message); @@ -405,7 +410,7 @@ impl Session { ); } - self.process_agent_response(false).await?; + self.process_agent_response(false, cancel_token).await?; Ok(()) } @@ -414,7 +419,8 @@ impl Session { // Process initial message if provided if let Some(prompt) = prompt { let msg = Message::user().with_text(&prompt); - self.process_message(msg).await?; + self.process_message(msg, CancellationToken::default()) + .await?; } // Initialize the completion cache @@ -514,7 +520,8 @@ impl Session { } output::show_thinking(); - self.process_agent_response(true).await?; + self.process_agent_response(true, CancellationToken::default()) + .await?; output::hide_thinking(); } RunMode::Plan => { @@ -814,7 +821,8 @@ impl Session { self.push_message(plan_message); // act on the plan output::show_thinking(); - self.process_agent_response(true).await?; + self.process_agent_response(true, CancellationToken::default()) + .await?; output::hide_thinking(); // Reset run & goose mode @@ -842,12 +850,15 @@ impl Session { /// Process a single message and exit pub async fn headless(&mut self, prompt: String) -> Result<()> { let message = Message::user().with_text(&prompt); - self.process_message(message).await + self.process_message(message, CancellationToken::default()) + .await } - async fn process_agent_response(&mut self, interactive: bool) -> Result<()> { - // Messages will be auto-compacted in agent.reply() if needed - let cancel_token = CancellationToken::new(); + async fn process_agent_response( + &mut self, + interactive: bool, + cancel_token: CancellationToken, + ) -> Result<()> { let cancel_token_clone = cancel_token.clone(); let session_config = self.session_file.as_ref().map(|s| { @@ -1511,7 +1522,8 @@ impl Session { if valid { output::show_thinking(); - self.process_agent_response(true).await?; + self.process_agent_response(true, CancellationToken::default()) + .await?; output::hide_thinking(); } } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 05f5eaa8af15..6a3b91ca8c42 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -99,19 +99,24 @@ enum MessageEvent { request_id: String, message: ServerNotification, }, + Ping, } async fn stream_event( event: MessageEvent, tx: &mpsc::Sender, -) -> Result<(), mpsc::error::SendError> { + cancel_token: &CancellationToken, +) { let json = serde_json::to_string(&event).unwrap_or_else(|e| { format!( r#"{{"type":"Error","error":"Failed to serialize event: {}"}}"#, e ) }); - tx.send(format!("data: {}\n\n", json)).await + if tx.send(format!("data: {}\n\n", json)).await.is_err() { + tracing::info!("client hung up"); + cancel_token.cancel(); + } } async fn reply_handler( @@ -144,6 +149,7 @@ async fn reply_handler( error: "No agent configured".to_string(), }, &task_tx, + &cancel_token, ) .await; return; @@ -173,11 +179,12 @@ async fn reply_handler( Ok(stream) => stream, Err(e) => { tracing::error!("Failed to start reply stream: {:?}", e); - let _ = stream_event( + stream_event( MessageEvent::Error { error: e.to_string(), }, &task_tx, + &cancel_token, ) .await; return; @@ -194,6 +201,7 @@ async fn reply_handler( error: format!("Failed to get session path: {}", e), }, &task_tx, + &cancel_token, ) .await; return; @@ -201,81 +209,61 @@ async fn reply_handler( }; let saved_message_count = all_messages.len(); + let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(500)); loop { tokio::select! { - _ = task_cancel.cancelled() => { - tracing::info!("Agent task cancelled"); + _ = task_cancel.cancelled() => { + tracing::info!("Agent task cancelled"); + break; + } + _ = heartbeat_interval.tick() => { + stream_event(MessageEvent::Ping, &tx, &cancel_token).await; + } + response = timeout(Duration::from_millis(500), stream.next()) => { + match response { + Ok(Some(Ok(AgentEvent::Message(message)))) => { + push_message(&mut all_messages, message.clone()); + stream_event(MessageEvent::Message { message }, &tx, &cancel_token).await; + } + Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { + // Replace the message history with the compacted messages + all_messages = new_messages; + // Note: We don't send this as a stream event since it's an internal operation + // The client will see the compaction notification message that was sent before this event + } + Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { + stream_event(MessageEvent::ModelChange { model, mode }, &tx, &cancel_token).await; + } + Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { + stream_event(MessageEvent::Notification{ + request_id: request_id.clone(), + message: n, + }, &tx, &cancel_token).await; + } + + Ok(Some(Err(e))) => { + tracing::error!("Error processing message: {}", e); + stream_event( + MessageEvent::Error { + error: e.to_string(), + }, + &tx, + &cancel_token, + ).await; + break; + } + Ok(None) => { + break; + } + Err(_) => { + if tx.is_closed() { break; } - response = timeout(Duration::from_millis(500), stream.next()) => { - match response { - Ok(Some(Ok(AgentEvent::Message(message)))) => { - push_message(&mut all_messages, message.clone()); - if let Err(e) = stream_event(MessageEvent::Message { message }, &tx).await { - tracing::error!("Error sending message through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - break; - } - } - Ok(Some(Ok(AgentEvent::HistoryReplaced(new_messages)))) => { - // Replace the message history with the compacted messages - all_messages = new_messages; - // Note: We don't send this as a stream event since it's an internal operation - // The client will see the compaction notification message that was sent before this event - } - Ok(Some(Ok(AgentEvent::ModelChange { model, mode }))) => { - if let Err(e) = stream_event(MessageEvent::ModelChange { model, mode }, &tx).await { - tracing::error!("Error sending model change through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - } - } - Ok(Some(Ok(AgentEvent::McpNotification((request_id, n))))) => { - if let Err(e) = stream_event(MessageEvent::Notification{ - request_id: request_id.clone(), - message: n, - }, &tx).await { - tracing::error!("Error sending message through channel: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - } - } - - Ok(Some(Err(e))) => { - tracing::error!("Error processing message: {}", e); - let _ = stream_event( - MessageEvent::Error { - error: e.to_string(), - }, - &tx, - ).await; - break; - } - Ok(None) => { - break; - } - Err(_) => { - if tx.is_closed() { - break; - } - continue; - } - } - } + continue; } + } + } + } } if all_messages.len() > saved_message_count { @@ -301,6 +289,7 @@ async fn reply_handler( reason: "stop".to_string(), }, &task_tx, + &cancel_token, ) .await; })); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index c106f43f6be1..ba9212a3b80c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -439,13 +439,19 @@ impl Agent { // Check if the tool is read_resource and handle it separately ToolCallResult::from( extension_manager - .read_resource(tool_call.arguments.clone()) + .read_resource( + tool_call.arguments.clone(), + cancellation_token.unwrap_or_default(), + ) .await, ) } else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME { ToolCallResult::from( extension_manager - .list_resources(tool_call.arguments.clone()) + .list_resources( + tool_call.arguments.clone(), + cancellation_token.unwrap_or_default(), + ) .await, ) } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { @@ -469,7 +475,7 @@ impl Agent { } else { // Clone the result to ensure no references to extension_manager are returned let result = extension_manager - .dispatch_tool_call(tool_call.clone()) + .dispatch_tool_call(tool_call.clone(), cancellation_token.unwrap_or_default()) .await; result.unwrap_or_else(|e| { ToolCallResult::from(Err(ToolError::ExecutionError(e.to_string()))) @@ -1161,7 +1167,7 @@ impl Agent { pub async fn list_extension_prompts(&self) -> HashMap> { let extension_manager = self.extension_manager.read().await; extension_manager - .list_prompts() + .list_prompts(CancellationToken::default()) .await .expect("Failed to list prompts") } @@ -1171,7 +1177,7 @@ impl Agent { // First find which extension has this prompt let prompts = extension_manager - .list_prompts() + .list_prompts(CancellationToken::default()) .await .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; @@ -1181,7 +1187,7 @@ impl Agent { .map(|(extension, _)| extension) { return extension_manager - .get_prompt(extension, name, arguments) + .get_prompt(extension, name, arguments, CancellationToken::default()) .await .map_err(|e| anyhow!("Failed to get prompt: {}", e)); } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index b3769fef4572..cd0a76d7ccdf 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -1,6 +1,6 @@ use anyhow::Result; use axum::http::{HeaderMap, HeaderName}; -use chrono::{DateTime, TimeZone, Utc}; +use chrono::{DateTime, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use futures::{future, FutureExt}; use mcp_core::handler::require_str_parameter; @@ -13,7 +13,6 @@ use rmcp::transport::{ use std::collections::{HashMap, HashSet}; use std::process::Stdio; use std::sync::Arc; -use std::sync::LazyLock; use std::time::Duration; use tempfile::tempdir; use tokio::io::AsyncReadExt; @@ -21,6 +20,7 @@ use tokio::process::Command; use tokio::sync::Mutex; use tokio::task; use tokio_stream::wrappers::ReceiverStream; +use tokio_util::sync::CancellationToken; use tracing::{error, warn}; use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo}; @@ -29,14 +29,9 @@ use crate::agents::extension::{Envs, ProcessExit}; use crate::config::{Config, ExtensionConfigManager}; use crate::prompt_template; use mcp_client::client::{McpClient, McpClientTrait}; -use rmcp::model::{Content, GetPromptResult, Prompt, Resource, ResourceContents, Tool}; +use rmcp::model::{Content, GetPromptResult, Prompt, ResourceContents, Tool}; use serde_json::Value; -// By default, we set it to Jan 1, 2020 if the resource does not have a timestamp -// This is to ensure that the resource is considered less important than resources with a more recent timestamp -static DEFAULT_TIMESTAMP: LazyLock> = - LazyLock::new(|| Utc.with_ymd_and_hms(2020, 1, 1, 0, 0, 0).unwrap()); - type McpClientBox = Arc>>; /// Manages Goose extensions / MCP clients and their interactions @@ -457,7 +452,9 @@ impl ExtensionManager { task::spawn(async move { let mut tools = Vec::new(); let client_guard = client.lock().await; - let mut client_tools = client_guard.list_tools(None).await?; + let mut client_tools = client_guard + .list_tools(None, CancellationToken::default()) + .await?; loop { for tool in client_tools.tools { @@ -474,7 +471,9 @@ impl ExtensionManager { break; } - client_tools = client_guard.list_tools(client_tools.next_cursor).await?; + client_tools = client_guard + .list_tools(client_tools.next_cursor, CancellationToken::default()) + .await?; } Ok::, ExtensionError>(tools) @@ -497,43 +496,6 @@ impl ExtensionManager { Ok(tools) } - /// Get client resources and their contents - pub async fn get_resources(&self) -> ExtensionResult> { - let mut result: Vec = Vec::new(); - - for (name, client) in &self.clients { - let client_guard = client.lock().await; - let resources = client_guard.list_resources(None).await?; - - for resource in resources.resources { - // Skip reading the resource if it's not marked active - // This avoids blowing up the context with inactive resources - if !resource_is_active(&resource) { - continue; - } - - if let Ok(contents) = client_guard.read_resource(&resource.uri).await { - for content in contents.contents { - let (uri, content_str) = match content { - ResourceContents::TextResourceContents { uri, text, .. } => (uri, text), - ResourceContents::BlobResourceContents { uri, blob, .. } => (uri, blob), - }; - - result.push(ResourceItem::new( - name.clone(), - uri, - resource.name.clone(), - content_str, - resource.timestamp().unwrap_or(*DEFAULT_TIMESTAMP), - resource.priority().unwrap_or(0.0), - )); - } - } - } - } - Ok(result) - } - /// Get the extension prompt including client instructions pub async fn get_planning_prompt(&self, tools_info: Vec) -> String { let mut context: HashMap<&str, Value> = HashMap::new(); @@ -551,14 +513,22 @@ impl ExtensionManager { } // Function that gets executed for read_resource tool - pub async fn read_resource(&self, params: Value) -> Result, ToolError> { + pub async fn read_resource( + &self, + params: Value, + cancellation_token: CancellationToken, + ) -> Result, ToolError> { let uri = require_str_parameter(¶ms, "uri")?; let extension_name = params.get("extension_name").and_then(|v| v.as_str()); // If extension name is provided, we can just look it up if extension_name.is_some() { let result = self - .read_resource_from_extension(uri, extension_name.unwrap()) + .read_resource_from_extension( + uri, + extension_name.unwrap(), + cancellation_token.clone(), + ) .await?; return Ok(result); } @@ -568,7 +538,9 @@ impl ExtensionManager { // TODO: do we want to find if a provided uri is in multiple extensions? // currently it will return the first match and skip any others for extension_name in self.resource_capable_extensions.iter() { - let result = self.read_resource_from_extension(uri, extension_name).await; + let result = self + .read_resource_from_extension(uri, extension_name, cancellation_token.clone()) + .await; match result { Ok(result) => return Ok(result), Err(_) => continue, @@ -594,6 +566,7 @@ impl ExtensionManager { &self, uri: &str, extension_name: &str, + cancellation_token: CancellationToken, ) -> Result, ToolError> { let available_extensions = self .clients @@ -612,9 +585,12 @@ impl ExtensionManager { .ok_or(ToolError::InvalidParameters(error_msg))?; let client_guard = client.lock().await; - let read_result = client_guard.read_resource(uri).await.map_err(|_| { - ToolError::ExecutionError(format!("Could not read resource with uri: {}", uri)) - })?; + let read_result = client_guard + .read_resource(uri, cancellation_token) + .await + .map_err(|_| { + ToolError::ExecutionError(format!("Could not read resource with uri: {}", uri)) + })?; let mut result = Vec::new(); for content in read_result.contents { @@ -631,6 +607,7 @@ impl ExtensionManager { async fn list_resources_from_extension( &self, extension_name: &str, + cancellation_token: CancellationToken, ) -> Result, ToolError> { let client = self.clients.get(extension_name).ok_or_else(|| { ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name)) @@ -638,7 +615,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .list_resources(None) + .list_resources(None, cancellation_token) .await .map_err(|e| { ToolError::ExecutionError(format!( @@ -658,13 +635,18 @@ impl ExtensionManager { }) } - pub async fn list_resources(&self, params: Value) -> Result, ToolError> { + pub async fn list_resources( + &self, + params: Value, + cancellation_token: CancellationToken, + ) -> Result, ToolError> { let extension = params.get("extension").and_then(|v| v.as_str()); match extension { Some(extension_name) => { // Handle single extension case - self.list_resources_from_extension(extension_name).await + self.list_resources_from_extension(extension_name, cancellation_token) + .await } None => { // Handle all extensions case using FuturesUnordered @@ -672,8 +654,10 @@ impl ExtensionManager { // Create futures for each resource_capable_extension for extension_name in &self.resource_capable_extensions { + let token = cancellation_token.clone(); futures.push(async move { - self.list_resources_from_extension(extension_name).await + self.list_resources_from_extension(extension_name, token) + .await }); } @@ -708,7 +692,11 @@ impl ExtensionManager { } } - pub async fn dispatch_tool_call(&self, tool_call: ToolCall) -> Result { + pub async fn dispatch_tool_call( + &self, + tool_call: ToolCall, + cancellation_token: CancellationToken, + ) -> Result { // Dispatch tool call based on the prefix naming convention let (client_name, client) = self .get_client_for_tool(&tool_call.name) @@ -729,7 +717,7 @@ impl ExtensionManager { let fut = async move { let client_guard = client.lock().await; client_guard - .call_tool(&tool_name, arguments) + .call_tool(&tool_name, arguments, cancellation_token) .await .map(|call| call.content) .map_err(|e| ToolError::ExecutionError(e.to_string())) @@ -744,6 +732,7 @@ impl ExtensionManager { pub async fn list_prompts_from_extension( &self, extension_name: &str, + cancellation_token: CancellationToken, ) -> Result, ToolError> { let client = self.clients.get(extension_name).ok_or_else(|| { ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name)) @@ -751,7 +740,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .list_prompts(None) + .list_prompts(None, cancellation_token) .await .map_err(|e| { ToolError::ExecutionError(format!( @@ -762,14 +751,19 @@ impl ExtensionManager { .map(|lp| lp.prompts) } - pub async fn list_prompts(&self) -> Result>, ToolError> { + pub async fn list_prompts( + &self, + cancellation_token: CancellationToken, + ) -> Result>, ToolError> { let mut futures = FuturesUnordered::new(); for extension_name in self.clients.keys() { + let token = cancellation_token.clone(); futures.push(async move { ( extension_name, - self.list_prompts_from_extension(extension_name).await, + self.list_prompts_from_extension(extension_name, token) + .await, ) }); } @@ -809,6 +803,7 @@ impl ExtensionManager { extension_name: &str, name: &str, arguments: Value, + cancellation_token: CancellationToken, ) -> Result { let client = self .clients @@ -817,7 +812,7 @@ impl ExtensionManager { let client_guard = client.lock().await; client_guard - .get_prompt(name, arguments) + .get_prompt(name, arguments, cancellation_token) .await .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e)) } @@ -896,10 +891,6 @@ impl ExtensionManager { } } -fn resource_is_active(resource: &Resource) -> bool { - resource.priority().is_some_and(|p| (p - 1.0).abs() < 1e-6) -} - #[cfg(test)] mod tests { use super::*; @@ -927,19 +918,33 @@ mod tests { async fn list_resources( &self, _next_cursor: Option, + _cancellation_token: CancellationToken, ) -> Result { Err(Error::TransportClosed) } - async fn read_resource(&self, _uri: &str) -> Result { + async fn read_resource( + &self, + _uri: &str, + _cancellation_token: CancellationToken, + ) -> Result { Err(Error::TransportClosed) } - async fn list_tools(&self, _next_cursor: Option) -> Result { + async fn list_tools( + &self, + _next_cursor: Option, + _cancellation_token: CancellationToken, + ) -> Result { Err(Error::TransportClosed) } - async fn call_tool(&self, name: &str, _arguments: Value) -> Result { + async fn call_tool( + &self, + name: &str, + _arguments: Value, + _cancellation_token: CancellationToken, + ) -> Result { match name { "tool" | "test__tool" => Ok(CallToolResult { content: vec![], @@ -952,6 +957,7 @@ mod tests { async fn list_prompts( &self, _next_cursor: Option, + _cancellation_token: CancellationToken, ) -> Result { Err(Error::TransportClosed) } @@ -960,6 +966,7 @@ mod tests { &self, _name: &str, _arguments: Value, + _cancellation_token: CancellationToken, ) -> Result { Err(Error::TransportClosed) } @@ -1043,7 +1050,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); let tool_call = ToolCall { @@ -1051,7 +1060,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); // verify a multiple underscores dispatch @@ -1060,7 +1071,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); // Test unicode in tool name, "client 🚀" should become "client_" @@ -1069,7 +1082,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); let tool_call = ToolCall { @@ -1077,7 +1092,9 @@ mod tests { arguments: json!({}), }; - let result = extension_manager.dispatch_tool_call(tool_call).await; + let result = extension_manager + .dispatch_tool_call(tool_call, CancellationToken::default()) + .await; assert!(result.is_ok()); // this should error out, specifically for an ToolError::ExecutionError @@ -1087,7 +1104,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(invalid_tool_call) + .dispatch_tool_call(invalid_tool_call, CancellationToken::default()) .await .unwrap() .result @@ -1105,7 +1122,7 @@ mod tests { }; let result = extension_manager - .dispatch_tool_call(invalid_tool_call) + .dispatch_tool_call(invalid_tool_call, CancellationToken::default()) .await; if let Err(err) = result { let tool_err = err.downcast_ref::().expect("Expected ToolError"); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 939916e0afeb..701d4aac60c0 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -15,6 +15,7 @@ use serde::{Deserialize, Serialize}; // use serde_json::{self}; use std::{collections::HashMap, sync::Arc}; use tokio::sync::{Mutex, RwLock}; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, instrument}; /// Status of a subagent @@ -197,7 +198,7 @@ impl SubAgent { .extension_manager .read() .await - .dispatch_tool_call(tool_call.clone()) + .dispatch_tool_call(tool_call.clone(), CancellationToken::default()) .await { Ok(result) => result.result.await, diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 78ff767ff5ba..812dd74d09a9 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -18,6 +18,7 @@ use tokio::sync::{ mpsc::{self, Sender}, Mutex, }; +use tokio_util::sync::CancellationToken; pub type BoxError = Box; @@ -28,17 +29,40 @@ pub trait McpClientTrait: Send + Sync { async fn list_resources( &self, next_cursor: Option, + cancel_token: CancellationToken, ) -> Result; - async fn read_resource(&self, uri: &str) -> Result; + async fn read_resource( + &self, + uri: &str, + cancel_token: CancellationToken, + ) -> Result; - async fn list_tools(&self, next_cursor: Option) -> Result; + async fn list_tools( + &self, + next_cursor: Option, + cancel_token: CancellationToken, + ) -> Result; - async fn call_tool(&self, name: &str, arguments: Value) -> Result; + async fn call_tool( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result; - async fn list_prompts(&self, next_cursor: Option) -> Result; + async fn list_prompts( + &self, + next_cursor: Option, + cancel_token: CancellationToken, + ) -> Result; - async fn get_prompt(&self, name: &str, arguments: Value) -> Result; + async fn get_prompt( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result; async fn subscribe(&self) -> mpsc::Receiver; @@ -143,10 +167,32 @@ impl McpClient { }) } - fn get_request_options(&self) -> PeerRequestOptions { - PeerRequestOptions { - timeout: Some(self.timeout), - meta: None, + async fn send_request( + &self, + request: ClientRequest, + cancel_token: CancellationToken, + ) -> Result { + let handle = self + .client + .lock() + .await + .send_request_with_option( + request, + PeerRequestOptions { + timeout: Some(self.timeout), + meta: None, + }, + ) + .await?; + + let cancel_token = cancel_token.clone(); + tokio::select! { + res = handle.await_response() => { + Ok(res?) + } + _ = cancel_token.cancelled() => { + Err(Error::Cancelled{reason: None}) + } } } } @@ -157,34 +203,35 @@ impl McpClientTrait for McpClient { self.server_info.as_ref() } - async fn list_resources(&self, cursor: Option) -> Result { + async fn list_resources( + &self, + cursor: Option, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ListResourcesRequest(ListResourcesRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ListResourcesResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn read_resource(&self, uri: &str) -> Result { + async fn read_resource( + &self, + uri: &str, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ReadResourceRequest(ReadResourceRequest { params: ReadResourceRequestParam { uri: uri.to_string(), @@ -192,49 +239,50 @@ impl McpClientTrait for McpClient { method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ReadResourceResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn list_tools(&self, cursor: Option) -> Result { + async fn list_tools( + &self, + cursor: Option, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ListToolsRequest(ListToolsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ListToolsResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn call_tool(&self, name: &str, arguments: Value) -> Result { + async fn call_tool( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result { let arguments = match arguments { Value::Object(map) => Some(map), _ => None, }; let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::CallToolRequest(CallToolRequest { params: CallToolRequestParam { name: name.to_string().into(), @@ -243,49 +291,50 @@ impl McpClientTrait for McpClient { method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::CallToolResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn list_prompts(&self, cursor: Option) -> Result { + async fn list_prompts( + &self, + cursor: Option, + cancel_token: CancellationToken, + ) -> Result { let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::ListPromptsRequest(ListPromptsRequest { params: Some(PaginatedRequestParam { cursor }), method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::ListPromptsResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse), } } - async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + async fn get_prompt( + &self, + name: &str, + arguments: Value, + cancel_token: CancellationToken, + ) -> Result { let arguments = match arguments { Value::Object(map) => Some(map), _ => None, }; let res = self - .client - .lock() - .await - .send_request_with_option( + .send_request( ClientRequest::GetPromptRequest(GetPromptRequest { params: GetPromptRequestParam { name: name.to_string(), @@ -294,11 +343,10 @@ impl McpClientTrait for McpClient { method: Default::default(), extensions: Default::default(), }), - self.get_request_options(), + cancel_token, ) - .await? - .await_response() .await?; + match res { ServerResult::GetPromptResult(result) => Ok(result), _ => Err(ServiceError::UnexpectedResponse),