From f942684fadb397b2121b1809965d41cd0894417b Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Wed, 20 Aug 2025 16:40:34 -0400 Subject: [PATCH 1/4] Unlock the extension manager --- crates/goose/src/agents/agent.rs | 79 +++-- crates/goose/src/agents/extension_manager.rs | 284 +++++++++++------- crates/goose/src/agents/reply_parts.rs | 7 +- crates/goose/src/agents/subagent.rs | 2 +- crates/goose/src/agents/tool_route_manager.rs | 10 +- .../src/agents/tool_router_index_manager.rs | 2 +- 6 files changed, 212 insertions(+), 172 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 256779095d1a..50f02c70fa52 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -49,7 +49,7 @@ use rmcp::model::{ Content, ErrorCode, ErrorData, GetPromptResult, Prompt, ServerNotification, Tool, }; use serde_json::Value; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{mpsc, Mutex}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, instrument}; @@ -88,7 +88,7 @@ pub struct ToolCategorizeResult { /// The main goose Agent pub struct Agent { pub(super) provider: Mutex>>, - pub extension_manager: Arc>, + pub extension_manager: ExtensionManager, pub(super) sub_recipe_manager: Mutex, pub(super) tasks_manager: TasksManager, pub(super) final_output_tool: Arc>>, @@ -174,7 +174,7 @@ impl Agent { Self { provider: Mutex::new(None), - extension_manager: Arc::new(RwLock::new(ExtensionManager::new())), + extension_manager: ExtensionManager::new(), sub_recipe_manager: Mutex::new(SubRecipeManager::new()), tasks_manager: TasksManager::new(), final_output_tool: Arc::new(Mutex::new(None)), @@ -439,7 +439,6 @@ impl Agent { }; } - let extension_manager = self.extension_manager.read().await; let sub_recipe_manager = self.sub_recipe_manager.lock().await; let result: ToolCallResult = if sub_recipe_manager.is_sub_recipe_tool(&tool_call.name) { sub_recipe_manager @@ -465,7 +464,7 @@ impl Agent { } else if tool_call.name == PLATFORM_READ_RESOURCE_TOOL_NAME { // Check if the tool is read_resource and handle it separately ToolCallResult::from( - extension_manager + self.extension_manager .read_resource( tool_call.arguments.clone(), cancellation_token.unwrap_or_default(), @@ -474,7 +473,7 @@ impl Agent { ) } else if tool_call.name == PLATFORM_LIST_RESOURCES_TOOL_NAME { ToolCallResult::from( - extension_manager + self.extension_manager .list_resources( tool_call.arguments.clone(), cancellation_token.unwrap_or_default(), @@ -482,7 +481,7 @@ impl Agent { .await, ) } else if tool_call.name == PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME { - ToolCallResult::from(extension_manager.search_available_extensions().await) + ToolCallResult::from(self.extension_manager.search_available_extensions().await) } else if self.is_frontend_tool(&tool_call.name).await { // For frontend tools, return an error indicating we need frontend execution ToolCallResult::from(Err(ErrorData::new( @@ -542,7 +541,8 @@ impl Agent { } } else { // Clone the result to ensure no references to extension_manager are returned - let result = extension_manager + let result = self + .extension_manager .dispatch_tool_call(tool_call.clone(), cancellation_token.unwrap_or_default()) .await; result.unwrap_or_else(|e| { @@ -578,11 +578,10 @@ impl Agent { let selector = self.tool_route_manager.get_router_tool_selector().await; if let Some(selector) = selector { let selector_action = if action == "disable" { "remove" } else { "add" }; - let extension_manager = self.extension_manager.read().await; let selector = Arc::new(selector); if let Err(e) = ToolRouterIndexManager::update_extension_tools( &selector, - &extension_manager, + &self.extension_manager, &extension_name, selector_action, ) @@ -599,9 +598,9 @@ impl Agent { } } } - let mut extension_manager = self.extension_manager.write().await; if action == "disable" { - let result = extension_manager + let result = self + .extension_manager .remove_extension(&extension_name) .await .map(|_| { @@ -640,7 +639,8 @@ impl Agent { ) } }; - let result = extension_manager + let result = self + .extension_manager .add_extension(config) .await .map(|_| { @@ -651,17 +651,15 @@ impl Agent { }) .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None)); - drop(extension_manager); // Update LLM index if operation was successful and LLM routing is functional if result.is_ok() && self.tool_route_manager.is_router_functional().await { let selector = self.tool_route_manager.get_router_tool_selector().await; if let Some(selector) = selector { let llm_action = if action == "disable" { "remove" } else { "add" }; - let extension_manager = self.extension_manager.read().await; let selector = Arc::new(selector); if let Err(e) = ToolRouterIndexManager::update_extension_tools( &selector, - &extension_manager, + &self.extension_manager, &extension_name, llm_action, ) @@ -711,8 +709,9 @@ impl Agent { } } _ => { - let mut extension_manager = self.extension_manager.write().await; - extension_manager.add_extension(extension.clone()).await?; + self.extension_manager + .add_extension(extension.clone()) + .await?; } } @@ -720,11 +719,10 @@ impl Agent { if self.tool_route_manager.is_router_functional().await { let selector = self.tool_route_manager.get_router_tool_selector().await; if let Some(selector) = selector { - let extension_manager = self.extension_manager.read().await; let selector = Arc::new(selector); if let Err(e) = ToolRouterIndexManager::update_extension_tools( &selector, - &extension_manager, + &self.extension_manager, &extension.name(), "add", ) @@ -743,8 +741,8 @@ impl Agent { } pub async fn list_tools(&self, extension_name: Option) -> Vec { - let extension_manager = self.extension_manager.read().await; - let mut prefixed_tools = extension_manager + let mut prefixed_tools = self + .extension_manager .get_prefixed_tools(extension_name.clone()) .await .unwrap_or_default(); @@ -765,7 +763,7 @@ impl Agent { prefixed_tools.push(create_dynamic_task_tool()); // Add resource tools if supported - if extension_manager.supports_resources() { + if self.extension_manager.supports_resources().await { prefixed_tools.extend([ platform_tools::read_resource_tool(), platform_tools::list_resources_tool(), @@ -793,18 +791,15 @@ impl Agent { } pub async fn remove_extension(&self, name: &str) -> Result<()> { - let mut extension_manager = self.extension_manager.write().await; - extension_manager.remove_extension(name).await?; - drop(extension_manager); + self.extension_manager.remove_extension(name).await?; // If LLM tool selection is functional, remove tools from the index if self.tool_route_manager.is_router_functional().await { let selector = self.tool_route_manager.get_router_tool_selector().await; if let Some(selector) = selector { - let extension_manager = self.extension_manager.read().await; ToolRouterIndexManager::update_extension_tools( &selector, - &extension_manager, + &self.extension_manager, name, "remove", ) @@ -816,8 +811,7 @@ impl Agent { } pub async fn list_extensions(&self) -> Vec { - let extension_manager = self.extension_manager.read().await; - extension_manager + self.extension_manager .list_extensions() .await .expect("Failed to list extensions") @@ -1272,18 +1266,16 @@ impl Agent { } pub async fn list_extension_prompts(&self) -> HashMap> { - let extension_manager = self.extension_manager.read().await; - extension_manager + self.extension_manager .list_prompts(CancellationToken::default()) .await .expect("Failed to list prompts") } pub async fn get_prompt(&self, name: &str, arguments: Value) -> Result { - let extension_manager = self.extension_manager.read().await; - // First find which extension has this prompt - let prompts = extension_manager + let prompts = self + .extension_manager .list_prompts(CancellationToken::default()) .await .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; @@ -1293,7 +1285,8 @@ impl Agent { .find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name)) .map(|(extension, _)| extension) { - return extension_manager + return self + .extension_manager .get_prompt(extension, name, arguments, CancellationToken::default()) .await .map_err(|e| anyhow!("Failed to get prompt: {}", e)); @@ -1303,8 +1296,7 @@ impl Agent { } pub async fn get_plan_prompt(&self) -> Result { - let extension_manager = self.extension_manager.read().await; - let tools = extension_manager.get_prefixed_tools(None).await?; + let tools = self.extension_manager.get_prefixed_tools(None).await?; let tools_info = tools .into_iter() .map(|tool| { @@ -1320,7 +1312,7 @@ impl Agent { }) .collect(); - let plan_prompt = extension_manager.get_planning_prompt(tools_info).await; + let plan_prompt = self.extension_manager.get_planning_prompt(tools_info).await; Ok(plan_prompt) } @@ -1332,8 +1324,7 @@ impl Agent { } pub async fn create_recipe(&self, mut messages: Conversation) -> Result { - let extension_manager = self.extension_manager.read().await; - let extensions_info = extension_manager.get_extensions_info().await; + let extensions_info = self.extension_manager.get_extensions_info().await; // Get model name from provider let provider = self.provider().await?; @@ -1344,13 +1335,15 @@ impl Agent { let system_prompt = prompt_manager.build_system_prompt( extensions_info, self.frontend_instructions.lock().await.clone(), - extension_manager.suggest_disable_extensions_prompt().await, + self.extension_manager + .suggest_disable_extensions_prompt() + .await, Some(model_name), false, ); let recipe_prompt = prompt_manager.get_recipe_prompt().await; - let tools = extension_manager.get_prefixed_tools(None).await?; + let tools = self.extension_manager.get_prefixed_tools(None).await?; messages.push(Message::user().with_text(recipe_prompt)); diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 470c0aa8bd97..0c3368400904 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -10,7 +10,7 @@ use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig use rmcp::transport::{ ConfigureCommandExt, SseClientTransport, StreamableHttpClientTransport, TokioChildProcess, }; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; @@ -30,19 +30,58 @@ use crate::config::{Config, ExtensionConfigManager}; use crate::oauth::oauth_flow; use crate::prompt_template; use mcp_client::client::{McpClient, McpClientTrait}; -use rmcp::model::{Content, ErrorCode, ErrorData, GetPromptResult, Prompt, ResourceContents, Tool}; +use rmcp::model::{ + Content, ErrorCode, ErrorData, GetPromptResult, Prompt, ResourceContents, ServerInfo, Tool, +}; use rmcp::transport::auth::AuthClient; use serde_json::Value; type McpClientBox = Arc>>; +struct Extension { + pub config: ExtensionConfig, + + client: McpClientBox, + server_info: Option, + _temp_dir: Option, +} + +impl Extension { + fn new( + config: ExtensionConfig, + client: McpClientBox, + server_info: Option, + temp_dir: Option, + ) -> Self { + Self { + client, + config, + server_info, + _temp_dir: temp_dir, + } + } + + fn supports_resources(&self) -> bool { + self.server_info + .as_ref() + .and_then(|info| info.capabilities.resources.as_ref()) + .is_some() + } + + fn get_instructions(&self) -> Option { + self.server_info + .as_ref() + .and_then(|info| info.instructions.clone()) + } + + fn get_client(&self) -> McpClientBox { + self.client.clone() + } +} + /// Manages Goose extensions / MCP clients and their interactions pub struct ExtensionManager { - clients: HashMap, - instructions: HashMap, - resource_capable_extensions: HashSet, - temp_dirs: HashMap, - extension_configs: HashMap, + extensions: Mutex>, } /// A flattened representation of a resource used by the agent to prepare inference @@ -149,26 +188,24 @@ async fn child_process_client( } impl ExtensionManager { - /// Create a new ExtensionManager instance pub fn new() -> Self { Self { - clients: HashMap::new(), - instructions: HashMap::new(), - resource_capable_extensions: HashSet::new(), - temp_dirs: HashMap::new(), - extension_configs: HashMap::new(), + extensions: Mutex::new(HashMap::new()), } } - pub fn supports_resources(&self) -> bool { - !self.resource_capable_extensions.is_empty() + pub async fn supports_resources(&self) -> bool { + self.extensions + .lock() + .await + .values() + .any(|ext| ext.supports_resources()) } - /// Add a new MCP extension based on the provided client type - // TODO IMPORTANT need to ensure this times out if the extension command is broken! - pub async fn add_extension(&mut self, config: ExtensionConfig) -> ExtensionResult<()> { + pub async fn add_extension(&self, config: ExtensionConfig) -> ExtensionResult<()> { let config_name = config.key().to_string(); let sanitized_name = normalize(config_name.clone()); + let mut temp_dir = None; /// Helper function to merge environment variables from direct envs and keychain-stored env_keys async fn merge_environments( @@ -355,8 +392,9 @@ impl ExtensionManager { dependencies, .. } => { - let temp_dir = tempdir()?; - let file_path = temp_dir.path().join(format!("{}.py", name)); + let dir = tempdir()?; + let file_path = dir.path().join(format!("{}.py", name)); + temp_dir = Some(dir); std::fs::write(&file_path, code)?; let command = Command::new("uvx").configure(|command| { @@ -370,61 +408,46 @@ impl ExtensionManager { }); let client = child_process_client(command, timeout).await?; - self.temp_dirs.insert(sanitized_name.clone(), temp_dir); Box::new(client) } _ => unreachable!(), }; - let info = client.get_info(); - if let Some(instructions) = info.and_then(|info| info.instructions.as_ref()) { - self.instructions - .insert(sanitized_name.clone(), instructions.clone()); - } - - if let Some(_resources) = info.and_then(|info| info.capabilities.resources.as_ref()) { - self.resource_capable_extensions - .insert(sanitized_name.clone()); - } + let server_info = client.get_info().cloned(); + self.extensions.lock().await.insert( + sanitized_name, + Extension::new(config, Arc::new(Mutex::new(client)), server_info, temp_dir), + ); - self.add_client(sanitized_name.clone(), client); - self.extension_configs.insert(sanitized_name, config); Ok(()) } - pub fn add_client(&mut self, client_name: String, client: Box) { - let sanitized_name = normalize(client_name); - self.clients - .insert(sanitized_name, Arc::new(Mutex::new(client))); - } - /// Get extensions info pub async fn get_extensions_info(&self) -> Vec { - self.clients - .keys() - .map(|name| { - let instructions = self.instructions.get(name).cloned().unwrap_or_default(); - let has_resources = self.resource_capable_extensions.contains(name); - ExtensionInfo::new(name, &instructions, has_resources) + self.extensions + .lock() + .await + .iter() + .map(|(name, ext)| { + ExtensionInfo::new( + name, + ext.get_instructions().unwrap_or_default().as_str(), + ext.supports_resources(), + ) }) .collect() } /// Get aggregated usage statistics - pub async fn remove_extension(&mut self, name: &str) -> ExtensionResult<()> { + pub async fn remove_extension(&self, name: &str) -> ExtensionResult<()> { let sanitized_name = normalize(name.to_string()); - - self.clients.remove(&sanitized_name); - self.instructions.remove(&sanitized_name); - self.resource_capable_extensions.remove(&sanitized_name); - self.temp_dirs.remove(&sanitized_name); - self.extension_configs.remove(&sanitized_name); + self.extensions.lock().await.remove(&sanitized_name); Ok(()) } pub async fn suggest_disable_extensions_prompt(&self) -> Value { - let enabled_extensions_count = self.clients.len(); + let enabled_extensions_count = self.extensions.lock().await.len(); let total_tools = self .get_prefixed_tools(None) @@ -456,7 +479,7 @@ impl ExtensionManager { } pub async fn list_extensions(&self) -> ExtensionResult> { - Ok(self.clients.keys().cloned().collect()) + Ok(self.extensions.lock().await.keys().cloned().collect()) } /// Get all tools from all clients with proper prefixing @@ -465,32 +488,32 @@ impl ExtensionManager { extension_name: Option, ) -> ExtensionResult> { // Filter clients based on the provided extension_name or include all if None - let filtered_clients = self.clients.iter().filter(|(name, _)| { - if let Some(ref name_filter) = extension_name { - *name == name_filter - } else { - true - } - }); - - let client_futures = filtered_clients.map(|(name, client)| { - let name = name.clone(); - let client = client.clone(); - let extension_config = self.extension_configs.get(&name).cloned(); + let filtered_clients: Vec<_> = self + .extensions + .lock() + .await + .iter() + .filter(|(name, _ext)| { + if let Some(ref name_filter) = extension_name { + *name == name_filter + } else { + true + } + }) + .map(|(name, ext)| (name.clone(), ext.config.clone(), ext.get_client())) + .collect(); + let cancel_token = CancellationToken::default(); + let client_futures = filtered_clients.into_iter().map(|(name, config, client)| { + let cancel_token = cancel_token.clone(); task::spawn(async move { let mut tools = Vec::new(); let client_guard = client.lock().await; - let mut client_tools = client_guard - .list_tools(None, CancellationToken::default()) - .await?; + let mut client_tools = client_guard.list_tools(None, cancel_token).await?; loop { for tool in client_tools.tools { - let is_available = extension_config - .as_ref() - .map(|config| config.is_tool_available(&tool.name)) - .unwrap_or(true); + let is_available = config.is_tool_available(&tool.name); if is_available { tools.push(Tool { @@ -542,11 +565,13 @@ impl ExtensionManager { } /// Find and return a reference to the appropriate client for a tool call - fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(&str, McpClientBox)> { - self.clients + async fn get_client_for_tool(&self, prefixed_name: &str) -> Option<(String, McpClientBox)> { + self.extensions + .lock() + .await .iter() .find(|(key, _)| prefixed_name.starts_with(*key)) - .map(|(name, client)| (name.as_str(), Arc::clone(client))) + .map(|(name, extension)| (name.clone(), extension.get_client())) } // Function that gets executed for read_resource tool @@ -574,7 +599,7 @@ impl ExtensionManager { // Loop through each extension and try to read the resource, don't raise an error if the resource is not found // TODO: do we want to find if a provided uri is in multiple extensions? // currently it will return the first match and skip any others - for extension_name in self.resource_capable_extensions.iter() { + for extension_name in self.extensions.lock().await.keys() { let result = self .read_resource_from_extension(uri, extension_name, cancellation_token.clone()) .await; @@ -586,7 +611,9 @@ impl ExtensionManager { // None of the extensions had the resource so we raise an error let available_extensions = self - .clients + .extensions + .lock() + .await .keys() .map(|s| s.as_str()) .collect::>() @@ -610,7 +637,9 @@ impl ExtensionManager { cancellation_token: CancellationToken, ) -> Result, ErrorData> { let available_extensions = self - .clients + .extensions + .lock() + .await .keys() .map(|s| s.as_str()) .collect::>() @@ -620,11 +649,10 @@ impl ExtensionManager { extension_name, available_extensions ); - let client = self.clients.get(extension_name).ok_or(ErrorData::new( - ErrorCode::INVALID_PARAMS, - error_msg, - None, - ))?; + let client = self + .get_server_client(extension_name) + .await + .ok_or(ErrorData::new(ErrorCode::INVALID_PARAMS, error_msg, None))?; let client_guard = client.lock().await; let read_result = client_guard @@ -655,13 +683,16 @@ impl ExtensionManager { extension_name: &str, cancellation_token: CancellationToken, ) -> Result, ErrorData> { - let client = self.clients.get(extension_name).ok_or_else(|| { - ErrorData::new( - ErrorCode::INVALID_PARAMS, - format!("Extension {} is not valid", extension_name), - None, - ) - })?; + let client = self + .get_server_client(extension_name) + .await + .ok_or_else(|| { + ErrorData::new( + ErrorCode::INVALID_PARAMS, + format!("Extension {} is not valid", extension_name), + None, + ) + })?; let client_guard = client.lock().await; client_guard @@ -704,13 +735,19 @@ impl ExtensionManager { let mut futures = FuturesUnordered::new(); // Create futures for each resource_capable_extension - for extension_name in &self.resource_capable_extensions { - let token = cancellation_token.clone(); - futures.push(async move { - self.list_resources_from_extension(extension_name, token) - .await + self.extensions + .lock() + .await + .iter() + .filter(|(_name, ext)| ext.supports_resources()) + .map(|(name, _ext)| name.clone()) + .for_each(|name| { + let token = cancellation_token.clone(); + futures.push(async move { + self.list_resources_from_extension(&name.clone(), token) + .await + }); }); - } let mut all_resources = Vec::new(); let mut errors = Vec::new(); @@ -749,22 +786,25 @@ impl ExtensionManager { cancellation_token: CancellationToken, ) -> Result { // Dispatch tool call based on the prefix naming convention - let (client_name, client) = self.get_client_for_tool(&tool_call.name).ok_or_else(|| { - ErrorData::new(ErrorCode::RESOURCE_NOT_FOUND, tool_call.name.clone(), None) - })?; + let (client_name, client) = + self.get_client_for_tool(&tool_call.name) + .await + .ok_or_else(|| { + ErrorData::new(ErrorCode::RESOURCE_NOT_FOUND, tool_call.name.clone(), None) + })?; // rsplit returns the iterator in reverse, tool_name is then at 0 let tool_name = tool_call .name - .strip_prefix(client_name) + .strip_prefix(client_name.as_str()) .and_then(|s| s.strip_prefix("__")) .ok_or_else(|| { ErrorData::new(ErrorCode::RESOURCE_NOT_FOUND, tool_call.name.clone(), None) })? .to_string(); - if let Some(extension_config) = self.extension_configs.get(client_name) { - if !extension_config.is_tool_available(&tool_name) { + if let Some(extension) = self.extensions.lock().await.get(&client_name) { + if !extension.config.is_tool_available(&tool_name) { return Err(ErrorData::new( ErrorCode::RESOURCE_NOT_FOUND, format!( @@ -801,13 +841,16 @@ impl ExtensionManager { extension_name: &str, cancellation_token: CancellationToken, ) -> Result, ErrorData> { - let client = self.clients.get(extension_name).ok_or_else(|| { - ErrorData::new( - ErrorCode::INVALID_PARAMS, - format!("Extension {} is not valid", extension_name), - None, - ) - })?; + let client = self + .get_server_client(extension_name) + .await + .ok_or_else(|| { + ErrorData::new( + ErrorCode::INVALID_PARAMS, + format!("Extension {} is not valid", extension_name), + None, + ) + })?; let client_guard = client.lock().await; client_guard @@ -829,12 +872,12 @@ impl ExtensionManager { ) -> Result>, ErrorData> { let mut futures = FuturesUnordered::new(); - for extension_name in self.clients.keys() { + for extension_name in self.extensions.lock().await.keys().cloned() { let token = cancellation_token.clone(); futures.push(async move { ( - extension_name, - self.list_prompts_from_extension(extension_name, token) + extension_name.clone(), + self.list_prompts_from_extension(&extension_name.as_str(), token) .await, ) }); @@ -878,8 +921,8 @@ impl ExtensionManager { cancellation_token: CancellationToken, ) -> Result { let client = self - .clients - .get(extension_name) + .get_server_client(extension_name) + .await .ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?; let client_guard = client.lock().await; @@ -934,7 +977,8 @@ impl ExtensionManager { } // Get currently enabled extensions that can be disabled - let enabled_extensions: Vec = self.clients.keys().cloned().collect(); + let enabled_extensions: Vec = + self.extensions.lock().await.keys().cloned().collect(); // Build output string if !disabled_extensions.is_empty() { @@ -961,6 +1005,14 @@ impl ExtensionManager { Ok(vec![Content::text(output_parts.join("\n"))]) } + + async fn get_server_client(&self, name: impl Into) -> Option { + self.extensions + .lock() + .await + .get(&name.into()) + .map(|ext| ext.get_client()) + } } #[cfg(test)] diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 46178c8933de..14cbc7888e50 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -52,8 +52,7 @@ impl Agent { } // Prepare system prompt - let extension_manager = self.extension_manager.read().await; - let extensions_info = extension_manager.get_extensions_info().await; + let extensions_info = self.extension_manager.get_extensions_info().await; // Get model name from provider let provider = self.provider().await?; @@ -64,7 +63,9 @@ impl Agent { let mut system_prompt = prompt_manager.build_system_prompt( extensions_info, self.frontend_instructions.lock().await.clone(), - extension_manager.suggest_disable_extensions_prompt().await, + self.extension_manager + .suggest_disable_extensions_prompt() + .await, Some(model_name), router_enabled, ); diff --git a/crates/goose/src/agents/subagent.rs b/crates/goose/src/agents/subagent.rs index 5e4334338d0c..50eded2b30fd 100644 --- a/crates/goose/src/agents/subagent.rs +++ b/crates/goose/src/agents/subagent.rs @@ -57,7 +57,7 @@ impl SubAgent { debug!("Creating new subagent with id: {}", task_config.id); // Create a new extension manager for this subagent - let mut extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new(); // Add extensions based on task_type: // 1. If executing dynamic task (task_type = 'text_instruction'), default to using all enabled extensions diff --git a/crates/goose/src/agents/tool_route_manager.rs b/crates/goose/src/agents/tool_route_manager.rs index 11c44f4b326e..697a4731b46c 100644 --- a/crates/goose/src/agents/tool_route_manager.rs +++ b/crates/goose/src/agents/tool_route_manager.rs @@ -11,7 +11,6 @@ use rmcp::model::{ErrorCode, ErrorData, Tool}; use serde_json::Value; use std::sync::Arc; use tokio::sync::Mutex; -use tokio::sync::RwLock; use tracing::error; pub struct ToolRouteManager { @@ -85,7 +84,7 @@ impl ToolRouteManager { &self, provider: Arc, reindex_all: Option, - extension_manager: &Arc>, + extension_manager: &ExtensionManager, ) -> Result<()> { let enabled = self.is_router_enabled().await; if !enabled { @@ -100,7 +99,6 @@ impl ToolRouteManager { let selector_arc = Arc::new(selector); // First index platform tools - let extension_manager = extension_manager.read().await; ToolRouterIndexManager::index_platform_tools(&selector_arc, &extension_manager).await?; if reindex_all.unwrap_or(false) { @@ -142,10 +140,7 @@ impl ToolRouteManager { self.router_tool_selector.lock().await.is_some() } - pub async fn list_tools_for_router( - &self, - extension_manager: &Arc>, - ) -> Vec { + pub async fn list_tools_for_router(&self, extension_manager: &ExtensionManager) -> Vec { // If router is disabled or overridden, return empty if *self.router_disabled_override.lock().await { return vec![]; @@ -163,7 +158,6 @@ impl ToolRouteManager { let selector = self.router_tool_selector.lock().await.clone(); if let Some(selector) = selector { if let Ok(recent_calls) = selector.get_recent_tool_calls(20).await { - let extension_manager = extension_manager.read().await; // Add recent tool calls to the list, avoiding duplicates for tool_name in recent_calls { // Find the tool in the extension manager's tools diff --git a/crates/goose/src/agents/tool_router_index_manager.rs b/crates/goose/src/agents/tool_router_index_manager.rs index c419de174870..320b10290338 100644 --- a/crates/goose/src/agents/tool_router_index_manager.rs +++ b/crates/goose/src/agents/tool_router_index_manager.rs @@ -87,7 +87,7 @@ impl ToolRouterIndexManager { tools.push(platform_tools::manage_extensions_tool()); // Add resource tools if supported - if extension_manager.supports_resources() { + if extension_manager.supports_resources().await { tools.push(platform_tools::read_resource_tool()); tools.push(platform_tools::list_resources_tool()); } From 1535c5d29d3e80ef92942a3117b80b2a1cdce388 Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Wed, 20 Aug 2025 17:27:25 -0400 Subject: [PATCH 2/4] output --- crates/goose-cli/src/session/builder.rs | 78 +++++++++++++++---------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index ea619dfa69a3..92c316abd1d7 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -7,8 +7,10 @@ use goose::recipe::{Response, SubRecipe}; use goose::session; use goose::session::Identifier; use rustyline::EditMode; +use std::collections::HashSet; use std::process; use std::sync::Arc; +use tokio::task::JoinSet; use super::output; use super::Session; @@ -355,38 +357,54 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { .collect() }; - for extension in extensions_to_run { - if let Err(e) = agent.add_extension(extension.clone()).await { - let err = e.to_string(); - eprintln!( - "{}", - style(format!( - "Warning: Failed to start extension '{}': {}", - extension.name(), - err - )) - .yellow() - ); - eprintln!( - "{}", - style(format!( - "Continuing without extension '{}'", - extension.name() - )) - .yellow() - ); + let mut set = JoinSet::new(); + let agent_ptr = Arc::new(agent); - // Offer debugging help - if let Err(debug_err) = offer_extension_debugging_help( - &extension.name(), - &err, - Arc::clone(&provider_for_display), - session_config.interactive, + let mut waiting_on = HashSet::new(); + for extension in extensions_to_run { + waiting_on.insert(extension.name()); + let agent_ptr = agent_ptr.clone(); + set.spawn(async move { + ( + extension.name(), + agent_ptr.add_extension(extension.clone()).await, ) - .await - { - eprintln!("Note: Could not start debugging session: {}", debug_err); + }); + } + + let get_message = |waiting_on: &HashSet| { + let mut names: Vec<_> = waiting_on.iter().cloned().collect(); + names.sort(); + format!("starting {} extensions: {}", names.len(), names.join(", ")) + }; + + let spinner = cliclack::spinner(); + spinner.start(get_message(&waiting_on)); + + let mut offer_debug = Vec::new(); + while let Some(result) = set.join_next().await { + match result { + Ok((name, Ok(_))) => { + waiting_on.remove(&name); + spinner.set_message(get_message(&waiting_on)); } + Ok((name, Err(e))) => offer_debug.push((name, e)), + Err(e) => tracing::error!("failed to add extension: {}", e), + } + } + + spinner.clear(); + + for (name, err) in offer_debug { + if let Err(debug_err) = offer_extension_debugging_help( + &name, + &err.to_string(), + Arc::clone(&provider_for_display), + session_config.interactive, + ) + .await + { + eprintln!("Note: Could not start debugging session: {}", debug_err); } } @@ -405,7 +423,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Create new session let mut session = Session::new( - agent, + Arc::try_unwrap(agent_ptr).unwrap_or_else(|_| panic!("There should be no more references")), session_file.clone(), session_config.debug, session_config.scheduled_job_id.clone(), From 1610597ceb3d97e5796c52fa8ffefdde11eac7ab Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Wed, 20 Aug 2025 20:28:33 -0400 Subject: [PATCH 3/4] Expose add_client again --- .../src/scenario_tests/scenario_runner.rs | 24 +- crates/goose/src/agents/extension_manager.rs | 216 ++++++++++-------- crates/goose/tests/mcp_integration_test.rs | 2 +- 3 files changed, 145 insertions(+), 97 deletions(-) diff --git a/crates/goose-cli/src/scenario_tests/scenario_runner.rs b/crates/goose-cli/src/scenario_tests/scenario_runner.rs index b3aa3b67f9b9..130077c9fdc7 100644 --- a/crates/goose-cli/src/scenario_tests/scenario_runner.rs +++ b/crates/goose-cli/src/scenario_tests/scenario_runner.rs @@ -136,6 +136,9 @@ async fn run_provider_scenario_with_validation( where F: Fn(&ScenarioResult) -> Result<()>, { + use goose::config::ExtensionConfig; + use tokio::sync::Mutex; + if let Ok(path) = dotenv() { println!("Loaded environment from {:?}", path); } @@ -193,10 +196,23 @@ where let mock_client = weather_client(); let agent = Agent::new(); - { - let mut extension_manager = agent.extension_manager.write().await; - extension_manager.add_client("weather_extension".to_string(), Box::new(mock_client)); - } + agent + .extension_manager + .add_client( + "weather_extension".to_string(), + ExtensionConfig::Builtin { + name: "".to_string(), + display_name: None, + description: None, + timeout: None, + bundled: None, + available_tools: vec![], + }, + Arc::new(Mutex::new(Box::new(mock_client))), + None, + None, + ) + .await; agent .update_provider(provider_arc as Arc) diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 0c3368400904..e80868fc91a5 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -14,7 +14,7 @@ use std::collections::HashMap; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; -use tempfile::tempdir; +use tempfile::{tempdir, TempDir}; use tokio::io::AsyncReadExt; use tokio::process::Command; use tokio::sync::Mutex; @@ -415,14 +415,32 @@ impl ExtensionManager { }; let server_info = client.get_info().cloned(); - self.extensions.lock().await.insert( + self.add_client( sanitized_name, - Extension::new(config, Arc::new(Mutex::new(client)), server_info, temp_dir), - ); + config, + Arc::new(Mutex::new(client)), + server_info, + temp_dir, + ) + .await; Ok(()) } + pub async fn add_client( + &self, + name: String, + config: ExtensionConfig, + client: McpClientBox, + info: Option, + temp_dir: Option, + ) { + self.extensions + .lock() + .await + .insert(name, Extension::new(config, client, info, temp_dir)); + } + /// Get extensions info pub async fn get_extensions_info(&self) -> Vec { self.extensions @@ -1031,6 +1049,35 @@ mod tests { use serde_json::json; use tokio::sync::mpsc; + impl ExtensionManager { + async fn add_mock_extension(&self, name: String, client: McpClientBox) { + self.add_mock_extension_with_tools(name, client, vec![]) + .await; + } + + async fn add_mock_extension_with_tools( + &self, + name: String, + client: McpClientBox, + available_tools: Vec, + ) { + let sanitized_name = normalize(name.clone()); + let config = ExtensionConfig::Builtin { + name: name.clone(), + display_name: Some(name.clone()), + description: None, + timeout: None, + bundled: None, + available_tools, + }; + let extension = Extension::new(config, client, None, None); + self.extensions + .lock() + .await + .insert(sanitized_name, extension); + } + } + struct MockClient {} #[async_trait::async_trait] @@ -1128,49 +1175,61 @@ mod tests { } } - #[test] - fn test_get_client_for_tool() { - let mut extension_manager = ExtensionManager::new(); + #[tokio::test] + async fn test_get_client_for_tool() { + let extension_manager = ExtensionManager::new(); - // Add some mock clients - extension_manager.clients.insert( - normalize("test_client".to_string()), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + // Add some mock clients using the helper method + extension_manager + .add_mock_extension( + "test_client".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + ) + .await; - extension_manager.clients.insert( - normalize("__client".to_string()), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + extension_manager + .add_mock_extension( + "__client".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + ) + .await; - extension_manager.clients.insert( - normalize("__cli__ent__".to_string()), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + extension_manager + .add_mock_extension( + "__cli__ent__".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + ) + .await; - extension_manager.clients.insert( - normalize("client 🚀".to_string()), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + extension_manager + .add_mock_extension( + "client 🚀".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + ) + .await; // Test basic case assert!(extension_manager .get_client_for_tool("test_client__tool") + .await .is_some()); // Test leading underscores assert!(extension_manager .get_client_for_tool("__client__tool") + .await .is_some()); // Test multiple underscores in client name, and ending with __ assert!(extension_manager .get_client_for_tool("__cli__ent____tool") + .await .is_some()); // Test unicode in tool name, "client 🚀" should become "client_" assert!(extension_manager .get_client_for_tool("client___tool") + .await .is_some()); } @@ -1178,23 +1237,29 @@ mod tests { async fn test_dispatch_tool_call() { // test that dispatch_tool_call parses out the sanitized name correctly, and extracts // tool_names - let mut extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new(); - // Add some mock clients - extension_manager.clients.insert( - normalize("test_client".to_string()), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + // Add some mock clients using the helper method + extension_manager + .add_mock_extension( + "test_client".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + ) + .await; - extension_manager.clients.insert( - normalize("__cli__ent__".to_string()), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + extension_manager + .add_mock_extension( + "__cli__ent__".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + ) + .await; - extension_manager.clients.insert( - normalize("client 🚀".to_string()), - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + extension_manager + .add_mock_extension( + "client 🚀".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + ) + .await; // verify a normal tool call let tool_call = ToolCall { @@ -1289,29 +1354,18 @@ mod tests { #[tokio::test] async fn test_tool_availability_filtering() { - let mut extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new(); // Only "available_tool" should be available to the LLM let available_tools = vec!["available_tool".to_string()]; - let config = ExtensionConfig::Builtin { - name: "test_extension".to_string(), - display_name: Some("Test Extension".to_string()), - description: Some("Test extension for available tools".to_string()), - timeout: Some(300), - bundled: Some(true), - available_tools, - }; - - let sanitized_name = normalize("test_extension".to_string()); extension_manager - .extension_configs - .insert(sanitized_name.clone(), config); - - extension_manager.clients.insert( - sanitized_name, - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + .add_mock_extension_with_tools( + "test_extension".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + available_tools, + ) + .await; let tools = extension_manager.get_prefixed_tools(None).await.unwrap(); @@ -1328,26 +1382,15 @@ mod tests { #[tokio::test] async fn test_tool_availability_defaults_to_available() { - let mut extension_manager = ExtensionManager::new(); - - let config = ExtensionConfig::Builtin { - name: "test_extension".to_string(), - display_name: Some("Test Extension".to_string()), - description: Some("Test extension for available tools".to_string()), - timeout: Some(300), - bundled: Some(true), - available_tools: vec![], - }; + let extension_manager = ExtensionManager::new(); - let sanitized_name = normalize("test_extension".to_string()); extension_manager - .extension_configs - .insert(sanitized_name.clone(), config); - - extension_manager.clients.insert( - sanitized_name, - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + .add_mock_extension_with_tools( + "test_extension".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + vec![], // Empty available_tools means all tools are available by default + ) + .await; let tools = extension_manager.get_prefixed_tools(None).await.unwrap(); @@ -1364,28 +1407,17 @@ mod tests { #[tokio::test] async fn test_dispatch_unavailable_tool_returns_error() { - let mut extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new(); let available_tools = vec!["available_tool".to_string()]; - let config = ExtensionConfig::Builtin { - name: "test_extension".to_string(), - display_name: Some("Test Extension".to_string()), - description: Some("Test extension for tool dispatch".to_string()), - timeout: Some(300), - bundled: Some(true), - available_tools, - }; - - let sanitized_name = normalize("test_extension".to_string()); extension_manager - .extension_configs - .insert(sanitized_name.clone(), config); - - extension_manager.clients.insert( - sanitized_name, - Arc::new(Mutex::new(Box::new(MockClient {}))), - ); + .add_mock_extension_with_tools( + "test_extension".to_string(), + Arc::new(Mutex::new(Box::new(MockClient {}))), + available_tools, + ) + .await; // Try to call an unavailable tool let unavailable_tool_call = ToolCall { diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index bba17fc6844b..f5a69c8d01ac 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -126,7 +126,7 @@ async fn test_replayed_session( available_tools: vec![], }; - let mut extension_manager = ExtensionManager::new(); + let extension_manager = ExtensionManager::new(); let result = extension_manager.add_extension(extension_config).await; assert!(result.is_ok(), "Failed to add extension: {:?}", result); From 56c5eaf3b7cf6f1dee81058c52c869e60e6fc101 Mon Sep 17 00:00:00 2001 From: Jack Amadeo Date: Wed, 20 Aug 2025 20:38:40 -0400 Subject: [PATCH 4/4] Silly way to satisfy clippy --- crates/goose/src/agents/extension_manager.rs | 5 +++-- crates/goose/src/agents/tool_route_manager.rs | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index e80868fc91a5..9b4ffd92d188 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -890,12 +890,13 @@ impl ExtensionManager { ) -> Result>, ErrorData> { let mut futures = FuturesUnordered::new(); - for extension_name in self.extensions.lock().await.keys().cloned() { + let names: Vec<_> = self.extensions.lock().await.keys().cloned().collect(); + for extension_name in names { let token = cancellation_token.clone(); futures.push(async move { ( extension_name.clone(), - self.list_prompts_from_extension(&extension_name.as_str(), token) + self.list_prompts_from_extension(extension_name.as_str(), token) .await, ) }); diff --git a/crates/goose/src/agents/tool_route_manager.rs b/crates/goose/src/agents/tool_route_manager.rs index 697a4731b46c..7e277d96f764 100644 --- a/crates/goose/src/agents/tool_route_manager.rs +++ b/crates/goose/src/agents/tool_route_manager.rs @@ -99,14 +99,14 @@ impl ToolRouteManager { let selector_arc = Arc::new(selector); // First index platform tools - ToolRouterIndexManager::index_platform_tools(&selector_arc, &extension_manager).await?; + ToolRouterIndexManager::index_platform_tools(&selector_arc, extension_manager).await?; if reindex_all.unwrap_or(false) { let enabled_extensions = extension_manager.list_extensions().await?; for extension_name in enabled_extensions { if let Err(e) = ToolRouterIndexManager::update_extension_tools( &selector_arc, - &extension_manager, + extension_manager, &extension_name, "add", )