Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions crates/goose-server/src/routes/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub struct GetToolsQuery {
extension_name: Option<String>,
}

#[derive(Serialize)]
struct ErrorResponse {
error: String,
}

async fn get_versions() -> Json<VersionsResponse> {
let versions = ["goose".to_string()];
let default_version = "goose".to_string();
Expand Down Expand Up @@ -217,12 +222,56 @@ async fn update_agent_provider(
Ok(StatusCode::OK)
}

#[utoipa::path(
post,
path = "/agent/update_router_tool_selector",
responses(
(status = 200, description = "Tool selection strategy updated successfully", body = String),
(status = 500, description = "Internal server error")
)
)]
async fn update_router_tool_selector(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> Result<Json<String>, Json<ErrorResponse>> {
verify_secret_key(&headers, &state).map_err(|_| {
Json(ErrorResponse {
error: "Unauthorized - Invalid or missing API key".to_string(),
})
})?;

let agent = state.get_agent().await.map_err(|e| {
tracing::error!("Failed to get agent: {}", e);
Json(ErrorResponse {
error: format!("Failed to get agent: {}", e),
})
})?;

agent
.update_router_tool_selector(None, Some(true))
.await
.map_err(|e| {
tracing::error!("Failed to update tool selection strategy: {}", e);
Json(ErrorResponse {
error: format!("Failed to update tool selection strategy: {}", e),
})
})?;

Ok(Json(
"Tool selection strategy updated successfully".to_string(),
))
}

pub fn routes(state: Arc<AppState>) -> Router {
Router::new()
.route("/agent/versions", get(get_versions))
.route("/agent/providers", get(list_providers))
.route("/agent/prompt", post(extend_prompt))
.route("/agent/tools", get(get_tools))
.route("/agent/update_provider", post(update_agent_provider))
.route(
"/agent/update_router_tool_selector",
post(update_router_tool_selector),
)
.with_state(state)
}
103 changes: 67 additions & 36 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,31 @@ impl Agent {
) -> (String, Result<Vec<Content>, ToolError>) {
let mut extension_manager = self.extension_manager.lock().await;

let selector = self.router_tool_selector.lock().await.clone();
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if let Some(selector) = selector {
let selector_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.lock().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
selector_action,
)
.await
{
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to update vector index: {}",
e
))),
);
}
}
}

if action == "disable" {
let result = extension_manager
.remove_extension(&extension_name)
Expand Down Expand Up @@ -351,34 +376,6 @@ impl Agent {
})
.map_err(|e| ToolError::ExecutionError(e.to_string()));

// Update vector index if operation was successful and vector routing is enabled
if result.is_ok() {
let selector = self.router_tool_selector.lock().await.clone();
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if let Some(selector) = selector {
let vector_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.lock().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
vector_action,
)
.await
{
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to update vector index: {}",
e
))),
);
}
}
}
}

(request_id, result)
}

Expand Down Expand Up @@ -503,9 +500,6 @@ impl Agent {
}

pub async fn remove_extension(&self, name: &str) -> Result<()> {
let mut extension_manager = self.extension_manager.lock().await;
extension_manager.remove_extension(name).await?;

// If vector tool selection is enabled, remove tools from the index
let selector = self.router_tool_selector.lock().await.clone();
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
Expand All @@ -521,6 +515,9 @@ impl Agent {
}
}

let mut extension_manager = self.extension_manager.lock().await;
extension_manager.remove_extension(name).await?;

Ok(())
}

