Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1170,40 +1170,32 @@ pub fn configure_goose_mode_dialog() -> Result<(), Box<dyn Error>> {
pub fn configure_goose_router_strategy_dialog() -> Result<(), Box<dyn Error>> {
let config = Config::global();

// Check if GOOSE_ROUTER_STRATEGY is set as an environment variable
if std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY").is_ok() {
let _ = cliclack::log::info("Notice: GOOSE_ROUTER_TOOL_SELECTION_STRATEGY environment variable is set. Configuration will override this.");
// Check if GOOSE_ENABLE_ROUTER is set as an environment variable
if std::env::var("GOOSE_ENABLE_ROUTER").is_ok() {
let _ = cliclack::log::info("Notice: GOOSE_ENABLE_ROUTER environment variable is set. Configuration will override this.");
}

let strategy = cliclack::select("Which router strategy would you like to use?")
let enable_router = cliclack::select("Would you like to enable LLM-based tool routing?")
.item(
"llm",
"LLM Strategy",
"true",
"Enable Router",
"Use LLM-based intelligence to select tools",
)
.item(
"default",
"Default Strategy",
"false",
"Disable Router",
"Use the default tool selection strategy",
)
.interact()?;

match strategy {
"llm" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("llm".to_string()),
)?;
cliclack::outro(
"Set to LLM Strategy - using LLM-based intelligence for tool selection",
)?;
match enable_router {
"true" => {
config.set_param("GOOSE_ENABLE_ROUTER", Value::String("true".to_string()))?;
cliclack::outro("Router enabled - using LLM-based intelligence for tool selection")?;
}
"default" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("default".to_string()),
)?;
cliclack::outro("Set to Default Strategy - using default tool selection")?;
"false" => {
config.set_param("GOOSE_ENABLE_ROUTER", Value::String("false".to_string()))?;
cliclack::outro("Router disabled - using default tool selection")?;
}
_ => unreachable!(),
};
Expand Down
74 changes: 34 additions & 40 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use crate::agents::recipe_tools::dynamic_task_tools::{
create_dynamic_task, create_dynamic_task_tool, DYNAMIC_TASK_TOOL_NAME_PREFIX,
};
use crate::agents::retry::{RetryManager, RetryResult};
use crate::agents::router_tool_selector::{RouterToolSelectionStrategy, RouterToolSelector};
use crate::agents::router_tools::ROUTER_LLM_SEARCH_TOOL_NAME;
use crate::agents::sub_recipe_manager::SubRecipeManager;
use crate::agents::subagent_execution_tool::subagent_execute_task_tool::{
Expand Down Expand Up @@ -505,8 +504,8 @@ impl Agent {
extension_name: String,
request_id: String,
) -> (String, Result<Vec<Content>, ToolError>) {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if let Some(selector) = selector {
let selector_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.read().await;
Expand Down Expand Up @@ -577,30 +576,28 @@ impl Agent {
.map_err(|e| ToolError::ExecutionError(e.to_string()));

drop(extension_manager);
// Update LLM index if operation was successful and LLM routing is enabled
if result.is_ok() {
// Update LLM index if operation was successful and LLM routing is functional
if result.is_ok() && self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
if let Some(selector) = selector {
let llm_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.read().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
llm_action,
)
.await
{
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to update LLM index: {}",
e
))),
);
}
if let Some(selector) = selector {
let llm_action = if action == "disable" { "remove" } else { "add" };
let extension_manager = self.extension_manager.read().await;
let selector = Arc::new(selector);
if let Err(e) = ToolRouterIndexManager::update_extension_tools(
&selector,
&extension_manager,
&extension_name,
llm_action,
)
.await
{
return (
request_id,
Err(ToolError::ExecutionError(format!(
"Failed to update LLM index: {}",
e
))),
);
}
}
}
Expand Down Expand Up @@ -641,10 +638,9 @@ impl Agent {
}
}

