Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,29 @@ impl Agent {
)))
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
let selector = self.router_tool_selector.lock().await.clone();
ToolCallResult::from(if let Some(selector) = selector {
selector.select_tools(tool_call.arguments.clone()).await
} else {
Err(ToolError::ExecutionError(
"Encountered vector search error.".to_string(),
))
})
let selected_tools = match selector.as_ref() {
Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await {
Ok(tools) => tools,
Err(e) => {
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to select tools: {}",
e
))),
)
}
},
None => {
return (
request_id,
Err(ToolError::ExecutionError(
"No tool selector available".to_string(),
)),
)
}
};
ToolCallResult::from(Ok(selected_tools))
} else {
// Clone the result to ensure no references to extension_manager are returned
let result = extension_manager
Expand Down
10 changes: 7 additions & 3 deletions crates/goose/src/agents/router_tool_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub enum RouterToolSelectionStrategy {
#[async_trait]
pub trait RouterToolSelector: Send + Sync {
async fn select_tools(&self, params: Value) -> Result<Vec<Content>, ToolError>;
async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError>;
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError>;
async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError>;
async fn record_tool_call(&self, tool_name: &str) -> Result<(), ToolError>;
async fn get_recent_tool_calls(&self, limit: usize) -> Result<Vec<String>, ToolError>;
Expand Down Expand Up @@ -76,6 +76,9 @@ impl RouterToolSelector for VectorToolSelector {

let k = params.get("k").and_then(|v| v.as_u64()).unwrap_or(5) as usize;

// Extract extension_name from params if present
let extension_name = params.get("extension_name").and_then(|v| v.as_str());

// Check if provider supports embeddings
if !self.embedding_provider.supports_embeddings() {
return Err(ToolError::ExecutionError(
Expand All @@ -98,7 +101,7 @@ impl RouterToolSelector for VectorToolSelector {

let vector_db = self.vector_db.read().await;
let tools = vector_db
.search_tools(query_embedding, k)
.search_tools(query_embedding, k, extension_name)
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?;

Expand All @@ -119,7 +122,7 @@ impl RouterToolSelector for VectorToolSelector {
Ok(selected_tools)
}

async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError> {
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError> {
let texts_to_embed: Vec<String> = tools
.iter()
.map(|tool| {
Expand Down Expand Up @@ -155,6 +158,7 @@ impl RouterToolSelector for VectorToolSelector {
description: tool.description.clone(),
schema: schema_str,
vector,
extension_name: extension_name.to_string(),
}
})
.collect();
Expand Down
14 changes: 8 additions & 6 deletions crates/goose/src/agents/router_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ pub fn vector_search_tool() -> Tool {
Format a query to search for the most relevant tools based on the user's messages.
Pay attention to the keywords in the user's messages, especially the last message and potential tools they are asking for.
This tool should be invoked when the user's messages suggest they are asking for a tool to be run.
Examples:
- {"User": "what is the weather in Tokyo?", "Query": "weather in Tokyo"}
- {"User": "read this pdf file for me", "Query": "read pdf file"}
- {"User": "run this command ls -l in the terminal", "Query": "run command in terminal ls -l"}
You have the list of extension names available to you in your system prompt.
Use the extension_name parameter to filter tools by the appropriate extension.
For example, if the user is asking to list the files in the current directory, you filter for the "developer" extension.
Example: {"User": "list the files in the current directory", "Query": "list files in current directory", "Extension Name": "developer", "k": 5}
Extension name is not optional, it is required.
"#}
.to_string(),
json!({
"type": "object",
"required": ["query"],
"required": ["query", "extension_name"],
"properties": {
"query": {"type": "string", "description": "The query to search for the most relevant tools based on the user's messages"},
"k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5}
"k": {"type": "integer", "description": "The number of tools to retrieve (defaults to 5)", "default": 5},
"extension_name": {"type": "string", "description": "Name of the extension to filter tools by"}
}
}),
Some(ToolAnnotations {
Expand Down
35 changes: 21 additions & 14 deletions crates/goose/src/agents/tool_router_index_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ impl ToolRouterIndexManager {

if !tools.is_empty() {
// Index all tools at once
selector.index_tools(&tools).await.map_err(|e| {
anyhow!(
"Failed to index tools for extension {}: {}",
extension_name,
e
)
})?;
selector
.index_tools(&tools, extension_name)
.await
.map_err(|e| {
anyhow!(
"Failed to index tools for extension {}: {}",
extension_name,
e
)
})?;

tracing::info!(
"Indexed {} tools for extension {}",
Expand All @@ -42,16 +45,20 @@ impl ToolRouterIndexManager {
}
}
"remove" => {
// Get tool names for the extension to remove them
// Remove all tools for this extension
let tools = extension_manager
.get_prefixed_tools(Some(extension_name.to_string()))
.await?;

for tool in &tools {
selector
.remove_tool(&tool.name)
.await
.map_err(|e| anyhow!("Failed to remove tool {}: {}", tool.name, e))?;
selector.remove_tool(&tool.name).await.map_err(|e| {
anyhow!(
"Failed to remove tool {} for extension {}: {}",
tool.name,
extension_name,
e
)
})?;
}

tracing::info!(
Expand All @@ -61,7 +68,7 @@ impl ToolRouterIndexManager {
);
}
_ => {
anyhow::bail!("Invalid action '{}' for tool indexing", action);
return Err(anyhow!("Invalid action: {}", action));
}
}

Expand All @@ -87,7 +94,7 @@ impl ToolRouterIndexManager {

// Index all platform tools at once
selector
.index_tools(&tools)
.index_tools(&tools, "platform")
.await
.map_err(|e| anyhow!("Failed to index platform tools: {}", e))?;

Expand Down
67 changes: 60 additions & 7 deletions crates/goose/src/agents/tool_vectordb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct ToolRecord {
pub description: String,
pub schema: String,
pub vector: Vec<f32>,
pub extension_name: String,
}

pub struct ToolVectorDB {
Expand Down Expand Up @@ -84,12 +85,14 @@ impl ToolVectorDB {
),
false,
),
Field::new("extension_name", DataType::Utf8, false),
]));

// Create empty table
let tool_names = StringArray::from(vec![] as Vec<&str>);
let descriptions = StringArray::from(vec![] as Vec<&str>);
let schemas = StringArray::from(vec![] as Vec<&str>);
let extension_names = StringArray::from(vec![] as Vec<&str>);

// Create empty fixed size list array for vectors
let mut vectors_builder =
Expand All @@ -103,6 +106,7 @@ impl ToolVectorDB {
Arc::new(descriptions),
Arc::new(schemas),
Arc::new(vectors),
Arc::new(extension_names),
],
)
.context("Failed to create record batch")?;
Expand Down Expand Up @@ -163,6 +167,7 @@ impl ToolVectorDB {
let tool_names: Vec<&str> = tools.iter().map(|t| t.tool_name.as_str()).collect();
let descriptions: Vec<&str> = tools.iter().map(|t| t.description.as_str()).collect();
let schemas: Vec<&str> = tools.iter().map(|t| t.schema.as_str()).collect();
let extension_names: Vec<&str> = tools.iter().map(|t| t.extension_name.as_str()).collect();

let vectors_data: Vec<Option<Vec<Option<f32>>>> = tools
.iter()
Expand All @@ -181,11 +186,13 @@ impl ToolVectorDB {
),
false,
),
Field::new("extension_name", DataType::Utf8, false),
]));

let tool_names_array = StringArray::from(tool_names);
let descriptions_array = StringArray::from(descriptions);
let schemas_array = StringArray::from(schemas);
let extension_names_array = StringArray::from(extension_names);
// Build vectors array
let mut vectors_builder =
FixedSizeListBuilder::new(arrow::array::Float32Builder::new(), 1536);
Expand Down Expand Up @@ -213,6 +220,7 @@ impl ToolVectorDB {
Arc::new(descriptions_array),
Arc::new(schemas_array),
Arc::new(vectors_array),
Arc::new(extension_names_array),
],
)
.context("Failed to create record batch")?;
Expand All @@ -239,7 +247,12 @@ impl ToolVectorDB {
Ok(())
}

pub async fn search_tools(&self, query_vector: Vec<f32>, k: usize) -> Result<Vec<ToolRecord>> {
pub async fn search_tools(
&self,
query_vector: Vec<f32>,
k: usize,
extension_name: Option<&str>,
) -> Result<Vec<ToolRecord>> {
let connection = self.connection.read().await;

let table = connection
Expand All @@ -248,9 +261,11 @@ impl ToolVectorDB {
.await
.context("Failed to open tools table")?;

let results = table
let search = table
.vector_search(query_vector)
.context("Failed to create vector search")?
.context("Failed to create vector search")?;

let results = search
.limit(k)
.execute()
.await
Expand Down Expand Up @@ -281,6 +296,13 @@ impl ToolVectorDB {
.downcast_ref::<StringArray>()
.context("Invalid schema column type")?;

let extension_names = batch
.column_by_name("extension_name")
.context("Missing extension_name column")?
.as_any()
.downcast_ref::<StringArray>()
.context("Invalid extension_name column type")?;

// Get the distance scores
let distances = batch
.column_by_name("_distance")
Expand All @@ -292,12 +314,21 @@ impl ToolVectorDB {
for i in 0..batch.num_rows() {
let tool_name = tool_names.value(i).to_string();
let _distance = distances.value(i);
let ext_name = extension_names.value(i).to_string();

// Filter by extension name if provided
if let Some(filter_ext) = extension_name {
if ext_name != filter_ext {
continue;
}
}

tools.push(ToolRecord {
tool_name,
description: descriptions.value(i).to_string(),
schema: schemas.value(i).to_string(),
vector: vec![], // We don't need to return the vector
extension_name: ext_name,
});
}
}
Expand Down Expand Up @@ -356,13 +387,15 @@ mod tests {
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.1; 1536], // Mock embedding vector
extension_name: "test_extension".to_string(),
},
ToolRecord {
tool_name: "test_tool_2".to_string(),
description: "A test tool for writing files".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#
.to_string(),
vector: vec![0.2; 1536], // Different mock embedding vector
extension_name: "test_extension".to_string(),
},
];

Expand All @@ -371,7 +404,7 @@ mod tests {

// Search for tools using a query vector similar to test_tool_1
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector, 2).await?;
let results = db.search_tools(query_vector.clone(), 2, None).await?;

// Verify results
assert_eq!(results.len(), 2, "Should find both tools");
Expand All @@ -384,6 +417,25 @@ mod tests {
"Second result should be test_tool_2"
);

// Test filtering by extension name
let results = db
.search_tools(query_vector.clone(), 2, Some("test_extension"))
.await?;
assert_eq!(
results.len(),
2,
"Should find both tools with test_extension"
);

let results = db
.search_tools(query_vector.clone(), 2, Some("nonexistent_extension"))
.await?;
assert_eq!(
results.len(),
0,
"Should find no tools with nonexistent_extension"
);

Ok(())
}

Expand All @@ -397,7 +449,7 @@ mod tests {

// Search in empty database
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector, 2).await?;
let results = db.search_tools(query_vector, 2, None).await?;

// Verify no results returned
assert_eq!(results.len(), 0, "Empty database should return no results");
Expand All @@ -419,20 +471,21 @@ mod tests {
description: "A test tool that will be deleted".to_string(),
schema: r#"{"type": "object", "properties": {"path": {"type": "string"}}}"#.to_string(),
vector: vec![0.1; 1536],
extension_name: "test_extension".to_string(),
};

db.index_tools(vec![test_tool]).await?;

// Verify tool exists
let query_vector = vec![0.1; 1536];
let results = db.search_tools(query_vector.clone(), 1).await?;
let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 1, "Tool should exist before deletion");

// Delete the tool
db.remove_tool("test_tool_to_delete").await?;

// Verify tool is gone
let results = db.search_tools(query_vector.clone(), 1).await?;
let results = db.search_tools(query_vector.clone(), 1, None).await?;
assert_eq!(results.len(), 0, "Tool should be deleted");

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion ui/desktop/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"license": {
"name": "Apache-2.0"
},
"version": "1.0.26"
"version": "1.0.27"
},
"paths": {
"/agent/tools": {
Expand Down
Loading