Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,11 @@ pub fn remove_extension_dialog() -> Result<(), Box<dyn Error>> {
pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
let setting_type = cliclack::select("What setting would you like to configure?")
.item("goose_mode", "Goose Mode", "Configure Goose mode")
.item(
"goose_router_strategy",
"Router Tool Selection Strategy",
"Configure the strategy for selecting tools to use",
)
.item(
"tool_permission",
"Tool Permission",
Expand All @@ -850,6 +855,9 @@ pub async fn configure_settings_dialog() -> Result<(), Box<dyn Error>> {
"goose_mode" => {
configure_goose_mode_dialog()?;
}
"goose_router_strategy" => {
configure_goose_router_strategy_dialog()?;
}
"tool_permission" => {
configure_tool_permissions_dialog().await.and(Ok(()))?;
}
Expand Down Expand Up @@ -921,6 +929,49 @@ pub fn configure_goose_mode_dialog() -> Result<(), Box<dyn Error>> {
Ok(())
}

pub fn configure_goose_router_strategy_dialog() -> Result<(), Box<dyn Error>> {
let config = Config::global();

// Check if GOOSE_ROUTER_STRATEGY is set as an environment variable
if std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY").is_ok() {
let _ = cliclack::log::info("Notice: GOOSE_ROUTER_TOOL_SELECTION_STRATEGY environment variable is set. Configuration will override this.");
}

let strategy = cliclack::select("Which router strategy would you like to use?")
.item(
"vector",
"Vector Strategy",
"Use vector-based similarity to select tools",
)
.item(
"default",
"Default Strategy",
"Use the default tool selection strategy",
)
.interact()?;

match strategy {
"vector" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("vector".to_string()),
)?;
cliclack::outro(
"Set to Vector Strategy - using vector-based similarity for tool selection",
)?;
}
"default" => {
config.set_param(
"GOOSE_ROUTER_TOOL_SELECTION_STRATEGY",
Value::String("default".to_string()),
)?;
cliclack::outro("Set to Default Strategy - using default tool selection")?;
}
_ => unreachable!(),
};
Ok(())
}

pub fn configure_tool_output_dialog() -> Result<(), Box<dyn Error>> {
let config = Config::global();
// Check if GOOSE_CLI_MIN_PRIORITY is set as an environment variable
Expand Down
29 changes: 16 additions & 13 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,17 +665,17 @@ impl Agent {
}

async fn update_router_tool_selector(&self, provider: Arc<dyn Provider>) -> Result<()> {
let router_tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.ok()
.and_then(|s| {
if s.eq_ignore_ascii_case("vector") {
Some(RouterToolSelectionStrategy::Vector)
} else {
None
}
});
let config = Config::global();
let router_tool_selection_strategy = config
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.unwrap_or_else(|_| "default".to_string());

if let Some(strategy) = router_tool_selection_strategy {
let strategy = match router_tool_selection_strategy.to_lowercase().as_str() {
"vector" => Some(RouterToolSelectionStrategy::Vector),
_ => None,
};

if let Some(strategy) = strategy {
let selector = create_tool_selector(Some(strategy), provider)
.await
.map_err(|e| anyhow!("Failed to create tool selector: {}", e))?;
Expand Down Expand Up @@ -763,9 +763,12 @@ impl Agent {
reindex_all: bool,
) -> Result<()> {
// Only proceed if vector strategy is enabled
let is_vector_enabled = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.map(|s| s.eq_ignore_ascii_case("vector"))
.unwrap_or(false);
let config = Config::global();
let router_tool_selection_strategy = config
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.unwrap_or_else(|_| "default".to_string());

let is_vector_enabled = router_tool_selection_strategy.eq_ignore_ascii_case("vector");

if !is_vector_enabled {
return Ok(());
Expand Down
22 changes: 12 additions & 10 deletions crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashSet;
use std::sync::Arc;

use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::config::Config;
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::{Provider, ProviderUsage};
use crate::providers::errors::ProviderError;
Expand All @@ -19,16 +20,17 @@ impl Agent {
pub(crate) async fn prepare_tools_and_prompt(
&self,
) -> anyhow::Result<(Vec<Tool>, Vec<Tool>, String)> {
// Get tool selection strategy
let tool_selection_strategy = std::env::var("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.ok()
.and_then(|s| {
if s.eq_ignore_ascii_case("vector") {
Some(RouterToolSelectionStrategy::Vector)
} else {
None
}
});
// Get tool selection strategy from config
let config = Config::global();
let router_tool_selection_strategy = config
.get_param("GOOSE_ROUTER_TOOL_SELECTION_STRATEGY")
.unwrap_or_else(|_| "default".to_string());

let tool_selection_strategy = match router_tool_selection_strategy.to_lowercase().as_str() {
"vector" => Some(RouterToolSelectionStrategy::Vector),
_ => None,
};

// Get tools from extension manager
let mut tools = match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Vector) => {
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/router_tool_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ pub async fn create_tool_selector(
let selector = VectorToolSelector::new(provider).await?;
Ok(Box::new(selector))
}
_ => {
None => {
let selector = VectorToolSelector::new(provider).await?;
Ok(Box::new(selector))
}
Expand Down