Expand Down Expand Up @@ -809,12 +806,23 @@ impl Agent {
/// Update the provider used by this agent
pub async fn update_provider(&self, provider: Arc<dyn Provider>) -> Result<()> {
*self.provider.lock().await = Some(provider.clone());
self.update_router_tool_selector(provider).await?;
self.update_router_tool_selector(Some(provider), None)
.await?;
Ok(())
}

async fn update_router_tool_selector(&self, provider: Arc<dyn Provider>) -> Result<()> {
pub async fn update_router_tool_selector(
&self,
provider: Option<Arc<dyn Provider>>,
reindex_all: Option<bool>,
) -> Result<()> {
let config = Config::global();
let extension_manager = self.extension_manager.lock().await;
let provider = match provider {
Some(p) => p,
None => self.provider().await?,
};

let router_tool_selection_strategy = config
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.unwrap_or_else(|_| "default".to_string());
Expand All @@ -828,21 +836,44 @@ impl Agent {
let selector = match strategy {
Some(RouterToolSelectionStrategy::Vector) => {
let table_name = generate_table_id();
let selector = create_tool_selector(strategy, provider, Some(table_name))
let selector = create_tool_selector(strategy, provider.clone(), Some(table_name))
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
Arc::new(selector)
}
Some(RouterToolSelectionStrategy::Llm) => {
let selector = create_tool_selector(strategy, provider, None)
let selector = create_tool_selector(strategy, provider.clone(), None)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
Arc::new(selector)
}
None => return Ok(()),
};
let extension_manager = self.extension_manager.lock().await;

// First index platform tools
ToolRouterIndexManager::index_platform_tools(&selector, &extension_manager).await?;

if reindex_all.unwrap_or(false) {
let enabled_extensions = extension_manager.list_extensions().await?;
for extension_name in enabled_extensions {
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
"add",
)
.await
{
tracing::error!(
"Failed to index tools for extension {}: {}",
extension_name,
e
);
}
}
}

// Update the selector
*self.router_tool_selector.lock().await = Some(selector.clone());
Ok(())
}
Expand Down
44 changes: 35 additions & 9 deletions crates/goose/src/agents/router_tool_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,36 @@ impl RouterToolSelector for VectorToolSelector {
})
.collect();

// Index all tools at once
// Get vector_db lock
let vector_db = self.vector_db.read().await;
vector_db
.index_tools(tool_records)
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to index tools: {}", e)))?;

// Filter out tools that already exist in the database
let mut new_tool_records = Vec::new();
for record in tool_records {
// Check if tool exists by searching for it
let existing_tools = vector_db
.search_tools(record.vector.clone(), 1, Some(&record.extension_name))
.await
.map_err(|e| {
ToolError::ExecutionError(format!("Failed to search for existing tools: {}", e))
})?;

// Only add if no exact match found
if !existing_tools
.iter()
.any(|t| t.tool_name == record.tool_name)
{
new_tool_records.push(record);
}
}

// Only index if there are new tools to add
if !new_tool_records.is_empty() {
vector_db
.index_tools(new_tool_records)
.await
.map_err(|e| ToolError::ExecutionError(format!("Failed to index tools: {}", e)))?;
}

Ok(())
}
Expand Down Expand Up @@ -282,7 +306,7 @@ impl RouterToolSelector for LLMToolSelector {
}
}

async fn index_tools(&self, tools: &[Tool], _extension_name: &str) -> Result<(), ToolError> {
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError> {
let mut tool_strings = self.tool_strings.write().await;

for tool in tools {
Expand All @@ -294,8 +318,11 @@ impl RouterToolSelector for LLMToolSelector {
.unwrap_or_else(|_| "{}".to_string())
);

if let Some(extension_name) = tool.name.split("__").next() {
let entry = tool_strings.entry(extension_name.to_string()).or_default();
// Use the provided extension_name instead of parsing from tool name
let entry = tool_strings.entry(extension_name.to_string()).or_default();

// Check if this tool already exists in the entry
if !entry.contains(&format!("Tool: {}", tool.name)) {
if !entry.is_empty() {
entry.push_str("\n\n");
}
Expand All @@ -305,7 +332,6 @@ impl RouterToolSelector for LLMToolSelector {

Ok(())
}

async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError> {
let mut tool_strings = self.tool_strings.write().await;
if let Some(extension_name) = tool_name.split("__").next() {
Expand Down
17 changes: 15 additions & 2 deletions crates/goose/src/agents/router_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,26 @@ pub fn vector_search_tool() -> Tool {
}

pub fn vector_search_tool_prompt() -> String {
r#"# Tool Selection Instructions
format!(
r#"# Tool Selection Instructions
Important: the user has opted to dynamically enable tools, so although an extension could be enabled, \
please invoke the vector search tool to actually retrieve the most relevant tools to use according to the user's messages.
For example, if the user has 3 extensions enabled, but they are asking for a tool to read a pdf file, \
you would invoke the vector_search tool to find the most relevant read pdf tool.
By dynamically enabling tools, you (Goose) as the agent save context window space and allow the user to dynamically retrieve the most relevant tools.
Be sure to format the query to search rather than pass in the user's messages directly."#.to_string()
Be sure to format the query to search rather than pass in the user's messages directly.
In addition to the extension names available to you, you also have platform extension tools available to you.
The platform extension contains the following tools:
- {}
- {}
- {}
- {}
"#,
PLATFORM_SEARCH_AVAILABLE_EXTENSIONS_TOOL_NAME,
PLATFORM_MANAGE_EXTENSIONS_TOOL_NAME,
PLATFORM_READ_RESOURCE_TOOL_NAME,
PLATFORM_LIST_RESOURCES_TOOL_NAME
)
}

pub fn llm_search_tool() -> Tool {
Expand Down
Loading
Loading