-
Notifications
You must be signed in to change notification settings - Fork 2.4k
chord: refactor tool route #3732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8c875ef
b5db1e4
ec5b207
b6dbf5e
79e398e
68121a2
0be1476
9155263
e002ad2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,17 +21,15 @@ use crate::agents::recipe_tools::dynamic_task_tools::{ | |
| create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX, | ||
| }; | ||
| use crate::agents::retry::{RetryManager, RetryResult}; | ||
| use crate::agents::router_tool_selector::{ | ||
| create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector, | ||
| }; | ||
| use crate::agents::router_tool_selector::RouterToolSelectionStrategy; | ||
| use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME}; | ||
| use crate::agents::sub_recipe_manager::SubRecipeManager; | ||
| use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{ | ||
| self, SUBAGENT_EXECUTE_TASK_TOOL_NAME, | ||
| }; | ||
| use crate::agents::subagent_execution_tool::tasks_manager::TasksManager; | ||
| use crate::agents::tool_route_manager::ToolRouteManager; | ||
| use crate::agents::tool_router_index_manager::ToolRouterIndexManager; | ||
| use crate::agents::tool_vectordb::generate_table_id; | ||
| use crate::agents::types::SessionConfig; | ||
| use crate::agents::types::{FrontendTool, ToolResultReceiver}; | ||
| use crate::config::{Config, ExtensionConfigManager, PermissionManager}; | ||
|
|
@@ -54,7 +52,6 @@ use tracing::{debug, error, info, instrument}; | |
|
|
||
| use super::final_output_tool::FinalOutputTool; | ||
| use super::platform_tools; | ||
| use super::router_tools; | ||
| use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; | ||
| use crate::agents::subagent_task_config::TaskConfig; | ||
| use crate::conversation_fixer::{debug_conversation_fix, ConversationFixer}; | ||
|
|
@@ -72,8 +69,7 @@ pub struct ReplyContext { | |
| pub config: &'static Config, | ||
| } | ||
|
|
||
| /// Result of processing tool requests | ||
| pub struct ToolProcessingResult { | ||
| pub struct ToolCategorizeResult { | ||
| pub frontend_requests: Vec<ToolRequest>, | ||
| pub remaining_requests: Vec<ToolRequest>, | ||
| pub filtered_response: Message, | ||
|
|
@@ -96,7 +92,7 @@ pub struct Agent { | |
| pub(super) tool_result_tx: mpsc::Sender<(String, ToolResult<Vec<Content>>)>, | ||
| pub(super) tool_result_rx: ToolResultReceiver, | ||
| pub(super) tool_monitor: Arc<Mutex<Option<ToolMonitor>>>, | ||
| pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>, | ||
| pub(super) tool_route_manager: ToolRouteManager, | ||
| pub(super) scheduler_service: Mutex<Option<Arc<dyn SchedulerTrait>>>, | ||
| pub(super) retry_manager: RetryManager, | ||
| } | ||
|
|
@@ -171,7 +167,7 @@ impl Agent { | |
| tool_result_tx: tool_tx, | ||
| tool_result_rx: Arc::new(Mutex::new(tool_rx)), | ||
| tool_monitor, | ||
| router_tool_selector: Mutex::new(None), | ||
| tool_route_manager: ToolRouteManager::new(), | ||
| scheduler_service: Mutex::new(None), | ||
| retry_manager, | ||
| } | ||
|
|
@@ -246,38 +242,18 @@ impl Agent { | |
| }) | ||
| } | ||
|
|
||
| /// Process tool requests by categorizing them and recording them in the router selector | ||
| async fn process_tool_requests( | ||
| async fn categorize_tools( | ||
| &self, | ||
| response: &Message, | ||
| tools: &[rmcp::model::Tool], | ||
| ) -> ToolProcessingResult { | ||
| ) -> ToolCategorizeResult { | ||
| let (readonly_tools, regular_tools) = Self::categorize_tools_by_annotation(tools); | ||
|
|
||
| // Categorize tool requests | ||
| let (frontend_requests, remaining_requests, filtered_response) = | ||
| self.categorize_tool_requests(response).await; | ||
|
|
||
| // Record tool calls in the router selector | ||
| let selector = self.router_tool_selector.lock().await.clone(); | ||
| if let Some(selector) = selector { | ||
| for request in &frontend_requests { | ||
| if let Ok(tool_call) = &request.tool_call { | ||
| if let Err(e) = selector.record_tool_call(&tool_call.name).await { | ||
| error!("Failed to record frontend tool call: {}", e); | ||
| } | ||
| } | ||
| } | ||
| for request in &remaining_requests { | ||
| if let Ok(tool_call) = &request.tool_call { | ||
| if let Err(e) = selector.record_tool_call(&tool_call.name).await { | ||
| error!("Failed to record tool call: {}", e); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| ToolProcessingResult { | ||
| ToolCategorizeResult { | ||
| frontend_requests, | ||
| remaining_requests, | ||
| filtered_response, | ||
|
|
@@ -336,6 +312,10 @@ impl Agent { | |
| *scheduler_service = Some(scheduler); | ||
| } | ||
|
|
||
| pub async fn disable_router_for_recipe(&self) { | ||
| self.tool_route_manager.disable_router_for_recipe().await; | ||
| } | ||
|
|
||
| /// Get a reference count clone to the provider | ||
| pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> { | ||
| match &*self.provider.lock().await { | ||
|
|
@@ -476,43 +456,14 @@ impl Agent { | |
| } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME | ||
| || tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME | ||
| { | ||
| let selector = self.router_tool_selector.lock().await.clone(); | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved the logic to tool_route_manager |
||
| let mut selected_tools = match selector.as_ref() { | ||
| Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await { | ||
| Ok(tools) => tools, | ||
| Err(e) => { | ||
| return ( | ||
| request_id, | ||
| Err(ToolError::ExecutionError(format!( | ||
| "Failed to select tools: {}", | ||
| e | ||
| ))), | ||
| ) | ||
| } | ||
| }, | ||
| None => { | ||
| return ( | ||
| request_id, | ||
| Err(ToolError::ExecutionError( | ||
| "No tool selector available".to_string(), | ||
| )), | ||
| ) | ||
| } | ||
| }; | ||
|
|
||
| // Append final_output tool if present (for structured output recipes, [Issue #3700](https://github.com/block/goose/issues/3700) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is no longer needed as we used disabled the vector or llm strategy for recipe |
||
| if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() { | ||
| let tool = final_output_tool.tool(); | ||
| let tool_content = Content::text(format!( | ||
| "Tool: {}\nDescription: {}\nSchema: {}", | ||
| tool.name, | ||
| tool.description.unwrap_or_default(), | ||
| serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default() | ||
| )); | ||
| selected_tools.push(tool_content); | ||
| match self | ||
| .tool_route_manager | ||
| .dispatch_route_search_tool(tool_call.arguments) | ||
| .await | ||
| { | ||
| Ok(tool_result) => tool_result, | ||
| Err(e) => return (request_id, Err(e)), | ||
| } | ||
|
|
||
| ToolCallResult::from(Ok(selected_tools)) | ||
| } else { | ||
| // Clone the result to ensure no references to extension_manager are returned | ||
| let result = extension_manager | ||
|
|
@@ -542,9 +493,7 @@ impl Agent { | |
| extension_name: String, | ||
| request_id: String, | ||
| ) -> (String, Result<Vec<Content>, ToolError>) { | ||
| let mut extension_manager = self.extension_manager.write().await; | ||
|
|
||
| let selector = self.router_tool_selector.lock().await.clone(); | ||
| let selector = self.tool_route_manager.get_router_tool_selector().await; | ||
| if ToolRouterIndexManager::is_tool_router_enabled(&selector) { | ||
| if let Some(selector) = selector { | ||
| let selector_action = if action == "disable" { "remove" } else { "add" }; | ||
|
|
@@ -568,7 +517,7 @@ impl Agent { | |
| } | ||
| } | ||
| } | ||
|
|
||
| let mut extension_manager = self.extension_manager.write().await; | ||
| if action == "disable" { | ||
| let result = extension_manager | ||
| .remove_extension(&extension_name) | ||
|
|
@@ -604,7 +553,6 @@ impl Agent { | |
| ) | ||
| } | ||
| }; | ||
|
|
||
| let result = extension_manager | ||
| .add_extension(config) | ||
| .await | ||
|
|
@@ -616,9 +564,10 @@ impl Agent { | |
| }) | ||
| .map_err(|e| ToolError::ExecutionError(e.to_string())); | ||
|
|
||
| drop(extension_manager); | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to release the lock as later on it uses read lock
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| // Update vector index if operation was successful and vector routing is enabled | ||
| if result.is_ok() { | ||
| let selector = self.router_tool_selector.lock().await.clone(); | ||
| let selector = self.tool_route_manager.get_router_tool_selector().await; | ||
| if ToolRouterIndexManager::is_tool_router_enabled(&selector) { | ||
| if let Some(selector) = selector { | ||
| let vector_action = if action == "disable" { "remove" } else { "add" }; | ||
|
|
@@ -681,7 +630,7 @@ impl Agent { | |
| } | ||
|
|
||
| // If vector tool selection is enabled, index the tools | ||
| let selector = self.router_tool_selector.lock().await.clone(); | ||
| let selector = self.tool_route_manager.get_router_tool_selector().await; | ||
| if ToolRouterIndexManager::is_tool_router_enabled(&selector) { | ||
| if let Some(selector) = selector { | ||
| let extension_manager = self.extension_manager.read().await; | ||
|
|
@@ -750,46 +699,18 @@ impl Agent { | |
| &self, | ||
| strategy: Option<RouterToolSelectionStrategy>, | ||
| ) -> Vec<Tool> { | ||
| let mut prefixed_tools = vec![]; | ||
| match strategy { | ||
| Some(RouterToolSelectionStrategy::Vector) => { | ||
| prefixed_tools.push(router_tools::vector_search_tool()); | ||
| } | ||
| Some(RouterToolSelectionStrategy::Llm) => { | ||
| prefixed_tools.push(router_tools::llm_search_tool()); | ||
| } | ||
| None => {} | ||
| } | ||
|
|
||
| // Get recent tool calls from router tool selector if available | ||
| let selector = self.router_tool_selector.lock().await.clone(); | ||
| if let Some(selector) = selector { | ||
| if let Ok(recent_calls) = selector.get_recent_tool_calls(20).await { | ||
| let extension_manager = self.extension_manager.read().await; | ||
| // Add recent tool calls to the list, avoiding duplicates | ||
| for tool_name in recent_calls { | ||
| // Find the tool in the extension manager's tools | ||
| if let Ok(extension_tools) = extension_manager.get_prefixed_tools(None).await { | ||
| if let Some(tool) = extension_tools.iter().find(|t| t.name == tool_name) { | ||
| // Only add if not already in prefixed_tools | ||
| if !prefixed_tools.iter().any(|t| t.name == tool.name) { | ||
| prefixed_tools.push(tool.clone()); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| prefixed_tools | ||
| self.tool_route_manager | ||
| .list_tools_for_router(strategy, &self.extension_manager) | ||
| .await | ||
| } | ||
|
|
||
| pub async fn remove_extension(&self, name: &str) -> Result<()> { | ||
| let mut extension_manager = self.extension_manager.write().await; | ||
| extension_manager.remove_extension(name).await?; | ||
| drop(extension_manager); | ||
|
|
||
| // If vector tool selection is enabled, remove tools from the index | ||
| let selector = self.router_tool_selector.lock().await.clone(); | ||
| let selector = self.tool_route_manager.get_router_tool_selector().await; | ||
| if ToolRouterIndexManager::is_tool_router_enabled(&selector) { | ||
| if let Some(selector) = selector { | ||
| let extension_manager = self.extension_manager.read().await; | ||
|
|
@@ -938,14 +859,17 @@ impl Agent { | |
| } | ||
|
|
||
| if let Some(response) = response { | ||
| let tool_result = self.process_tool_requests(&response, &tools).await; | ||
| let ToolProcessingResult { | ||
| let ToolCategorizeResult { | ||
| frontend_requests, | ||
| remaining_requests, | ||
| filtered_response, | ||
| readonly_tools, | ||
| regular_tools, | ||
| } = tool_result; | ||
| } = self.categorize_tools(&response, &tools).await; | ||
| let requests_to_record: Vec<ToolRequest> = frontend_requests.iter().chain(remaining_requests.iter()).cloned().collect(); | ||
| self.tool_route_manager | ||
| .record_tool_requests(&requests_to_record) | ||
| .await; | ||
|
|
||
| yield AgentEvent::Message(filtered_response.clone()); | ||
| tokio::task::yield_now().await; | ||
|
|
@@ -1151,67 +1075,15 @@ impl Agent { | |
| provider: Option<Arc<dyn Provider>>, | ||
| reindex_all: Option<bool>, | ||
| ) -> Result<()> { | ||
| let config = Config::global(); | ||
| let _extension_manager = self.extension_manager.read().await; | ||
| let provider = match provider { | ||
| Some(p) => p, | ||
| None => self.provider().await?, | ||
| }; | ||
|
|
||
| let router_tool_selection_strategy = config | ||
| .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") | ||
| .unwrap_or_else(|_| "default".to_string()); | ||
|
|
||
| let strategy = match router_tool_selection_strategy.to_lowercase().as_str() { | ||
| "vector" => Some(RouterToolSelectionStrategy::Vector), | ||
| "llm" => Some(RouterToolSelectionStrategy::Llm), | ||
| _ => None, | ||
| }; | ||
|
|
||
| let selector = match strategy { | ||
| Some(RouterToolSelectionStrategy::Vector) => { | ||
| let table_name = generate_table_id(); | ||
| let selector = create_tool_selector(strategy, provider.clone(), Some(table_name)) | ||
| .await | ||
| .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; | ||
| Arc::new(selector) | ||
| } | ||
| Some(RouterToolSelectionStrategy::Llm) => { | ||
| let selector = create_tool_selector(strategy, provider.clone(), None) | ||
| .await | ||
| .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; | ||
| Arc::new(selector) | ||
| } | ||
| None => return Ok(()), | ||
| }; | ||
|
|
||
| // First index platform tools | ||
| let extension_manager = self.extension_manager.read().await; | ||
| ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?; | ||
|
|
||
| if reindex_all.unwrap_or(false) { | ||
| let enabled_extensions = extension_manager.list_extensions().await?; | ||
| for extension_name in enabled_extensions { | ||
| if let Err(e) = ToolRouterIndexManager::update_extension_tools( | ||
| &selector, | ||
| &extension_manager, | ||
| &extension_name, | ||
| "add", | ||
| ) | ||
| .await | ||
| { | ||
| error!( | ||
| "Failed to index tools for extension {}: {}", | ||
| extension_name, e | ||
| ); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Update the selector | ||
| *self.router_tool_selector.lock().await = Some(selector.clone()); | ||
|
|
||
| Ok(()) | ||
| // Delegate to ToolRouteManager | ||
| self.tool_route_manager | ||
| .update_router_tool_selector(provider, reindex_all, &self.extension_manager) | ||
| .await | ||
| } | ||
|
|
||
| /// Override the system prompt with a custom template | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved this logic outside of this function