Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ jobs:
source ../bin/activate-hermit && cargo test
working-directory: crates

# Add disk space cleanup before linting
- name: Check disk space before cleanup
run: df -h

Expand Down
50 changes: 32 additions & 18 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::agents::prompt_manager::PromptManager;
use crate::agents::router_tool_selector::{
create_tool_selector, RouterToolSelectionStrategy, RouterToolSelector,
};
use crate::agents::router_tools::ROUTER_VECTOR_SEARCH_TOOL_NAME;
use crate::agents::router_tools::{ROUTER_LLM_SEARCH_TOOL_NAME, ROUTER_VECTOR_SEARCH_TOOL_NAME};
use crate::agents::tool_router_index_manager::ToolRouterIndexManager;
use crate::agents::tool_vectordb::generate_table_id;
use crate::agents::types::SessionConfig;
Expand Down Expand Up @@ -244,7 +244,9 @@ impl Agent {
ToolCallResult::from(Err(ToolError::ExecutionError(
"Frontend tool execution required".to_string(),
)))
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME {
} else if tool_call.name == ROUTER_VECTOR_SEARCH_TOOL_NAME
|| tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME
{
let selector = self.router_tool_selector.lock().await.clone();
ToolCallResult::from(if let Some(selector) = selector {
selector.select_tools(tool_call.arguments.clone()).await
Expand Down Expand Up @@ -335,10 +337,11 @@ impl Agent {
// Update vector index if operation was successful and vector routing is enabled
if result.is_ok() {
let selector = self.router_tool_selector.lock().await.clone();
if ToolRouterIndexManager::vector_tool_router_enabled(&selector) {
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if let Some(selector) = selector {
let vector_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.lock().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
Expand Down Expand Up @@ -398,9 +401,10 @@ impl Agent {

// If vector tool selection is enabled, index the tools
let selector = self.router_tool_selector.lock().await.clone();
if ToolRouterIndexManager::vector_tool_router_enabled(&selector) {
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if let Some(selector) = selector {
let extension_manager = self.extension_manager.lock().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
Expand Down Expand Up @@ -452,6 +456,9 @@ impl Agent {
Some(RouterToolSelectionStrategy::Vector) => {
prefixed_tools.push(router_tools::vector_search_tool());
}
Some(RouterToolSelectionStrategy::Llm) => {
prefixed_tools.push(router_tools::llm_search_tool());
}
None => {}
}

Expand Down Expand Up @@ -484,7 +491,7 @@ impl Agent {

// If vector tool selection is enabled, remove tools from the index
let selector = self.router_tool_selector.lock().await.clone();
if ToolRouterIndexManager::vector_tool_router_enabled(&selector) {
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if let Some(selector) = selector {
let extension_manager = self.extension_manager.lock().await;
ToolRouterIndexManager::update_extension_tools(
Expand Down Expand Up @@ -777,22 +784,29 @@ impl Agent {

let strategy = match router_tool_selection_strategy.to_lowercase().as_str() {
"vector" => Some(RouterToolSelectionStrategy::Vector),
"llm" => Some(RouterToolSelectionStrategy::Llm),
_ => None,
};

if let Some(strategy) = strategy {
let table_name = generate_table_id();
let selector = create_tool_selector(Some(strategy), provider, table_name)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;

let selector = Arc::new(selector);
*self.router_tool_selector.lock().await = Some(selector.clone());

let extension_manager = self.extension_manager.lock().await;
ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?;
}

let selector = match strategy {
Some(RouterToolSelectionStrategy::Vector) => {
let table_name = generate_table_id();
let selector = create_tool_selector(strategy, provider, Some(table_name))
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
Arc::new(selector)
}
Some(RouterToolSelectionStrategy::Llm) => {
let selector = create_tool_selector(strategy, provider, None)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
Arc::new(selector)
}
None => return Ok(()),
};
let extension_manager = self.extension_manager.lock().await;
ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?;
*self.router_tool_selector.lock().await = Some(selector.clone());
Ok(())
}

Expand Down
8 changes: 7 additions & 1 deletion crates/goose/src/agents/prompt_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashMap;

use crate::agents::extension::ExtensionInfo;
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::agents::router_tools::vector_search_tool_prompt;
use crate::agents::router_tools::{llm_search_tool_prompt, vector_search_tool_prompt};
use crate::providers::base::get_current_model;
use crate::{config::Config, prompt_template};

Expand Down Expand Up @@ -92,6 +92,12 @@ impl PromptManager {
Value::String(vector_search_tool_prompt()),
);
}
Some(RouterToolSelectionStrategy::Llm) => {
context.insert(
"tool_selection_strategy",
Value::String(llm_search_tool_prompt()),
);
}
None => {}
}

Expand Down
5 changes: 5 additions & 0 deletions crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ impl Agent {

let tool_selection_strategy = match router_tool_selection_strategy.to_lowercase().as_str() {
"vector" => Some(RouterToolSelectionStrategy::Vector),
"llm" => Some(RouterToolSelectionStrategy::Llm),
_ => None,
};

Expand All @@ -38,6 +39,10 @@ impl Agent {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Vector))
.await
}
Some(RouterToolSelectionStrategy::Llm) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Llm))
.await
}
_ => self.list_tools(None).await,
};
// Add frontend tools
Expand Down
143 changes: 140 additions & 3 deletions crates/goose/src/agents/router_tool_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ use mcp_core::{Content, ToolError};
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::env;
use std::sync::Arc;
use tokio::sync::RwLock;

use crate::agents::tool_vectordb::ToolVectorDB;
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::{self, base::Provider};

#[derive(Debug, Clone, PartialEq)]
pub enum RouterToolSelectionStrategy {
Vector,
Llm,
}

#[async_trait]
Expand Down Expand Up @@ -196,19 +199,153 @@ impl RouterToolSelector for VectorToolSelector {
}
}

pub struct LLMToolSelector {
llm_provider: Arc<dyn Provider>,
tool_strings: Arc<RwLock<HashMap<String, String>>>, // extension_name -> tool_string
recent_tool_calls: Arc<RwLock<VecDeque<String>>>,
}

impl LLMToolSelector {
pub async fn new(provider: Arc<dyn Provider>) -> Result<Self> {
Ok(Self {
llm_provider: provider.clone(),
tool_strings: Arc::new(RwLock::new(HashMap::new())),
recent_tool_calls: Arc::new(RwLock::new(VecDeque::with_capacity(100))),
})
}
}

#[async_trait]
impl RouterToolSelector for LLMToolSelector {
async fn select_tools(&self, params: Value) -> Result<Vec<Content>, ToolError> {
let query = params
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters("Missing 'query' parameter".to_string()))?;

let extension_name = params
.get("extension_name")
.and_then(|v| v.as_str())
.map(|s| s.to_string());

// Get relevant tool strings based on extension_name
let tool_strings = self.tool_strings.read().await;
let relevant_tools = if let Some(ext) = &extension_name {
tool_strings.get(ext).cloned()
} else {
// If no extension specified, use all tools
Some(
tool_strings
.values()
.cloned()
.collect::<Vec<String>>()
.join("\n"),
)
};

if let Some(tools) = relevant_tools {
// Use LLM to search through tools
let prompt = format!(
"Given the following tools:\n{}\n\nFind the most relevant tools for the query: {}\n\nReturn the tools in this exact format for each tool:\nTool: <tool_name>\nDescription: <tool_description>\nSchema: <tool_schema>",
tools, query
);
let system_message = Message::user().with_text("You are a tool selection assistant. Your task is to find the most relevant tools based on the user's query.");
let response = self
.llm_provider
.complete(&prompt, &[system_message], &[])
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to search tools: {}", e)))?;

// Extract just the message content from the response
let (message, _usage) = response;
let text = message.content[0].as_text().unwrap_or_default();

// Split the response into individual tool entries
let tool_entries: Vec<Content> = text
.split("\n\n")
.filter(|entry| entry.trim().starts_with("Tool:"))
.map(|entry| {
Content::Text(TextContent {
text: entry.trim().to_string(),
annotations: None,
})
})
.collect();

Ok(tool_entries)
} else {
Ok(vec![])
}
}

async fn index_tools(&self, tools: &[Tool]) -> Result<(), ToolError> {
let mut tool_strings = self.tool_strings.write().await;

for tool in tools {
let tool_string = format!(
"Tool: {}\nDescription: {}\nSchema: {}",
tool.name,
tool.description,
serde_json::to_string_pretty(&tool.input_schema)
.unwrap_or_else(|_| "{}".to_string())
);

if let Some(extension_name) = tool.name.split("__").next() {
let entry = tool_strings.entry(extension_name.to_string()).or_default();
if !entry.is_empty() {
entry.push_str("\n\n");
}
entry.push_str(&tool_string);
}
}

Ok(())
}

async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError> {
let mut tool_strings = self.tool_strings.write().await;
if let Some(extension_name) = tool_name.split("__").next() {
tool_strings.remove(extension_name);
}
Ok(())
}

async fn record_tool_call(&self, tool_name: &str) -> Result<(), ToolError> {
let mut recent_calls = self.recent_tool_calls.write().await;
if recent_calls.len() >= 100 {
recent_calls.pop_front();
}
recent_calls.push_back(tool_name.to_string());
Ok(())
}

async fn get_recent_tool_calls(&self, limit: usize) -> Result<Vec<String>, ToolError> {
let recent_calls = self.recent_tool_calls.read().await;
Ok(recent_calls.iter().rev().take(limit).cloned().collect())
}

fn selector_type(&self) -> RouterToolSelectionStrategy {
RouterToolSelectionStrategy::Llm
}
}

// Helper function to create a boxed tool selector
pub async fn create_tool_selector(
strategy: Option<RouterToolSelectionStrategy>,
provider: Arc<dyn Provider>,
table_name: String,
table_name: Option<String>,
) -> Result<Box<dyn RouterToolSelector>> {
match strategy {
Some(RouterToolSelectionStrategy::Vector) => {
let selector = VectorToolSelector::new(provider, table_name).await?;
let selector = VectorToolSelector::new(provider, table_name.unwrap()).await?;
Ok(Box::new(selector))
}
Some(RouterToolSelectionStrategy::Llm) => {
let selector = LLMToolSelector::new(provider).await?;
Ok(Box::new(selector))
}
None => {
let selector = VectorToolSelector::new(provider, table_name).await?;
let selector = LLMToolSelector::new(provider).await?;
Ok(Box::new(selector))
}
}
Expand Down
Loading
Loading