Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions crates/goose-cli/src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,6 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
let agent: Agent = Agent::new();
let new_provider = create(&provider_name, model_config).unwrap();
let _ = agent.update_provider(new_provider).await;

// Initialize router tool selector if vector strategy is enabled
if let Err(e) = agent.initialize_router_tool_selector().await {
output::render_error(&format!("Failed to initialize router tool selector: {}", e));
process::exit(1);
}

// Configure tool monitoring if max_tool_repetitions is set
if let Some(max_repetitions) = session_config.max_tool_repetitions {
Expand Down
1 change: 1 addition & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ jsonwebtoken = "9.3.1"
blake3 = "1.5"
fs2 = "0.4.3"
futures-util = "0.3.31"
downcast-rs = "1.2"

# Vector database for tool selection
lancedb = "0.13"
Expand Down
225 changes: 160 additions & 65 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,6 @@ impl Agent {
router_tool_selector: Mutex::new(None),
}
}

pub async fn initialize_router_tool_selector(&self) -> Result<()> {
let router_tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.ok()
.and_then(|s| {
if s.eq_ignore_ascii_case("vector") {
Some(RouterToolSelectionStrategy::Vector)
} else {
None
}
});

if router_tool_selection_strategy.is_some() {
let selector = create_tool_selector(router_tool_selection_strategy)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
*self.router_tool_selector.lock().await = Some(selector);
}

Ok(())
}

pub async fn configure_tool_monitor(&self, max_repetitions: Option<u32>) {
let mut tool_monitor = self.tool_monitor.lock().await;
Expand Down Expand Up @@ -292,11 +271,27 @@ impl Agent {
))]
})
.map_err(|e| ToolError::ExecutionError(e.to_string()));

// If vector tool selection is enabled, index the tools
if result.is_ok() {
if let Err(e) = self.index_tools_if_vector_enabled().await {
tracing::error!("Failed to index tools after adding extension: {}", e);
if action == "disable" {
if let Err(e) = self
.index_tools_if_vector_enabled(
Some(extension_name.clone()),
Some("remove"),
false,
)
.await
{
tracing::error!("Failed to remove tools from vector index: {}", e);
}
} else {
if let Err(e) = self
.index_tools_if_vector_enabled(Some(extension_name.clone()), Some("add"), false)
.await
{
tracing::error!("Failed to index tools: {}", e);
}
}
}

Expand Down Expand Up @@ -336,12 +331,17 @@ impl Agent {
extension_manager.add_extension(extension.clone()).await?;
}
};

