diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index bd3ca18117ef..7e49c50c70f4 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -824,6 +824,11 @@ pub fn remove_extension_dialog() -> Result<(), Box> { pub async fn configure_settings_dialog() -> Result<(), Box> { let setting_type = cliclack::select("What setting would you like to configure?") .item("goose_mode", "Goose Mode", "Configure Goose mode") + .item( + "goose_router_strategy", + "Router Tool Selection Strategy", + "Configure the strategy for selecting tools to use", + ) .item( "tool_permission", "Tool Permission", @@ -850,6 +855,9 @@ pub async fn configure_settings_dialog() -> Result<(), Box> { "goose_mode" => { configure_goose_mode_dialog()?; } + "goose_router_strategy" => { + configure_goose_router_strategy_dialog()?; + } "tool_permission" => { configure_tool_permissions_dialog().await.and(Ok(()))?; } @@ -921,6 +929,49 @@ pub fn configure_goose_mode_dialog() -> Result<(), Box> { Ok(()) } +pub fn configure_goose_router_strategy_dialog() -> Result<(), Box> { + let config = Config::global(); + + // Check if GOOSE_ROUTER_STRATEGY is set as an environment variable + if std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY").is_ok() { + let _ = cliclack::log::info("Notice: GOOSE_ROUTER_TOOL_SELECTION_STRATEGY environment variable is set. Configuration will override this."); + } + + let strategy = cliclack::select("Which router strategy would you like to use?") + .item( + "vector", + "Vector Strategy", + "Use vector-based similarity to select tools", + ) + .item( + "default", + "Default Strategy", + "Use the default tool selection strategy", + ) + .interact()?; + + match strategy { + "vector" => { + config.set_param( + "GOOSE_ROUTER_TOOL_SELECTION_STRATEGY", + Value::String("vector".to_string()), + )?; + cliclack::outro( + "Set to Vector Strategy - using vector-based similarity for tool selection", + )?; + } + "default" => { + config.set_param( + "GOOSE_ROUTER_TOOL_SELECTION_STRATEGY", + Value::String("default".to_string()), + )?; + cliclack::outro("Set to Default Strategy - using default tool selection")?; + } + _ => unreachable!(), + }; + Ok(()) +} + pub fn configure_tool_output_dialog() -> Result<(), Box> { let config = Config::global(); // Check if GOOSE_CLI_MIN_PRIORITY is set as an environment variable diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index a97d5eb0769b..537340437ab2 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -665,17 +665,17 @@ impl Agent { } async fn update_router_tool_selector(&self, provider: Arc) -> Result<()> { - let router_tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") - .ok() - .and_then(|s| { - if s.eq_ignore_ascii_case("vector") { - Some(RouterToolSelectionStrategy::Vector) - } else { - None - } - }); + let config = Config::global(); + let router_tool_selection_strategy = config + .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .unwrap_or_else(|_| "default".to_string()); - if let Some(strategy) = router_tool_selection_strategy { + let strategy = match router_tool_selection_strategy.to_lowercase().as_str() { + "vector" => Some(RouterToolSelectionStrategy::Vector), + _ => None, + }; + + if let Some(strategy) = strategy { let selector = create_tool_selector(Some(strategy), provider) .await .map_err(|e| anyhow!("Failed to create tool selector: {}", e))?; @@ -763,9 +763,12 @@ impl Agent { reindex_all: bool, ) -> Result<()> { // Only proceed if vector strategy is enabled - let is_vector_enabled = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") - .map(|s| s.eq_ignore_ascii_case("vector")) - .unwrap_or(false); + let config = Config::global(); + let router_tool_selection_strategy = config + .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .unwrap_or_else(|_| "default".to_string()); + + let is_vector_enabled = router_tool_selection_strategy.eq_ignore_ascii_case("vector"); if !is_vector_enabled { return Ok(()); diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 72515c4a52ec..5833d54482ab 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -3,6 +3,7 @@ use std::collections::HashSet; use std::sync::Arc; use crate::agents::router_tool_selector::RouterToolSelectionStrategy; +use crate::config::Config; use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::{Provider, ProviderUsage}; use crate::providers::errors::ProviderError; @@ -19,16 +20,17 @@ impl Agent { pub(crate) async fn prepare_tools_and_prompt( &self, ) -> anyhow::Result<(Vec, Vec, String)> { - // Get tool selection strategy - let tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") - .ok() - .and_then(|s| { - if s.eq_ignore_ascii_case("vector") { - Some(RouterToolSelectionStrategy::Vector) - } else { - None - } - }); + // Get tool selection strategy from config + let config = Config::global(); + let router_tool_selection_strategy = config + .get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY") + .unwrap_or_else(|_| "default".to_string()); + + let tool_selection_strategy = match router_tool_selection_strategy.to_lowercase().as_str() { + "vector" => Some(RouterToolSelectionStrategy::Vector), + _ => None, + }; + // Get tools from extension manager let mut tools = match tool_selection_strategy { Some(RouterToolSelectionStrategy::Vector) => { diff --git a/crates/goose/src/agents/router_tool_selector.rs b/crates/goose/src/agents/router_tool_selector.rs index ad654f09c56d..efb0898e828b 100644 --- a/crates/goose/src/agents/router_tool_selector.rs +++ b/crates/goose/src/agents/router_tool_selector.rs @@ -173,7 +173,7 @@ pub async fn create_tool_selector( let selector = VectorToolSelector::new(provider).await?; Ok(Box::new(selector)) } - _ => { + None => { let selector = VectorToolSelector::new(provider).await?; Ok(Box::new(selector)) }