// If LLM tool selection is enabled, index the tools
let selector: Option<Arc<Box<dyn RouterToolSelector>>> =
self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
// If LLM tool selection is functional, index the tools
if self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if let Some(selector) = selector {
let extension_manager = self.extension_manager.read().await;
let selector = Arc::new(selector);
Expand Down Expand Up @@ -708,12 +704,9 @@ impl Agent {
prefixed_tools
}

pub async fn list_tools_for_router(
&self,
strategy: Option<RouterToolSelectionStrategy>,
) -> Vec<Tool> {
pub async fn list_tools_for_router(&self) -> Vec<Tool> {
self.tool_route_manager
.list_tools_for_router(strategy, &self.extension_manager)
.list_tools_for_router(&self.extension_manager)
.await
}

Expand All @@ -722,9 +715,9 @@ impl Agent {
extension_manager.remove_extension(name).await?;
drop(extension_manager);

// If LLM tool selection is enabled, remove tools from the index
let selector = self.tool_route_manager.get_router_tool_selector().await;
if ToolRouterIndexManager::is_tool_router_enabled(&selector) {
// If LLM tool selection is functional, remove tools from the index
if self.tool_route_manager.is_router_functional().await {
let selector = self.tool_route_manager.get_router_tool_selector().await;
if let Some(selector) = selector {
let extension_manager = self.extension_manager.read().await;
ToolRouterIndexManager::update_extension_tools(
Expand Down Expand Up @@ -1257,13 +1250,14 @@ impl Agent {
let model_config = provider.get_model_config();
let model_name = &model_config.model_name;

let router_enabled = self.tool_route_manager.is_router_enabled().await;
let prompt_manager = self.prompt_manager.lock().await;
let system_prompt = prompt_manager.build_system_prompt(
extensions_info,
self.frontend_instructions.lock().await.clone(),
extension_manager.suggest_disable_extensions_prompt().await,
Some(model_name),
None,
router_enabled,
);

let recipe_prompt = prompt_manager.get_recipe_prompt().await;
Expand Down Expand Up @@ -1422,7 +1416,7 @@ mod tests {

let prompt_manager = agent.prompt_manager.lock().await;
let system_prompt =
prompt_manager.build_system_prompt(vec![], None, Value::Null, None, None);
prompt_manager.build_system_prompt(vec![], None, Value::Null, None, false);

let final_output_tool_ref = agent.final_output_tool.lock().await;
let final_output_tool_system_prompt =
Expand Down
16 changes: 6 additions & 10 deletions crates/goose/src/agents/prompt_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use serde_json::Value;
use std::collections::HashMap;

use crate::agents::extension::ExtensionInfo;
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::agents::router_tools::llm_search_tool_prompt;
use crate::providers::base::get_current_model;
use crate::{config::Config, prompt_template};
Expand Down Expand Up @@ -69,7 +68,7 @@ impl PromptManager {
frontend_instructions: Option<String>,
suggest_disable_extensions_prompt: Value,
model_name: Option<&str>,
tool_selection_strategy: Option<RouterToolSelectionStrategy>,
router_enabled: bool,
) -> String {
let mut context: HashMap<&str, Value> = HashMap::new();
let mut extensions_info = extensions_info.clone();
Expand All @@ -85,14 +84,11 @@ impl PromptManager {

context.insert("extensions", serde_json::to_value(extensions_info).unwrap());

match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Llm) => {
context.insert(
"tool_selection_strategy",
Value::String(llm_search_tool_prompt()),
);
}
None => {}
if router_enabled {
context.insert(
"tool_selection_strategy",
Value::String(llm_search_tool_prompt()),
);
}

context.insert(
Expand Down
24 changes: 10 additions & 14 deletions crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use async_stream::try_stream;
use futures::stream::StreamExt;

use super::super::agents::Agent;
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::conversation::message::{Message, MessageContent, ToolRequest};
use crate::conversation::Conversation;
use crate::providers::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage};
Expand Down Expand Up @@ -34,20 +33,17 @@ async fn toolshim_postprocess(
impl Agent {
/// Prepares tools and system prompt for a provider request
pub async fn prepare_tools_and_prompt(&self) -> anyhow::Result<(Vec<Tool>, Vec<Tool>, String)> {
// Get tool selection strategy from config
let tool_selection_strategy = self
.tool_route_manager
.get_router_tool_selection_strategy()
.await;
// Get router enabled status
let router_enabled = self.tool_route_manager.is_router_enabled().await;

// Get tools from extension manager
let mut tools = match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Llm) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Llm))
.await
}
_ => self.list_tools(None).await,
};
let mut tools = self.list_tools_for_router().await;

// If router is disabled and no tools were returned, fall back to regular tools
if !router_enabled && tools.is_empty() {
tools = self.list_tools(None).await;
}

// Add frontend tools
let frontend_tools = self.frontend_tools.lock().await;
for frontend_tool in frontend_tools.values() {
Expand All @@ -69,7 +65,7 @@ impl Agent {
self.frontend_instructions.lock().await.clone(),
extension_manager.suggest_disable_extensions_prompt().await,
Some(model_name),
tool_selection_strategy,
router_enabled,
);

// Handle toolshim if enabled
Expand Down
23 changes: 2 additions & 21 deletions crates/goose/src/agents/router_tool_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,13 @@ struct ToolSelectorContext {
query: String,
}

#[derive(Debug, Clone, PartialEq)]
pub enum RouterToolSelectionStrategy {
Llm,
}

#[async_trait]
pub trait RouterToolSelector: Send + Sync {
async fn select_tools(&self, params: Value) -> Result<Vec<Content>, ToolError>;
async fn index_tools(&self, tools: &[Tool], extension_name: &str) -> Result<(), ToolError>;
async fn remove_tool(&self, tool_name: &str) -> Result<(), ToolError>;
async fn record_tool_call(&self, tool_name: &str) -> Result<(), ToolError>;
async fn get_recent_tool_calls(&self, limit: usize) -> Result<Vec<String>, ToolError>;
fn selector_type(&self) -> RouterToolSelectionStrategy;
}

pub struct LLMToolSelector {
Expand Down Expand Up @@ -166,25 +160,12 @@ impl RouterToolSelector for LLMToolSelector {
let recent_calls = self.recent_tool_calls.read().await;
Ok(recent_calls.iter().rev().take(limit).cloned().collect())
}

fn selector_type(&self) -> RouterToolSelectionStrategy {
RouterToolSelectionStrategy::Llm
}
}

// Helper function to create a boxed tool selector
pub async fn create_tool_selector(
strategy: Option<RouterToolSelectionStrategy>,
provider: Arc<dyn Provider>,
) -> Result<Box<dyn RouterToolSelector>> {
match strategy {
Some(RouterToolSelectionStrategy::Llm) => {
let selector = LLMToolSelector::new(provider).await?;
Ok(Box::new(selector))
}
None => {
let selector = LLMToolSelector::new(provider).await?;
Ok(Box::new(selector))
}
}
let selector = LLMToolSelector::new(provider).await?;
Ok(Box::new(selector))
}
Loading