// If vector tool selection is enabled, index the tools
if let Err(e) = self.index_tools_if_vector_enabled().await {
return Err(ExtensionError::SetupError(
format!("Failed to index tools for extension {}: {}", extension.name(), e),
));
if let Err(e) = self
.index_tools_if_vector_enabled(Some(extension.name()), Some("add"), false)
.await
{
return Err(ExtensionError::SetupError(format!(
"Failed to index tools for extension {}: {}",
extension.name(),
e
)));
}

Ok(())
Expand Down Expand Up @@ -398,6 +398,14 @@ impl Agent {
.remove_extension(name)
.await
.expect("Failed to remove extension");

// If vector tool selection is enabled, remove tools from the index
if let Err(e) = self
.index_tools_if_vector_enabled(Some(name.to_string()), Some("remove"), false)
.await
{
tracing::error!("Failed to remove tools from vector index: {}", e);
}
}

pub async fn list_extensions(&self) -> Vec<String> {
Expand Down Expand Up @@ -621,7 +629,29 @@ impl Agent {

/// Update the provider used by this agent
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
*self.provider.lock().await = Some(provider);
*self.provider.lock().await = Some(provider.clone());
self.update_router_tool_selector(provider).await?;
Ok(())
}

async fn update_router_tool_selector(&self, provider: Arc<dyn Provider>) -> Result<()> {
let router_tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.ok()
.and_then(|s| {
if s.eq_ignore_ascii_case("vector") {
Some(RouterToolSelectionStrategy::Vector)
} else {
None
}
});

if router_tool_selection_strategy.is_some() {
let selector = create_tool_selector(router_tool_selection_strategy, provider)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
*self.router_tool_selector.lock().await = Some(selector);
}

Ok(())
}

Expand Down Expand Up @@ -688,47 +718,112 @@ impl Agent {
}
}

async fn index_tools_if_vector_enabled(&self) -> Result<()> {
async fn index_tools_if_vector_enabled(
&self,
extension_name: Option<String>,
action: Option<&str>,
reindex_all: bool,
) -> Result<()> {
// Only proceed if vector strategy is enabled
let is_vector_enabled = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.map(|s| s.eq_ignore_ascii_case("vector"))
.unwrap_or(false);

if !is_vector_enabled {
return Ok(());
}

let router_tool_selector = self.router_tool_selector.lock().await;
if let Some(selector) = router_tool_selector.as_ref() {
// Get all tools from extension manager
let extension_manager = self.extension_manager.lock().await;
let tools = extension_manager.get_prefixed_tools(None).await?;

// Clear existing tools and re-index all
selector.clear_tools().await
.map_err(|e| anyhow!("Failed to clear tools: {}", e))?;

// Index each tool
for tool in &tools {
let schema_str = serde_json::to_string_pretty(&tool.input_schema)
.unwrap_or_else(|_| "{}".to_string());

selector.index_tool(
tool.name.clone(),
tool.description.clone(),
schema_str,
).await
.map_err(|e| anyhow!("Failed to index tool {}: {}", tool.name, e))?;

if reindex_all {
// Clear and reindex everything
selector
.clear_tools()
.await
.map_err(|e| anyhow!("Failed to clear tools: {}", e))?;

// Index all extension tools
let all_tools = extension_manager.get_prefixed_tools(None).await?;
for tool in &all_tools {
let schema_str = serde_json::to_string_pretty(&tool.input_schema)
.unwrap_or_else(|_| "{}".to_string());

selector
.index_tool(tool.name.clone(), tool.description.clone(), schema_str)
.await
.map_err(|e| anyhow!("Failed to index tool {}: {}", tool.name, e))?;
}

// Index all frontend tools
let frontend_tools = self.frontend_tools.lock().await;
for frontend_tool in frontend_tools.values() {
let schema_str = serde_json::to_string_pretty(&frontend_tool.tool.input_schema)
.unwrap_or_else(|_| "{}".to_string());

selector
.index_tool(
frontend_tool.tool.name.clone(),
frontend_tool.tool.description.clone(),
schema_str,
)
.await
.map_err(|e| {
anyhow!(
"Failed to index frontend tool {}: {}",
frontend_tool.tool.name,
e
)
})?;
}

tracing::info!("Reindexed all tools for vector search");
return Ok(());
}

// Also index frontend tools
let frontend_tools = self.frontend_tools.lock().await;
for frontend_tool in frontend_tools.values() {
let schema_str = serde_json::to_string_pretty(&frontend_tool.tool.input_schema)
.unwrap_or_else(|_| "{}".to_string());

selector.index_tool(
frontend_tool.tool.name.clone(),
frontend_tool.tool.description.clone(),
schema_str,
).await
.map_err(|e| anyhow!("Failed to index frontend tool {}: {}", frontend_tool.tool.name, e))?;

// Handle specific extension operations
if let (Some(ext_name), Some(act)) = (extension_name, action) {
match act {
"add" => {
// Get tools for specific extension
let tools = extension_manager
.get_prefixed_tools(Some(ext_name.clone()))
.await?;
for tool in &tools {
let schema_str = serde_json::to_string_pretty(&tool.input_schema)
.unwrap_or_else(|_| "{}".to_string());

selector
.index_tool(tool.name.clone(), tool.description.clone(), schema_str)
.await
.map_err(|e| {
anyhow!("Failed to index tool {}: {}", tool.name, e)
})?;
}
tracing::info!("Indexed {} tools for extension {}", tools.len(), ext_name);
}
"remove" => {
// Get tool names for the extension to remove them
let tools = extension_manager
.get_prefixed_tools(Some(ext_name.clone()))
.await?;
for tool in &tools {
selector.remove_tool(&tool.name).await.map_err(|e| {
anyhow!("Failed to remove tool {}: {}", tool.name, e)
})?;
}
tracing::info!("Removed {} tools for extension {}", tools.len(), ext_name);
}
_ => {
anyhow::bail!("Invalid action '{}' for tool indexing", act);
}
}
} else {
anyhow::bail!("Extension name and action required for tool indexing");
}

tracing::info!("Indexed {} tools for vector search", tools.len() + frontend_tools.len());
}

Ok(())
}

Expand Down
Loading