From 6561dd0611605ea7d45514450cc2a0c819b3fc84 Mon Sep 17 00:00:00 2001 From: Wendy Tang Date: Tue, 10 Jun 2025 17:30:32 -0700 Subject: [PATCH] Improve extension filtering and tool router functionality - Enhanced router tool selector with better extension filtering - Updated tool vector database with improved indexing - Modified agent and router tools for better tool selection - Updated desktop OpenAPI specification --- crates/goose/src/agents/agent.rs | 30 +++++++-- .../goose/src/agents/router_tool_selector.rs | 10 ++- crates/goose/src/agents/router_tools.rs | 14 ++-- .../src/agents/tool_router_index_manager.rs | 35 ++++++---- crates/goose/src/agents/tool_vectordb.rs | 67 +++++++++++++++++-- ui/desktop/openapi.json | 2 +- 6 files changed, 120 insertions(+), 38 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 991603007c87..9ab68c91b2e2 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -246,13 +246,29 @@ impl Agent { ))) } else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME { let selector = self.router_tool_selector.lock().await.clone(); - ToolCallResult::from(if let Some(selector) = selector { - selector.select_tools(tool_call.arguments.clone()).await - } else { - Err(ToolError::ExecutionError( - "Encountered vector search error.".to_string(), - )) - }) + let selected_tools = match selector.as_ref() { + Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await { + Ok(tools) => tools, + Err(e) => { + return ( + request_id, + Err(ToolError::ExecutionError(format!( + "Failed to select tools: {}", + e + ))), + ) + } + }, + None => { + return ( + request_id, + Err(ToolError::ExecutionError( + "No tool selector available".to_string(), + )), + ) + } + }; + ToolCallResult::from(Ok(selected_tools)) } else { // Clone the result to ensure no references to extension_manager are returned let result = extension_manager diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index 316197ea37b2..5ca570f89c28 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -22,7 +22,7 @@ pub enum RouterToolSelectionStrategy { #[async_trait] pub trait RouterToolSelector: Send + Sync { async fn select_tools(&self, params: Value) -> Result, ToolError>; - async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError>; + async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError>; async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError>; async fn record_tool_call(&self, tool_name: &str) -> Result<(), ToolError>; async fn get_recent_tool_calls(&self, limit: usize) -> Result, ToolError>; @@ -76,6 +76,9 @@ impl RouterToolSelector for VectorToolSelector { let k = params.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize; + // Extract extension_name from params if present + let extension_name = params.get("extension_name").and_then(|v| v.as_str()); + // Check if provider supports embeddings if !self.embedding_provider.supports_embeddings() { return Err(ToolError::ExecutionError( @@ -98,7 +101,7 @@ impl RouterToolSelector for VectorToolSelector { let vector_db = self.vector_db.read().await; let tools = vector_db - .search_tools(query_embedding, k) + .search_tools(query_embedding, k, extension_name) .await .map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?; @@ -119,7 +122,7 @@ impl RouterToolSelector for VectorToolSelector { Ok(selected_tools) } - async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError> { + async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError> { let texts_to_embed: Vec = tools .iter() .map(|tool| { @@ -155,6 +158,7 @@ impl RouterToolSelector for VectorToolSelector { description: tool.description.clone(), schema: schema_str, vector, + extension_name: extension_name.to_string(), } }) .collect(); diff --git a/crates/goose/src/agents/router_tools.rs b/crates/goose/src/agents/router_tools.rs index e9ec9a96ceeb..8be668002f3f 100644 --- a/crates/goose/src/agents/router_tools.rs +++ b/crates/goose/src/agents/router_tools.rs @@ -12,18 +12,20 @@ pub fn vector_search_tool() -> Tool { Format a query to search for the most relevant tools based on the user's messages. Pay attention to the keywords in the user's messages, especially the last message and potential tools they are asking for. This tool should be invoked when the user's messages suggest they are asking for a tool to be run. - Examples: - - {"User": "what is the weather in Tokyo?", "Query": "weather in Tokyo"} - - {"User": "read this pdf file for me", "Query": "read pdf file"} - - {"User": "run this command ls -l in the terminal", "Query": "run command in terminal ls -l"} + You have the list of extension names available to you in your system prompt. + Use the extension_name parameter to filter tools by the appropriate extension. + For example, if the user is asking to list the files in the current directory, you filter for the "developer" extension. + Example: {"User": "list the files in the current directory", "Query": "list files in current directory", "Extension Name": "developer", "k": 5} + Extension name is not optional, it is required. "#} .to_string(), json!({ "type": "object", - "required": ["query"], + "required": ["query", "extension_name"], "properties": { "query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"}, - "k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5} + "k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5}, + "extension_name": {"type": "string", "description": "Name of the extension to filter tools by"} } }), Some(ToolAnnotations { diff --git a/crates/goose/src/agents/tool_router_index_manager.rs b/crates/goose/src/agents/tool_router_index_manager.rs index cdee03263ae1..a0fb0d63e4af 100644 --- a/crates/goose/src/agents/tool_router_index_manager.rs +++ b/crates/goose/src/agents/tool_router_index_manager.rs @@ -26,13 +26,16 @@ impl ToolRouterIndexManager { if !tools.is_empty() { // Index all tools at once - selector.index_tools(&tools).await.map_err(|e| { - anyhow!( - "Failed to index tools for extension {}: {}", - extension_name, - e - ) - })?; + selector + .index_tools(&tools, extension_name) + .await + .map_err(|e| { + anyhow!( + "Failed to index tools for extension {}: {}", + extension_name, + e + ) + })?; tracing::info!( "Indexed {} tools for extension {}", @@ -42,16 +45,20 @@ impl ToolRouterIndexManager { } } "remove" => { - // Get tool names for the extension to remove them + // Remove all tools for this extension let tools = extension_manager .get_prefixed_tools(Some(extension_name.to_string())) .await?; for tool in &tools { - selector - .remove_tool(&tool.name) - .await - .map_err(|e| anyhow!("Failed to remove tool {}: {}", tool.name, e))?; + selector.remove_tool(&tool.name).await.map_err(|e| { + anyhow!( + "Failed to remove tool {} for extension {}: {}", + tool.name, + extension_name, + e + ) + })?; } tracing::info!( @@ -61,7 +68,7 @@ impl ToolRouterIndexManager { ); } _ => { - anyhow::bail!("Invalid action '{}' for tool indexing", action); + return Err(anyhow!("Invalid action: {}", action)); } } @@ -87,7 +94,7 @@ impl ToolRouterIndexManager { // Index all platform tools at once selector - .index_tools(&tools) + .index_tools(&tools, "platform") .await .map_err(|e| anyhow!("Failed to index platform tools: {}", e))?; diff --git a/crates/goose/src/agents/tool_vectordb.rs b/crates/goose/src/agents/tool_vectordb.rs index e46c13b41542..293360e234c8 100644 --- a/crates/goose/src/agents/tool_vectordb.rs +++ b/crates/goose/src/agents/tool_vectordb.rs @@ -18,6 +18,7 @@ pub struct ToolRecord { pub description: String, pub schema: String, pub vector: Vec, + pub extension_name: String, } pub struct ToolVectorDB { @@ -84,12 +85,14 @@ impl ToolVectorDB { ), false, ), + Field::new("extension_name", DataType::Utf8, false), ])); // Create empty table let tool_names = StringArray::from(vec![] as Vec<&str>); let descriptions = StringArray::from(vec![] as Vec<&str>); let schemas = StringArray::from(vec![] as Vec<&str>); + let extension_names = StringArray::from(vec![] as Vec<&str>); // Create empty fixed size list array for vectors let mut vectors_builder = @@ -103,6 +106,7 @@ impl ToolVectorDB { Arc::new(descriptions), Arc::new(schemas), Arc::new(vectors), + Arc::new(extension_names), ], ) .context("Failed to create record batch")?; @@ -163,6 +167,7 @@ impl ToolVectorDB { let tool_names: Vec<&str> = tools.iter().map(|t| t.tool_name.as_str()).collect(); let descriptions: Vec<&str> = tools.iter().map(|t| t.description.as_str()).collect(); let schemas: Vec<&str> = tools.iter().map(|t| t.schema.as_str()).collect(); + let extension_names: Vec<&str> = tools.iter().map(|t| t.extension_name.as_str()).collect(); let vectors_data: Vec>>> = tools .iter() @@ -181,11 +186,13 @@ impl ToolVectorDB { ), false, ), + Field::new("extension_name", DataType::Utf8, false), ])); let tool_names_array = StringArray::from(tool_names); let descriptions_array = StringArray::from(descriptions); let schemas_array = StringArray::from(schemas); + let extension_names_array = StringArray::from(extension_names); // Build vectors array let mut vectors_builder = FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536); @@ -213,6 +220,7 @@ impl ToolVectorDB { Arc::new(descriptions_array), Arc::new(schemas_array), Arc::new(vectors_array), + Arc::new(extension_names_array), ], ) .context("Failed to create record batch")?; @@ -239,7 +247,12 @@ impl ToolVectorDB { Ok(()) } - pub async fn search_tools(&self, query_vector: Vec, k: usize) -> Result> { + pub async fn search_tools( + &self, + query_vector: Vec, + k: usize, + extension_name: Option<&str>, + ) -> Result> { let connection = self.connection.read().await; let table = connection @@ -248,9 +261,11 @@ impl ToolVectorDB { .await .context("Failed to open tools table")?; - let results = table + let search = table .vector_search(query_vector) - .context("Failed to create vector search")? + .context("Failed to create vector search")?; + + let results = search .limit(k) .execute() .await @@ -281,6 +296,13 @@ impl ToolVectorDB { .downcast_ref::() .context("Invalid schema column type")?; + let extension_names = batch + .column_by_name("extension_name") + .context("Missing extension_name column")? + .as_any() + .downcast_ref::() + .context("Invalid extension_name column type")?; + // Get the distance scores let distances = batch .column_by_name("_distance") @@ -292,12 +314,21 @@ impl ToolVectorDB { for i in 0..batch.num_rows() { let tool_name = tool_names.value(i).to_string(); let _distance = distances.value(i); + let ext_name = extension_names.value(i).to_string(); + + // Filter by extension name if provided + if let Some(filter_ext) = extension_name { + if ext_name != filter_ext { + continue; + } + } tools.push(ToolRecord { tool_name, description: descriptions.value(i).to_string(), schema: schemas.value(i).to_string(), vector: vec![], // We don't need to return the vector + extension_name: ext_name, }); } } @@ -356,6 +387,7 @@ mod tests { schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# .to_string(), vector: vec![0.1; 1536], // Mock embedding vector + extension_name: "test_extension".to_string(), }, ToolRecord { tool_name: "test_tool_2".to_string(), @@ -363,6 +395,7 @@ mod tests { schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"# .to_string(), vector: vec![0.2; 1536], // Different mock embedding vector + extension_name: "test_extension".to_string(), }, ]; @@ -371,7 +404,7 @@ mod tests { // Search for tools using a query vector similar to test_tool_1 let query_vector = vec![0.1; 1536]; - let results = db.search_tools(query_vector, 2).await?; + let results = db.search_tools(query_vector.clone(), 2, None).await?; // Verify results assert_eq!(results.len(), 2, "Should find both tools"); @@ -384,6 +417,25 @@ mod tests { "Second result should be test_tool_2" ); + // Test filtering by extension name + let results = db + .search_tools(query_vector.clone(), 2, Some("test_extension")) + .await?; + assert_eq!( + results.len(), + 2, + "Should find both tools with test_extension" + ); + + let results = db + .search_tools(query_vector.clone(), 2, Some("nonexistent_extension")) + .await?; + assert_eq!( + results.len(), + 0, + "Should find no tools with nonexistent_extension" + ); + Ok(()) } @@ -397,7 +449,7 @@ mod tests { // Search in empty database let query_vector = vec![0.1; 1536]; - let results = db.search_tools(query_vector, 2).await?; + let results = db.search_tools(query_vector, 2, None).await?; // Verify no results returned assert_eq!(results.len(), 0, "Empty database should return no results"); @@ -419,20 +471,21 @@ mod tests { description: "A test tool that will be deleted".to_string(), schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#.to_string(), vector: vec![0.1; 1536], + extension_name: "test_extension".to_string(), }; db.index_tools(vec![test_tool]).await?; // Verify tool exists let query_vector = vec![0.1; 1536]; - let results = db.search_tools(query_vector.clone(), 1).await?; + let results = db.search_tools(query_vector.clone(), 1, None).await?; assert_eq!(results.len(), 1, "Tool should exist before deletion"); // Delete the tool db.remove_tool("test_tool_to_delete").await?; // Verify tool is gone - let results = db.search_tools(query_vector.clone(), 1).await?; + let results = db.search_tools(query_vector.clone(), 1, None).await?; assert_eq!(results.len(), 0, "Tool should be deleted"); Ok(()) diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 7ad0e79d4c86..ffd9dca2b427 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -10,7 +10,7 @@ "license": { "name": "Apache-2.0" }, - "version": "1.0.26" + "version": "1.0.27" }, "paths": { "/agent/tools": {