Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose-cli/src/session/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ pub async fn build_session(session_config: SessionBuilderConfig) -> Session {
// Extensions need to be added after the session is created because we change directory when resuming a session
// If we get extensions_override, only run those extensions and none other
let extensions_to_run: Vec<_> = if let Some(extensions) = session_config.extensions_override {
agent.disable_router_for_recipe().await;
extensions.into_iter().collect()
} else {
ExtensionConfigManager::get_all()
Expand Down
29 changes: 16 additions & 13 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pub struct Agent {
pub(super) tool_result_rx: ToolResultReceiver,
pub(super) tool_monitor: Arc<Mutex<Option<ToolMonitor>>>,
pub(super) router_tool_selector: Mutex<Option<Arc<Box<dyn RouterToolSelector>>>>,
pub(super) router_disabled_override: Mutex<bool>,
pub(super) scheduler_service: Mutex<Option<Arc<dyn SchedulerTrait>>>,
pub(super) retry_manager: RetryManager,
}
Expand Down Expand Up @@ -172,6 +173,7 @@ impl Agent {
tool_result_rx: Arc::new(Mutex::new(tool_rx)),
tool_monitor,
router_tool_selector: Mutex::new(None),
router_disabled_override: Mutex::new(false),
scheduler_service: Mutex::new(None),
retry_manager,
}
Expand Down Expand Up @@ -336,6 +338,11 @@ impl Agent {
*scheduler_service = Some(scheduler);
}

pub async fn disable_router_for_recipe(&self) {
*self.router_disabled_override.lock().await = true;
*self.router_tool_selector.lock().await = None;
}

/// Get a reference count clone to the provider
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
match &*self.provider.lock().await {
Expand Down Expand Up @@ -477,7 +484,7 @@ impl Agent {
|| tool_call.name == ROUTER_LLM_SEARCH_TOOL_NAME
{
let selector = self.router_tool_selector.lock().await.clone();
let mut selected_tools = match selector.as_ref() {
let selected_tools = match selector.as_ref() {
Some(selector) => match selector.select_tools(tool_call.arguments.clone()).await {
Ok(tools) => tools,
Err(e) => {
Expand All @@ -500,18 +507,6 @@ impl Agent {
}
};

// Append final_output tool if present (for structured output recipes, [Issue #3700](https://github.com/block/goose/issues/3700)
if let Some(final_output_tool) = self.final_output_tool.lock().await.as_ref() {
let tool = final_output_tool.tool();
let tool_content = Content::text(format!(
"Tool: {}\nDescription: {}\nSchema: {}",
tool.name,
tool.description.unwrap_or_default(),
serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default()
));
selected_tools.push(tool_content);
}

ToolCallResult::from(Ok(selected_tools))
} else {
// Clone the result to ensure no references to extension_manager are returned
Expand Down Expand Up @@ -750,6 +745,10 @@ impl Agent {
&self,
strategy: Option<RouterToolSelectionStrategy>,
) -> Vec<Tool> {
if *self.router_disabled_override.lock().await {
return vec![];
}

let mut prefixed_tools = vec![];
match strategy {
Some(RouterToolSelectionStrategy::Vector) => {
Expand Down Expand Up @@ -1151,6 +1150,10 @@ impl Agent {
provider: Option<Arc<dyn Provider>>,
reindex_all: Option<bool>,
) -> Result<()> {
if *self.router_disabled_override.lock().await {
return Ok(());
}

let config = Config::global();
let _extension_manager = self.extension_manager.read().await;
let provider = match provider {
Expand Down
23 changes: 14 additions & 9 deletions crates/goose/src/agents/reply_parts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,21 @@ impl Agent {
};

// Get tools from extension manager
let mut tools = match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Vector) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Vector))
.await
}
Some(RouterToolSelectionStrategy::Llm) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Llm))
.await
let mut tools = if *self.router_disabled_override.lock().await {
// If router is disabled, use regular tools
self.list_tools(None).await
} else {
match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Vector) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Vector))
.await
}
Some(RouterToolSelectionStrategy::Llm) => {
self.list_tools_for_router(Some(RouterToolSelectionStrategy::Llm))
.await
}
_ => self.list_tools(None).await,
}
_ => self.list_tools(None).await,
};
// Add frontend tools
let frontend_tools = self.frontend_tools.lock().await;
Expand Down
Loading