diff --git a/crates/goose-cli/src/session/builder.rs b/crates/goose-cli/src/session/builder.rs index ae2adc9bf6b8..f18b886c183c 100644 --- a/crates/goose-cli/src/session/builder.rs +++ b/crates/goose-cli/src/session/builder.rs @@ -346,6 +346,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session { // Extensions need to be added after the session is created because we change directory when resuming a session // If we get extensions_override, only run those extensions and none other let extensions_to_run: Vec<_> = if let Some(extensions) = session_config.extensions_override { + agent.disable_router_for_recipe().await; extensions.into_iter().collect() } else { ExtensionConfigManager::get_all() diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index d412517ecfdd..5a9e7e4a7d39 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -21,17 +21,15 @@ use crate::agents::recipe_tools::dynamic_task_tools::{ create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, }; use crate::agents::retry::{RetryManager, RetryResult}; -use crate::agents::router_tool_selector::{ - create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, -}; +use crate::agents::router_tool_selector::RouterToolSelectionStrategy; use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME}; use crate::agents::sub_recipe_manager::SubRecipeManager; use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, }; use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; +use crate::agents::tool_route_manager::ToolRouteManager; use crate::agents::tool_router_index_manager::ToolRouterIndexManager; -use crate::agents::tool_vectordb::generate_table_id; use crate::agents::types::SessionConfig; use crate::agents::types::{FrontendTool, ToolResultReceiver}; use crate::config::{Config, ExtensionConfigManager, PermissionManager}; @@ -54,7 +52,6 @@ use tracing::{debug, error, info, instrument}; use super::final_output_tool::FinalOutputTool; use super::platform_tools; -use super::router_tools; use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; use crate::agents::subagent_task_config::TaskConfig; use crate::conversation_fixer::{debug_conversation_fix, ConversationFixer}; @@ -72,8 +69,7 @@ pub struct ReplyContext { pub config: &'static Config, } -/// Result of processing tool requests -pub struct ToolProcessingResult { +pub struct ToolCategorizeResult { pub frontend_requests: Vec, pub remaining_requests: Vec, pub filtered_response: Message, @@ -96,7 +92,7 @@ pub struct Agent { pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult>)>, pub(super) tool_result_rx: ToolResultReceiver, pub(super) tool_monitor: Arc>>, - pub(super) router_tool_selector: Mutex>>>, + pub(super) tool_route_manager: ToolRouteManager, pub(super) scheduler_service: Mutex>>, pub(super) retry_manager: RetryManager, } @@ -171,7 +167,7 @@ impl Agent { tool_result_tx: tool_tx, tool_result_rx: Arc::new(Mutex::new(tool_rx)), tool_monitor, - router_tool_selector: Mutex::new(None), + tool_route_manager: ToolRouteManager::new(), scheduler_service: Mutex::new(None), retry_manager, } @@ -246,38 +242,18 @@ impl Agent { }) } - /// Process tool requests by categorizing them and recording them in the router selector - async fn process_tool_requests( + async fn categorize_tools( &self, response: &Message, tools: &[rmcp::model::Tool], - ) -> ToolProcessingResult { + ) -> ToolCategorizeResult { let (readonly_tools, regular_tools) = Self::categorize_tools_by_annotation(tools); // Categorize tool requests let (frontend_requests, remaining_requests, filtered_response) = self.categorize_tool_requests(response).await; - // Record tool calls in the router selector - let selector = self.router_tool_selector.lock().await.clone(); - if let Some(selector) = selector { - for request in &frontend_requests { - if let Ok(tool_call) = &request.tool_call { - if let Err(e) = selector.record_tool_call(&tool_call.name).await { - error!("Failed to record frontend tool call: {}", e); - } - } - } - for request in &remaining_requests { - if let Ok(tool_call) = &request.tool_call { - if let Err(e) = selector.record_tool_call(&tool_call.name).await { - error!("Failed to record tool call: {}", e); - } - } - } - } - - ToolProcessingResult { + ToolCategorizeResult { frontend_requests, remaining_requests, filtered_response, @@ -336,6 +312,10 @@ impl Agent { *scheduler_service = Some(scheduler); } + pub async fn disable_router_for_recipe(&self) { + self.tool_route_manager.disable_router_for_recipe().await; + } + /// Get a reference count clone to the provider pub async fn provider(&self) -> Result, anyhow::Error> { match &*self.provider.lock().await { @@ -476,43 +456,14 @@ impl Agent { } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME || tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME { - let selector = self.router_tool_selector.lock().await.clone(); - let mut selected_tools = match selector.as_ref() { - Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await { - Ok(tools) => tools, - Err(e) => { - return ( - request_id, - Err(ToolError::ExecutionError(format!( - "Failed to select tools: {}", - e - ))), - ) - } - }, - None => { - return ( - request_id, - Err(ToolError::ExecutionError( - "No tool selector available".to_string(), - )), - ) - } - }; - - // Append final_output tool if present (for structured output recipes, [Issue #3700](https://github.com/block/goose/issues/3700) - if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { - let tool = final_output_tool.tool(); - let tool_content = Content::text(format!( - "Tool: {}\nDescription: {}\nSchema: {}", - tool.name, - tool.description.unwrap_or_default(), - serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default() - )); - selected_tools.push(tool_content); + match self + .tool_route_manager + .dispatch_route_search_tool(tool_call.arguments) + .await + { + Ok(tool_result) => tool_result, + Err(e) => return (request_id, Err(e)), } - - ToolCallResult::from(Ok(selected_tools)) } else { // Clone the result to ensure no references to extension_manager are returned let result = extension_manager @@ -542,9 +493,7 @@ impl Agent { extension_name: String, request_id: String, ) -> (String, Result, ToolError>) { - let mut extension_manager = self.extension_manager.write().await; - - let selector = self.router_tool_selector.lock().await.clone(); + let selector = self.tool_route_manager.get_router_tool_selector().await; if ToolRouterIndexManager::is_tool_router_enabled(&selector) { if let Some(selector) = selector { let selector_action = if action == "disable" { "remove" } else { "add" }; @@ -568,7 +517,7 @@ impl Agent { } } } - + let mut extension_manager = self.extension_manager.write().await; if action == "disable" { let result = extension_manager .remove_extension(&extension_name) @@ -604,7 +553,6 @@ impl Agent { ) } }; - let result = extension_manager .add_extension(config) .await @@ -616,9 +564,10 @@ impl Agent { }) .map_err(|e| ToolError::ExecutionError(e.to_string())); + drop(extension_manager); // Update vector index if operation was successful and vector routing is enabled if result.is_ok() { - let selector = self.router_tool_selector.lock().await.clone(); + let selector = self.tool_route_manager.get_router_tool_selector().await; if ToolRouterIndexManager::is_tool_router_enabled(&selector) { if let Some(selector) = selector { let vector_action = if action == "disable" { "remove" } else { "add" }; @@ -681,7 +630,7 @@ impl Agent { } // If vector tool selection is enabled, index the tools - let selector = self.router_tool_selector.lock().await.clone(); + let selector = self.tool_route_manager.get_router_tool_selector().await; if ToolRouterIndexManager::is_tool_router_enabled(&selector) { if let Some(selector) = selector { let extension_manager = self.extension_manager.read().await; @@ -750,46 +699,18 @@ impl Agent { &self, strategy: Option, ) -> Vec { - let mut prefixed_tools = vec![]; - match strategy { - Some(RouterToolSelectionStrategy::Vector) => { - prefixed_tools.push(router_tools::vector_search_tool()); - } - Some(RouterToolSelectionStrategy::Llm) => { - prefixed_tools.push(router_tools::llm_search_tool()); - } - None => {} - } - - // Get recent tool calls from router tool selector if available - let selector = self.router_tool_selector.lock().await.clone(); - if let Some(selector) = selector { - if let Ok(recent_calls) = selector.get_recent_tool_calls(20).await { - let extension_manager = self.extension_manager.read().await; - // Add recent tool calls to the list, avoiding duplicates - for tool_name in recent_calls { - // Find the tool in the extension manager's tools - if let Ok(extension_tools) = extension_manager.get_prefixed_tools(None).await { - if let Some(tool) = extension_tools.iter().find(|t| t.name == tool_name) { - // Only add if not already in prefixed_tools - if !prefixed_tools.iter().any(|t| t.name == tool.name) { - prefixed_tools.push(tool.clone()); - } - } - } - } - } - } - - prefixed_tools + self.tool_route_manager + .list_tools_for_router(strategy, &self.extension_manager) + .await } pub async fn remove_extension(&self, name: &str) -> Result<()> { let mut extension_manager = self.extension_manager.write().await; extension_manager.remove_extension(name).await?; + drop(extension_manager); // If vector tool selection is enabled, remove tools from the index - let selector = self.router_tool_selector.lock().await.clone(); + let selector = self.tool_route_manager.get_router_tool_selector().await; if ToolRouterIndexManager::is_tool_router_enabled(&selector) { if let Some(selector) = selector { let extension_manager = self.extension_manager.read().await; @@ -938,14 +859,17 @@ impl Agent { } if let Some(response) = response { - let tool_result = self.process_tool_requests(&response, &tools).await; - let ToolProcessingResult { + let ToolCategorizeResult { frontend_requests, remaining_requests, filtered_response, readonly_tools, regular_tools, - } = tool_result; + } = self.categorize_tools(&response, &tools).await; + let requests_to_record: Vec = frontend_requests.iter().chain(remaining_requests.iter()).cloned().collect(); + self.tool_route_manager + .record_tool_requests(&requests_to_record) + .await; yield AgentEvent::Message(filtered_response.clone()); tokio::task::yield_now().await; @@ -1151,67 +1075,15 @@ impl Agent { provider: Option>, reindex_all: Option, ) -> Result<()> { - let config = Config::global(); - let _extension_manager = self.extension_manager.read().await; let provider = match provider { Some(p) => p, None => self.provider().await?, }; - let router_tool_selection_strategy = config - .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") - .unwrap_or_else(|_| "default".to_string()); - - let strategy = match router_tool_selection_strategy.to_lowercase().as_str() { - "vector" => Some(RouterToolSelectionStrategy::Vector), - "llm" => Some(RouterToolSelectionStrategy::Llm), - _ => None, - }; - - let selector = match strategy { - Some(RouterToolSelectionStrategy::Vector) => { - let table_name = generate_table_id(); - let selector = create_tool_selector(strategy, provider.clone(), Some(table_name)) - .await - .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; - Arc::new(selector) - } - Some(RouterToolSelectionStrategy::Llm) => { - let selector = create_tool_selector(strategy, provider.clone(), None) - .await - .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; - Arc::new(selector) - } - None => return Ok(()), - }; - - // First index platform tools - let extension_manager = self.extension_manager.read().await; - ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?; - - if reindex_all.unwrap_or(false) { - let enabled_extensions = extension_manager.list_extensions().await?; - for extension_name in enabled_extensions { - if let Err(e) = ToolRouterIndexManager::update_extension_tools( - &selector, - &extension_manager, - &extension_name, - "add", - ) - .await - { - error!( - "Failed to index tools for extension {}: {}", - extension_name, e - ); - } - } - } - - // Update the selector - *self.router_tool_selector.lock().await = Some(selector.clone()); - - Ok(()) + // Delegate to ToolRouteManager + self.tool_route_manager + .update_router_tool_selector(provider, reindex_all, &self.extension_manager) + .await } /// Override the system prompt with a custom template diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index f272a5039d77..9cd782f3522e 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -18,6 +18,7 @@ pub mod subagent_execution_tool; pub mod subagent_handler; mod subagent_task_config; mod tool_execution; +mod tool_route_manager; mod tool_router_index_manager; pub(crate) mod tool_vectordb; pub mod types; diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 61452c01eac2..46c134e0306e 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -6,7 +6,6 @@ use async_stream::try_stream; use futures::stream::StreamExt; use crate::agents::router_tool_selector::RouterToolSelectionStrategy; -use crate::config::Config; use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage}; use crate::providers::errors::ProviderError; @@ -36,16 +35,10 @@ impl Agent { /// Prepares tools and system prompt for a provider request pub async fn prepare_tools_and_prompt(&self) -> anyhow::Result<(Vec, Vec, String)> { // Get tool selection strategy from config - let config = Config::global(); - let router_tool_selection_strategy = config - .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") - .unwrap_or_else(|_| "default".to_string()); - - let tool_selection_strategy = match router_tool_selection_strategy.to_lowercase().as_str() { - "vector" => Some(RouterToolSelectionStrategy::Vector), - "llm" => Some(RouterToolSelectionStrategy::Llm), - _ => None, - }; + let tool_selection_strategy = self + .tool_route_manager + .get_router_tool_selection_strategy() + .await; // Get tools from extension manager let mut tools = match tool_selection_strategy { diff --git a/crates/goose/src/agents/tool_route_manager.rs b/crates/goose/src/agents/tool_route_manager.rs new file mode 100644 index 000000000000..08a77157aa13 --- /dev/null +++ b/crates/goose/src/agents/tool_route_manager.rs @@ -0,0 +1,187 @@ +use crate::agents::extension_manager::ExtensionManager; +use crate::agents::router_tool_selector::{ + create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, +}; +use crate::agents::router_tools::{self}; +use crate::agents::tool_execution::ToolCallResult; +use crate::agents::tool_router_index_manager::ToolRouterIndexManager; +use crate::agents::tool_vectordb::generate_table_id; +use crate::config::Config; +use crate::message::ToolRequest; +use crate::providers::base::Provider; +use anyhow::{anyhow, Result}; +use mcp_core::ToolError; +use rmcp::model::Tool; +use serde_json::Value; +use std::sync::Arc; +use tokio::sync::Mutex; +use tokio::sync::RwLock; +use tracing::error; + +pub struct ToolRouteManager { + router_tool_selector: Mutex>>>, + router_disabled_override: Mutex, +} + +impl ToolRouteManager { + pub fn new() -> Self { + Self { + router_tool_selector: Mutex::new(None), + router_disabled_override: Mutex::new(false), + } + } + + pub async fn disable_router_for_recipe(&self) { + *self.router_disabled_override.lock().await = true; + *self.router_tool_selector.lock().await = None; + } + + pub async fn record_tool_requests(&self, requests: &[ToolRequest]) { + let selector = self.router_tool_selector.lock().await.clone(); + if let Some(selector) = selector { + for request in requests { + if let Ok(tool_call) = &request.tool_call { + if let Err(e) = selector.record_tool_call(&tool_call.name).await { + error!("Failed to record tool call: {}", e); + } + } + } + } + } + + pub async fn dispatch_route_search_tool( + &self, + arguments: Value, + ) -> Result { + let selector = self.router_tool_selector.lock().await.clone(); + match selector.as_ref() { + Some(selector) => match selector.select_tools(arguments).await { + Ok(tools) => Ok(ToolCallResult::from(Ok(tools))), + Err(e) => Err(ToolError::ExecutionError(format!( + "Failed to select tools: {}", + e + ))), + }, + None => Err(ToolError::ExecutionError( + "No tool selector available".to_string(), + )), + } + } + + pub async fn get_router_tool_selection_strategy(&self) -> Option { + if *self.router_disabled_override.lock().await { + return None; + } + + let config = Config::global(); + let router_tool_selection_strategy = config + .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .unwrap_or_else(|_| "default".to_string()); + + match router_tool_selection_strategy.to_lowercase().as_str() { + "vector" => Some(RouterToolSelectionStrategy::Vector), + "llm" => Some(RouterToolSelectionStrategy::Llm), + _ => None, + } + } + + pub async fn update_router_tool_selector( + &self, + provider: Arc, + reindex_all: Option, + extension_manager: &Arc>, + ) -> Result<()> { + let strategy = self.get_router_tool_selection_strategy().await; + let selector = match strategy { + Some(RouterToolSelectionStrategy::Vector) => { + let table_name = generate_table_id(); + let selector = create_tool_selector(strategy, provider.clone(), Some(table_name)) + .await + .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; + Arc::new(selector) + } + Some(RouterToolSelectionStrategy::Llm) => { + let selector = create_tool_selector(strategy, provider.clone(), None) + .await + .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; + Arc::new(selector) + } + None => return Ok(()), + }; + + // First index platform tools + let extension_manager = extension_manager.read().await; + ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?; + + if reindex_all.unwrap_or(false) { + let enabled_extensions = extension_manager.list_extensions().await?; + for extension_name in enabled_extensions { + if let Err(e) = ToolRouterIndexManager::update_extension_tools( + &selector, + &extension_manager, + &extension_name, + "add", + ) + .await + { + error!( + "Failed to index tools for extension {}: {}", + extension_name, e + ); + } + } + } + + // Update the selector + *self.router_tool_selector.lock().await = Some(selector.clone()); + + Ok(()) + } + + pub async fn get_router_tool_selector(&self) -> Option>> { + self.router_tool_selector.lock().await.clone() + } + + pub async fn list_tools_for_router( + &self, + strategy: Option, + extension_manager: &Arc>, + ) -> Vec { + if *self.router_disabled_override.lock().await { + return vec![]; + } + + let mut prefixed_tools = vec![]; + match strategy { + Some(RouterToolSelectionStrategy::Vector) => { + prefixed_tools.push(router_tools::vector_search_tool()); + } + Some(RouterToolSelectionStrategy::Llm) => { + prefixed_tools.push(router_tools::llm_search_tool()); + } + None => {} + } + + // Get recent tool calls from router tool selector if available + let selector = self.router_tool_selector.lock().await.clone(); + if let Some(selector) = selector { + if let Ok(recent_calls) = selector.get_recent_tool_calls(20).await { + let extension_manager = extension_manager.read().await; + // Add recent tool calls to the list, avoiding duplicates + for tool_name in recent_calls { + // Find the tool in the extension manager's tools + if let Ok(extension_tools) = extension_manager.get_prefixed_tools(None).await { + if let Some(tool) = extension_tools.iter().find(|t| t.name == tool_name) { + // Only add if not already in prefixed_tools + if !prefixed_tools.iter().any(|t| t.name == tool.name) { + prefixed_tools.push(tool.clone()); + } + } + } + } + } + } + + prefixed_tools + } +}