diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 8dc45289c39a..461843f4447f 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -28,7 +28,7 @@ use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; use crate::agents::tool_route_manager::ToolRouteManager; use crate::agents::tool_router_index_manager::ToolRouterIndexManager; use crate::agents::types::SessionConfig; -use crate::agents::types::{FrontendTool, ToolResultReceiver}; +use crate::agents::types::{FrontendTool, SharedProvider, ToolResultReceiver}; use crate::config::{get_enabled_extensions, Config}; use crate::context_mgmt::DEFAULT_COMPACTION_THRESHOLD; use crate::conversation::{debug_conversation_fix, fix_conversation, Conversation}; @@ -86,7 +86,8 @@ pub struct ToolCategorizeResult { /// The main goose Agent pub struct Agent { - pub(super) provider: Mutex>>, + pub(super) provider: SharedProvider, + pub extension_manager: Arc, pub(super) sub_recipe_manager: Mutex, pub(super) tasks_manager: TasksManager, @@ -159,10 +160,11 @@ impl Agent { // Create channels with buffer size 32 (adjust if needed) let (confirm_tx, confirm_rx) = mpsc::channel(32); let (tool_tx, tool_rx) = mpsc::channel(32); + let provider = Arc::new(Mutex::new(None)); Self { - provider: Mutex::new(None), - extension_manager: Arc::new(ExtensionManager::new()), + provider: provider.clone(), + extension_manager: Arc::new(ExtensionManager::new(provider.clone())), sub_recipe_manager: Mutex::new(SubRecipeManager::new()), tasks_manager: TasksManager::new(), final_output_tool: Arc::new(Mutex::new(None)), diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 7eaa2ea8bde2..fcea45250c88 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -12,6 +12,7 @@ use rmcp::transport::{ TokioChildProcess, }; use std::collections::HashMap; +use std::option::Option; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; @@ -29,6 +30,7 @@ use super::extension::{ ToolInfo, PLATFORM_EXTENSIONS, }; use super::tool_execution::ToolCallResult; +use super::types::SharedProvider; use crate::agents::extension::{Envs, ProcessExit}; use crate::agents::extension_malware_check; use crate::agents::mcp_client::{McpClient, McpClientTrait}; @@ -91,6 +93,7 @@ impl Extension { pub struct ExtensionManager { extensions: Mutex>, context: Mutex, + provider: SharedProvider, } /// A flattened representation of a resource used by the agent to prepare inference @@ -171,13 +174,14 @@ pub fn get_parameter_names(tool: &Tool) -> Vec { impl Default for ExtensionManager { fn default() -> Self { - Self::new() + Self::new(Arc::new(Mutex::new(None))) } } async fn child_process_client( mut command: Command, timeout: &Option, + provider: SharedProvider, ) -> ExtensionResult { #[cfg(unix)] command.process_group(0); @@ -205,6 +209,7 @@ async fn child_process_client( let client_result = McpClient::connect( transport, Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT)), + provider, ) .await; @@ -243,7 +248,7 @@ fn extract_auth_error( } impl ExtensionManager { - pub fn new() -> Self { + pub fn new(provider: SharedProvider) -> Self { Self { extensions: Mutex::new(HashMap::new()), context: Mutex::new(PlatformExtensionContext { @@ -251,9 +256,15 @@ impl ExtensionManager { extension_manager: None, tool_route_manager: None, }), + provider, } } + /// Create a new ExtensionManager with no provider (useful for tests) + pub fn new_without_provider() -> Self { + Self::new(Arc::new(Mutex::new(None))) + } + pub async fn set_context(&self, context: PlatformExtensionContext) { *self.context.lock().await = context; } @@ -348,6 +359,7 @@ impl ExtensionManager { Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), + self.provider.clone(), ) .await?, ) @@ -388,6 +400,7 @@ impl ExtensionManager { Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), + self.provider.clone(), ) .await; let client = if let Some(_auth_error) = extract_auth_error(&client_res) { @@ -407,6 +420,7 @@ impl ExtensionManager { Duration::from_secs( timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), ), + self.provider.clone(), ) .await? } else { @@ -430,7 +444,7 @@ impl ExtensionManager { // Check for malicious packages before launching the process extension_malware_check::deny_if_malicious_cmd_args(cmd, args).await?; - let client = child_process_client(command, timeout).await?; + let client = child_process_client(command, timeout, self.provider.clone()).await?; Box::new(client) } ExtensionConfig::Builtin { @@ -459,7 +473,7 @@ impl ExtensionManager { let command = Command::new(cmd).configure(|command| { command.arg("mcp").arg(name); }); - let client = child_process_client(command, timeout).await?; + let client = child_process_client(command, timeout, self.provider.clone()).await?; Box::new(client) } ExtensionConfig::Platform { name, .. } => { @@ -495,7 +509,7 @@ impl ExtensionManager { command.arg("python").arg(file_path.to_str().unwrap()); }); - let client = child_process_client(command, timeout).await?; + let client = child_process_client(command, timeout, self.provider.clone()).await?; Box::new(client) } @@ -1252,7 +1266,7 @@ mod tests { #[tokio::test] async fn test_get_client_for_tool() { - let extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new_without_provider(); // Add some mock clients using the helper method extension_manager @@ -1312,7 +1326,7 @@ mod tests { async fn test_dispatch_tool_call() { // test that dispatch_tool_call parses out the sanitized name correctly, and extracts // tool_names - let extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new_without_provider(); // Add some mock clients using the helper method extension_manager @@ -1429,7 +1443,7 @@ mod tests { #[tokio::test] async fn test_tool_availability_filtering() { - let extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new_without_provider(); // Only "available_tool" should be available to the LLM let available_tools = vec!["available_tool".to_string()]; @@ -1457,7 +1471,7 @@ mod tests { #[tokio::test] async fn test_tool_availability_defaults_to_available() { - let extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new_without_provider(); extension_manager .add_mock_extension_with_tools( @@ -1482,7 +1496,7 @@ mod tests { #[tokio::test] async fn test_dispatch_unavailable_tool_returns_error() { - let extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new_without_provider(); let available_tools = vec!["available_tool".to_string()]; diff --git a/crates/goose/src/agents/mcp_client.rs b/crates/goose/src/agents/mcp_client.rs index a166b8f7bec2..88c017a2e402 100644 --- a/crates/goose/src/agents/mcp_client.rs +++ b/crates/goose/src/agents/mcp_client.rs @@ -1,21 +1,24 @@ -use rmcp::model::JsonObject; +use crate::agents::types::SharedProvider; +use rmcp::model::{Content, ErrorCode, JsonObject}; /// MCP client implementation for Goose use rmcp::{ model::{ CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification, CancelledNotificationMethod, CancelledNotificationParam, ClientCapabilities, ClientInfo, - ClientRequest, GetPromptRequest, GetPromptRequestParam, GetPromptResult, Implementation, - InitializeResult, ListPromptsRequest, ListPromptsResult, ListResourcesRequest, - ListResourcesResult, ListToolsRequest, ListToolsResult, LoggingMessageNotification, + ClientRequest, CreateMessageRequestParam, CreateMessageResult, GetPromptRequest, + GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult, + ListPromptsRequest, ListPromptsResult, ListResourcesRequest, ListResourcesResult, + ListToolsRequest, ListToolsResult, LoggingMessageNotification, LoggingMessageNotificationMethod, PaginatedRequestParam, ProgressNotification, ProgressNotificationMethod, ProtocolVersion, ReadResourceRequest, ReadResourceRequestParam, - ReadResourceResult, RequestId, ServerNotification, ServerResult, + ReadResourceResult, RequestId, Role, SamplingMessage, ServerNotification, ServerResult, }, service::{ - ClientInitializeError, PeerRequestOptions, RequestHandle, RunningService, ServiceRole, + ClientInitializeError, PeerRequestOptions, RequestContext, RequestHandle, RunningService, + ServiceRole, }, transport::IntoTransport, - ClientHandler, Peer, RoleClient, ServiceError, ServiceExt, + ClientHandler, ErrorData, Peer, RoleClient, ServiceError, ServiceExt, }; use serde_json::Value; use std::{sync::Arc, time::Duration}; @@ -76,12 +79,17 @@ pub trait McpClientTrait: Send + Sync { pub struct GooseClient { notification_handlers: Arc>>>, + provider: SharedProvider, } impl GooseClient { - pub fn new(handlers: Arc>>>) -> Self { + pub fn new( + handlers: Arc>>>, + provider: SharedProvider, + ) -> Self { GooseClient { notification_handlers: handlers, + provider, } } } @@ -127,10 +135,88 @@ impl ClientHandler for GooseClient { }); } + async fn create_message( + &self, + params: CreateMessageRequestParam, + _context: RequestContext, + ) -> Result { + let provider = self + .provider + .lock() + .await + .as_ref() + .ok_or(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "Could not use provider", + None, + ))? + .clone(); + + let provider_ready_messages: Vec = params + .messages + .iter() + .map(|msg| { + let base = match msg.role { + Role::User => crate::conversation::message::Message::user(), + Role::Assistant => crate::conversation::message::Message::assistant(), + }; + + match msg.content.as_text() { + Some(text) => base.with_text(&text.text), + None => base.with_content(msg.content.clone().into()), + } + }) + .collect(); + + let system_prompt = params + .system_prompt + .as_deref() + .unwrap_or("You are a general-purpose AI agent called goose"); + + let (response, usage) = provider + .complete(system_prompt, &provider_ready_messages, &[]) + .await + .map_err(|e| { + ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "Unexpected error while completing the prompt", + Some(Value::from(e.to_string())), + ) + })?; + + Ok(CreateMessageResult { + model: usage.model, + stop_reason: Some(CreateMessageResult::STOP_REASON_END_TURN.to_string()), + message: SamplingMessage { + role: Role::Assistant, + // TODO(alexhancock): MCP sampling currently only supports one content on each SamplingMessage + // https://modelcontextprotocol.io/specification/draft/client/sampling#messages + // This doesn't mesh well with goose's approach which has Vec + // There is a proposal to MCP which is agreed to go in the next version to have SamplingMessages support multiple content parts + // https://github.com/modelcontextprotocol/modelcontextprotocol/pull/198 + // Until that is formalized, we can take the first message content from the provider and use it + content: if let Some(content) = response.content.first() { + match content { + crate::conversation::message::MessageContent::Text(text) => { + Content::text(&text.text) + } + crate::conversation::message::MessageContent::Image(img) => { + Content::image(&img.data, &img.mime_type) + } + // TODO(alexhancock) - Content::Audio? goose's messages don't currently have it + _ => Content::text(""), + } + } else { + Content::text("") + }, + }, + }) + } + fn get_info(&self) -> ClientInfo { ClientInfo { protocol_version: ProtocolVersion::V_2025_03_26, - capabilities: ClientCapabilities::builder().build(), + capabilities: ClientCapabilities::builder().enable_sampling().build(), client_info: Implementation { name: "goose".to_string(), version: std::env::var("GOOSE_MCP_CLIENT_VERSION") @@ -155,6 +241,7 @@ impl McpClient { pub async fn connect( transport: T, timeout: std::time::Duration, + provider: SharedProvider, ) -> Result where T: IntoTransport, @@ -163,7 +250,7 @@ impl McpClient { let notification_subscribers = Arc::new(Mutex::new(Vec::>::new())); - let client = GooseClient::new(notification_subscribers.clone()); + let client = GooseClient::new(notification_subscribers.clone(), provider); let client: rmcp::service::RunningService = client.serve(transport).await?; let server_info = client.peer_info().cloned(); diff --git a/crates/goose/src/agents/types.rs b/crates/goose/src/agents/types.rs index 0518c65789b3..027560179b2a 100644 --- a/crates/goose/src/agents/types.rs +++ b/crates/goose/src/agents/types.rs @@ -1,4 +1,5 @@ use crate::mcp_utils::ToolResult; +use crate::providers::base::Provider; use rmcp::model::{Content, Tool}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; @@ -9,6 +10,9 @@ use utoipa::ToSchema; /// Type alias for the tool result channel receiver pub type ToolResultReceiver = Arc>)>>>; +// We use double Arc here to allow easy provider swaps while sharing concurrent access +pub type SharedProvider = Arc>>>; + /// Default timeout for retry operations (5 minutes) pub const DEFAULT_RETRY_TIMEOUT_SECONDS: u64 = 300; diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 9182735a5863..6704119e638b 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -205,8 +205,7 @@ async fn test_replayed_session( bundled: Some(false), available_tools: vec![], }; - - let extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new_without_provider(); #[allow(clippy::redundant_closure_call)] let result = (async || -> Result<(), Box> { diff --git a/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper b/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper index b8b5c12a217b..c5b490b065f4 100644 --- a/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper +++ b/crates/goose/tests/mcp_replays/cargorun--quiet-pgoose-server--bingoosed--mcpdeveloper @@ -1,4 +1,4 @@ -STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"goose","version":"0.0.0"}}} +STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{"sampling":{}},"clientInfo":{"name":"goose","version":"0.0.0"}}} STDERR: 2025-09-27T04:13:30.409389Z  INFO goose_mcp::mcp_server_runner: Starting MCP server STDERR: at crates/goose-mcp/src/mcp_server_runner.rs:18 STDERR: diff --git a/crates/goose/tests/mcp_replays/github-mcp-serverstdio b/crates/goose/tests/mcp_replays/github-mcp-serverstdio index e7f70639287c..f38a2473b14a 100644 --- a/crates/goose/tests/mcp_replays/github-mcp-serverstdio +++ b/crates/goose/tests/mcp_replays/github-mcp-serverstdio @@ -1,4 +1,4 @@ -STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"goose","version":"0.0.0"}}} +STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{"sampling":{}},"clientInfo":{"name":"goose","version":"0.0.0"}}} STDERR: GitHub MCP Server running on stdio STDOUT: {"jsonrpc":"2.0","id":0,"result":{"protocolVersion":"2025-03-26","capabilities":{"logging":{},"prompts":{},"resources":{"subscribe":true,"listChanged":true},"tools":{"listChanged":true}},"serverInfo":{"name":"github-mcp-server","version":"version"}}} STDIN: {"jsonrpc":"2.0","method":"notifications/initialized"} diff --git a/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything b/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything index 4fbb74482a5d..8ac4f0843488 100644 --- a/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything +++ b/crates/goose/tests/mcp_replays/npx-y@modelcontextprotocol_server-everything @@ -1,4 +1,4 @@ -STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"goose","version":"0.0.0"}}} +STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{"sampling":{}},"clientInfo":{"name":"goose","version":"0.0.0"}}} STDERR: 2025-09-26 23:13:04 - Starting npx setup script. STDERR: 2025-09-26 23:13:04 - Creating directory ~/.config/goose/mcp-hermit/bin if it does not exist. STDERR: 2025-09-26 23:13:04 - Changing to directory ~/.config/goose/mcp-hermit. diff --git a/crates/goose/tests/mcp_replays/uvxmcp-server-fetch b/crates/goose/tests/mcp_replays/uvxmcp-server-fetch index 7362e657a902..a473f29fa8eb 100644 --- a/crates/goose/tests/mcp_replays/uvxmcp-server-fetch +++ b/crates/goose/tests/mcp_replays/uvxmcp-server-fetch @@ -1,4 +1,4 @@ -STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"goose","version":"0.0.0"}}} +STDIN: {"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{"sampling":{}},"clientInfo":{"name":"goose","version":"0.0.0"}}} STDERR: 2025-09-26 23:13:04 - Starting uvx setup script. STDERR: 2025-09-26 23:13:04 - Creating directory ~/.config/goose/mcp-hermit/bin if it does not exist. STDERR: 2025-09-26 23:13:04 - Changing to directory ~/.config/goose/mcp-hermit.