From 7d5d18e7d0643478e62a1b8ae39b63c0aed9ea33 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 23 Jul 2025 13:18:11 +1000 Subject: [PATCH 1/5] reuse the cancellation token in the agent level --- crates/goose/src/agents/agent.rs | 13 ++ .../agents/sub_recipe_execution_tool/mod.rs | 11 -- .../agents/sub_recipe_execution_tool/tasks.rs | 186 ------------------ .../sub_recipe_execution_tool/workers.rs | 31 --- .../subagent_execution_tool/executor/mod.rs | 22 ++- .../agents/subagent_execution_tool/lib/mod.rs | 14 +- .../subagent_execute_task_tool.rs | 3 + .../task_execution_tracker.rs | 55 +++++- .../subagent_execution_tool/task_types.rs | 2 + .../agents/subagent_execution_tool/tasks.rs | 55 +++++- .../agents/subagent_execution_tool/workers.rs | 39 ++-- 11 files changed, 174 insertions(+), 257 deletions(-) delete mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/mod.rs delete mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs delete mode 100644 crates/goose/src/agents/sub_recipe_execution_tool/workers.rs diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 40fe5563eee0..13f16a53d124 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -75,6 +75,7 @@ pub struct Agent { pub(super) tool_monitor: Arc>>, pub(super) router_tool_selector: Mutex>>>, pub(super) scheduler_service: Mutex>>, + pub(super) current_cancellation_token: Arc>>, pub(super) mcp_tx: Mutex>, pub(super) mcp_notification_rx: Arc>>, pub(super) retry_manager: RetryManager, @@ -154,6 +155,7 @@ impl Agent { tool_monitor, router_tool_selector: Mutex::new(None), scheduler_service: Mutex::new(None), + current_cancellation_token: Arc::new(Mutex::new(None)), // Initialize with MCP notification support mcp_tx: Mutex::new(mcp_tx), mcp_notification_rx: Arc::new(Mutex::new(mcp_rx)), @@ -345,10 +347,15 @@ impl Agent { let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); + + // Get the current cancellation token + let current_token = self.current_cancellation_token.lock().await.clone(); + subagent_execute_task_tool::run_tasks( tool_call.arguments.clone(), task_config, &self.tasks_manager, + current_token, ) .await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { @@ -719,6 +726,12 @@ impl Agent { session: Option, cancel_token: Option, ) -> Result>> { + // Store the cancellation token for use in tool calls + { + let mut current_token = self.current_cancellation_token.lock().await; + *current_token = cancel_token.clone(); + } + let mut messages = messages.to_vec(); let initial_messages = messages.clone(); let reply_span = tracing::Span::current(); diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs b/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs deleted file mode 100644 index 49fcc194c56a..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod executor; -pub mod lib; -pub mod notification_events; -pub mod sub_recipe_execute_task_tool; -mod task_execution_tracker; -mod task_types; -mod tasks; -pub mod tasks_manager; -pub mod utils; -mod workers; - diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs b/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs deleted file mode 100644 index 66f67729e69e..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/tasks.rs +++ /dev/null @@ -1,186 +0,0 @@ -use serde_json::Value; -use std::process::Stdio; -use std::sync::Arc; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::process::Command; - -use crate::agents::sub_recipe_execution_tool::task_execution_tracker::TaskExecutionTracker; -use crate::agents::sub_recipe_execution_tool::task_types::{Task, TaskResult, TaskStatus}; - -pub async fn process_task( - task: &Task, - task_execution_tracker: Arc, -) -> TaskResult { - match get_task_result(task.clone(), task_execution_tracker).await { - Ok(data) => TaskResult { - task_id: task.id.clone(), - status: TaskStatus::Completed, - data: Some(data), - error: None, - }, - Err(error) => TaskResult { - task_id: task.id.clone(), - status: TaskStatus::Failed, - data: None, - error: Some(error), - }, - } -} - -async fn get_task_result( - task: Task, - task_execution_tracker: Arc, -) -> Result { - let (command, output_identifier) = build_command(&task)?; - let (stdout_output, stderr_output, success) = run_command( - command, - &output_identifier, - &task.id, - task_execution_tracker, - ) - .await?; - - if success { - process_output(stdout_output) - } else { - Err(format!("Command failed:\n{}", stderr_output)) - } -} - -fn build_command(task: &Task) -> Result<(Command, String), String> { - let task_error = |field: &str| format!("Task {}: Missing {}", task.id, field); - - let mut output_identifier = task.id.clone(); - let mut command = if task.task_type == "sub_recipe" { - let sub_recipe_name = task - .get_sub_recipe_name() - .ok_or_else(|| task_error("sub_recipe name"))?; - let path = task - .get_sub_recipe_path() - .ok_or_else(|| task_error("sub_recipe path"))?; - let command_parameters = task - .get_command_parameters() - .ok_or_else(|| task_error("command_parameters"))?; - - output_identifier = format!("sub-recipe {}", sub_recipe_name); - let mut cmd = Command::new("goose"); - cmd.arg("run").arg("--recipe").arg(path).arg("--no-session"); - - for (key, value) in command_parameters { - let key_str = key.to_string(); - let value_str = value.as_str().unwrap_or(&value.to_string()).to_string(); - cmd.arg("--params") - .arg(format!("{}={}", key_str, value_str)); - } - cmd - } else { - let text = task - .get_text_instruction() - .ok_or_else(|| task_error("text_instruction"))?; - let mut cmd = Command::new("goose"); - cmd.arg("run").arg("--text").arg(text); - cmd - }; - - command.stdout(Stdio::piped()); - command.stderr(Stdio::piped()); - Ok((command, output_identifier)) -} - -async fn run_command( - mut command: Command, - output_identifier: &str, - task_id: &str, - task_execution_tracker: Arc, -) -> Result<(String, String, bool), String> { - let mut child = command - .spawn() - .map_err(|e| format!("Failed to spawn goose: {}", e))?; - - let stdout = child.stdout.take().expect("Failed to capture stdout"); - let stderr = child.stderr.take().expect("Failed to capture stderr"); - - let stdout_task = spawn_output_reader( - stdout, - output_identifier, - false, - task_id, - task_execution_tracker.clone(), - ); - let stderr_task = spawn_output_reader( - stderr, - output_identifier, - true, - task_id, - task_execution_tracker.clone(), - ); - - let status = child - .wait() - .await - .map_err(|e| format!("Failed to wait for process: {}", e))?; - - let stdout_output = stdout_task.await.unwrap(); - let stderr_output = stderr_task.await.unwrap(); - - Ok((stdout_output, stderr_output, status.success())) -} - -fn spawn_output_reader( - reader: impl tokio::io::AsyncRead + Unpin + Send + 'static, - output_identifier: &str, - is_stderr: bool, - task_id: &str, - task_execution_tracker: Arc, -) -> tokio::task::JoinHandle { - let output_identifier = output_identifier.to_string(); - let task_id = task_id.to_string(); - tokio::spawn(async move { - let mut buffer = String::new(); - let mut lines = BufReader::new(reader).lines(); - while let Ok(Some(line)) = lines.next_line().await { - buffer.push_str(&line); - buffer.push('\n'); - - if !is_stderr { - task_execution_tracker - .send_live_output(&task_id, &line) - .await; - } else { - tracing::warn!("Task stderr [{}]: {}", output_identifier, line); - } - } - buffer - }) -} - -fn extract_json_from_line(line: &str) -> Option { - let start = line.find('{')?; - let end = line.rfind('}')?; - - if start >= end { - return None; - } - - let potential_json = &line[start..=end]; - if serde_json::from_str::(potential_json).is_ok() { - Some(potential_json.to_string()) - } else { - None - } -} - -fn process_output(stdout_output: String) -> Result { - let last_line = stdout_output - .lines() - .filter(|line| !line.trim().is_empty()) - .next_back() - .unwrap_or(""); - - if let Some(json_string) = extract_json_from_line(last_line) { - Ok(Value::String(json_string)) - } else { - Ok(Value::String(stdout_output)) - } -} - diff --git a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs b/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs deleted file mode 100644 index 89473f7c6a65..000000000000 --- a/crates/goose/src/agents/sub_recipe_execution_tool/workers.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::agents::sub_recipe_execution_tool::task_types::{SharedState, Task}; -use crate::agents::sub_recipe_execution_tool::tasks::process_task; -use std::sync::Arc; - -async fn receive_task(state: &SharedState) -> Option { - let mut receiver = state.task_receiver.lock().await; - receiver.recv().await -} - -pub fn spawn_worker(state: Arc, worker_id: usize) -> tokio::task::JoinHandle<()> { - state.increment_active_workers(); - - tokio::spawn(async move { - worker_loop(state, worker_id).await; - }) -} - -async fn worker_loop(state: Arc, _worker_id: usize) { - while let Some(task) = receive_task(&state).await { - state.task_execution_tracker.start_task(&task.id).await; - let result = process_task(&task, state.task_execution_tracker.clone()).await; - - if let Err(e) = state.result_sender.send(result).await { - tracing::error!("Worker failed to send result: {}", e); - break; - } - } - - state.decrement_active_workers(); -} - diff --git a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs index 14726a75470c..665bf9e14b64 100644 --- a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; use tokio::time::Instant; +use tokio_util::sync::CancellationToken; const EXECUTION_STATUS_COMPLETED: &str = "completed"; const DEFAULT_MAX_WORKERS: usize = 10; @@ -21,14 +22,22 @@ pub async fn execute_single_task( task: &Task, notifier: mpsc::Sender, task_config: TaskConfig, + cancellation_token: Option, ) -> ExecutionResponse { let start_time = Instant::now(); let task_execution_tracker = Arc::new(TaskExecutionTracker::new( vec![task.clone()], DisplayMode::SingleTaskOutput, notifier, + cancellation_token.clone(), )); - let result = process_task(task, task_execution_tracker.clone(), task_config).await; + let result = process_task( + task, + task_execution_tracker.clone(), + task_config, + cancellation_token.unwrap_or_default(), + ) + .await; // Complete the task in the tracker task_execution_tracker @@ -49,11 +58,13 @@ pub async fn execute_tasks_in_parallel( tasks: Vec, notifier: Sender, task_config: TaskConfig, + cancellation_token: Option, ) -> ExecutionResponse { let task_execution_tracker = Arc::new(TaskExecutionTracker::new( tasks.clone(), DisplayMode::MultipleTasksOutput, notifier, + cancellation_token.clone(), )); let start_time = Instant::now(); let task_count = tasks.len(); @@ -71,7 +82,12 @@ pub async fn execute_tasks_in_parallel( return create_error_response(e); } - let shared_state = create_shared_state(task_rx, result_tx, task_execution_tracker.clone()); + let shared_state = create_shared_state( + task_rx, + result_tx, + task_execution_tracker.clone(), + cancellation_token.unwrap_or_default(), + ); let worker_count = std::cmp::min(task_count, DEFAULT_MAX_WORKERS); let mut worker_handles = Vec::new(); @@ -135,12 +151,14 @@ fn create_shared_state( task_rx: mpsc::Receiver, result_tx: mpsc::Sender, task_execution_tracker: Arc, + cancellation_token: CancellationToken, ) -> Arc { Arc::new(SharedState { task_receiver: Arc::new(tokio::sync::Mutex::new(task_rx)), result_sender: result_tx, active_workers: Arc::new(AtomicUsize::new(0)), task_execution_tracker, + cancellation_token, }) } diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index d6f431ede629..172ad03c6146 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -9,6 +9,7 @@ use crate::agents::subagent_task_config::TaskConfig; use rmcp::model::JsonRpcMessage; use serde_json::{json, Value}; use tokio::sync::mpsc::Sender; +use tokio_util::sync::CancellationToken; pub async fn execute_tasks( input: Value, @@ -16,6 +17,7 @@ pub async fn execute_tasks( notifier: Sender, task_config: TaskConfig, tasks_manager: &TasksManager, + cancellation_token: Option, ) -> Result { let task_ids: Vec = serde_json::from_value( input @@ -31,7 +33,8 @@ pub async fn execute_tasks( match execution_mode { ExecutionMode::Sequential => { if task_count == 1 { - let response = execute_single_task(&tasks[0], notifier, task_config).await; + let response = + execute_single_task(&tasks[0], notifier, task_config, cancellation_token).await; handle_response(response) } else { Err("Sequential execution mode requires exactly one task".to_string()) @@ -47,8 +50,13 @@ pub async fn execute_tasks( } )) } else { - let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, notifier.clone(), task_config).await; + let response: ExecutionResponse = execute_tasks_in_parallel( + tasks, + notifier.clone(), + task_config, + cancellation_token, + ) + .await; handle_response(response) } } diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index fc400dad326b..e06da4061566 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -11,6 +11,7 @@ use crate::agents::{ use rmcp::model::JsonRpcMessage; use tokio::sync::mpsc; use tokio_stream; +use tokio_util::sync::CancellationToken; pub const SUBAGENT_EXECUTE_TASK_TOOL_NAME: &str = "subagent__execute_task"; pub fn create_subagent_execute_task_tool() -> Tool { @@ -64,6 +65,7 @@ pub async fn run_tasks( execute_data: Value, task_config: TaskConfig, tasks_manager: &TasksManager, + cancellation_token: Option, ) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); @@ -81,6 +83,7 @@ pub async fn run_tasks( notification_tx, task_config, &tasks_manager_clone, + cancellation_token, ) .await { diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index dab102392799..2d501ebb6504 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; +use tokio_util::sync::CancellationToken; use crate::agents::subagent_execution_tool::notification_events::{ FailedTaskInfo, TaskCompletionStats, TaskExecutionNotificationEvent, TaskExecutionStats, @@ -62,6 +63,7 @@ pub struct TaskExecutionTracker { last_refresh: Arc>, notifier: mpsc::Sender, display_mode: DisplayMode, + cancellation_token: Option, } impl TaskExecutionTracker { @@ -69,6 +71,7 @@ impl TaskExecutionTracker { tasks: Vec, display_mode: DisplayMode, notifier: Sender, + cancellation_token: Option, ) -> Self { let task_map = tasks .into_iter() @@ -93,6 +96,7 @@ impl TaskExecutionTracker { last_refresh: Arc::new(RwLock::new(Instant::now())), notifier, display_mode, + cancellation_token, } } @@ -166,7 +170,14 @@ impl TaskExecutionTracker { }, })) { - tracing::warn!("Failed to send live output notification: {}", e); + // Only log warning if not cancelled (channel close is expected during cancellation) + if let Some(ref token) = self.cancellation_token { + if !token.is_cancelled() { + tracing::warn!("Failed to send live output notification: {}", e); + } + } else { + tracing::warn!("Failed to send live output notification: {}", e); + } } } DisplayMode::MultipleTasksOutput => { @@ -197,6 +208,13 @@ impl TaskExecutionTracker { } async fn send_tasks_update(&self) { + // Check if we're cancelled before sending notifications + if let Some(ref token) = self.cancellation_token { + if token.is_cancelled() { + return; + } + } + let tasks = self.tasks.read().await; let task_list: Vec<_> = tasks.values().collect(); let (total, pending, running, completed, failed) = count_by_status(&tasks); @@ -242,7 +260,14 @@ impl TaskExecutionTracker { }, })) { - tracing::warn!("Failed to send tasks update notification: {}", e); + // Only log warning if not cancelled (channel close is expected during cancellation) + if let Some(ref token) = self.cancellation_token { + if !token.is_cancelled() { + tracing::warn!("Failed to send tasks update notification: {}", e); + } + } else { + tracing::warn!("Failed to send tasks update notification: {}", e); + } } } @@ -276,6 +301,13 @@ impl TaskExecutionTracker { } pub async fn send_tasks_complete(&self) { + // Check if we're cancelled before sending notifications + if let Some(ref token) = self.cancellation_token { + if token.is_cancelled() { + return; + } + } + let tasks = self.tasks.read().await; let (total, _, _, completed, failed) = count_by_status(&tasks); @@ -306,10 +338,23 @@ impl TaskExecutionTracker { }, })) { - tracing::warn!("Failed to send tasks complete notification: {}", e); + // Only log warning if not cancelled (channel close is expected during cancellation) + if let Some(ref token) = self.cancellation_token { + if !token.is_cancelled() { + tracing::warn!("Failed to send tasks complete notification: {}", e); + } + } else { + tracing::warn!("Failed to send tasks complete notification: {}", e); + } } - // Brief delay to ensure completion notification is processed - sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; + // Brief delay to ensure completion notification is processed (skip if cancelled) + if let Some(ref token) = self.cancellation_token { + if !token.is_cancelled() { + sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; + } + } else { + sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; + } } } diff --git a/crates/goose/src/agents/subagent_execution_tool/task_types.rs b/crates/goose/src/agents/subagent_execution_tool/task_types.rs index 796491f624f2..6bdcce33a7f9 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_types.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_types.rs @@ -3,6 +3,7 @@ use serde_json::{Map, Value}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; use crate::agents::subagent_execution_tool::task_execution_tracker::TaskExecutionTracker; @@ -117,6 +118,7 @@ pub struct SharedState { pub result_sender: mpsc::Sender, pub active_workers: Arc, pub task_execution_tracker: Arc, + pub cancellation_token: CancellationToken, } impl SharedState { diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks.rs b/crates/goose/src/agents/subagent_execution_tool/tasks.rs index a330711e0a0a..89f399cb13cb 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks.rs @@ -4,6 +4,7 @@ use std::process::Stdio; use std::sync::Arc; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; +use tokio_util::sync::CancellationToken; use crate::agents::subagent_execution_tool::task_execution_tracker::TaskExecutionTracker; use crate::agents::subagent_execution_tool::task_types::{Task, TaskResult, TaskStatus}; @@ -14,8 +15,16 @@ pub async fn process_task( task: &Task, task_execution_tracker: Arc, task_config: TaskConfig, + cancellation_token: CancellationToken, ) -> TaskResult { - match get_task_result(task.clone(), task_execution_tracker, task_config).await { + match get_task_result( + task.clone(), + task_execution_tracker, + task_config, + cancellation_token, + ) + .await + { Ok(data) => TaskResult { task_id: task.id.clone(), status: TaskStatus::Completed, @@ -35,10 +44,17 @@ async fn get_task_result( task: Task, task_execution_tracker: Arc, task_config: TaskConfig, + cancellation_token: CancellationToken, ) -> Result { if task.task_type == "text_instruction" { // Handle text_instruction tasks using subagent system - handle_text_instruction_task(task, task_execution_tracker, task_config).await + handle_text_instruction_task( + task, + task_execution_tracker, + task_config, + cancellation_token, + ) + .await } else { // Handle sub_recipe tasks using command execution let (command, output_identifier) = build_command(&task)?; @@ -47,6 +63,7 @@ async fn get_task_result( &output_identifier, &task.id, task_execution_tracker, + cancellation_token, ) .await?; @@ -62,6 +79,7 @@ async fn handle_text_instruction_task( task: Task, task_execution_tracker: Arc, task_config: TaskConfig, + cancellation_token: CancellationToken, ) -> Result { let text_instruction = task .get_text_instruction() @@ -76,7 +94,15 @@ async fn handle_text_instruction_task( // "instructions": "You are a helpful assistant. Execute the given task and provide a clear, concise response.", }); - match run_complete_subagent_task(task_arguments, task_config).await { + // Use tokio::select to race between subagent execution and cancellation + let result = tokio::select! { + result = run_complete_subagent_task(task_arguments, task_config) => result, + _ = cancellation_token.cancelled() => { + return Err("Task cancelled".to_string()); + } + }; + + match result { Ok(contents) => { // Extract the text content from the result let result_text = contents @@ -141,6 +167,7 @@ async fn run_command( output_identifier: &str, task_id: &str, task_execution_tracker: Arc, + cancellation_token: CancellationToken, ) -> Result<(String, String, bool), String> { let mut child = command .spawn() @@ -164,15 +191,27 @@ async fn run_command( task_execution_tracker.clone(), ); - let status = child - .wait() - .await - .map_err(|e| format!("Failed to wait for process: {}", e))?; + // Use tokio::select to race between process completion and cancellation + let result = tokio::select! { + _ = cancellation_token.cancelled() => { + // Kill the child process + if let Err(e) = child.kill().await { + tracing::warn!("Failed to kill child process: {}", e); + } + // Abort the output reading tasks + stdout_task.abort(); + stderr_task.abort(); + return Err("Command cancelled".to_string()); + } + status_result = child.wait() => { + status_result.map_err(|e| format!("Failed to wait for process: {}", e))? + } + }; let stdout_output = stdout_task.await.unwrap(); let stderr_output = stderr_task.await.unwrap(); - Ok((stdout_output, stderr_output, status.success())) + Ok((stdout_output, stderr_output, result.success())) } fn spawn_output_reader( diff --git a/crates/goose/src/agents/subagent_execution_tool/workers.rs b/crates/goose/src/agents/subagent_execution_tool/workers.rs index 4ae0ab250737..d28808cf5d59 100644 --- a/crates/goose/src/agents/subagent_execution_tool/workers.rs +++ b/crates/goose/src/agents/subagent_execution_tool/workers.rs @@ -21,18 +21,35 @@ pub fn spawn_worker( } async fn worker_loop(state: Arc, _worker_id: usize, task_config: TaskConfig) { - while let Some(task) = receive_task(&state).await { - state.task_execution_tracker.start_task(&task.id).await; - let result = process_task( - &task, - state.task_execution_tracker.clone(), - task_config.clone(), - ) - .await; + loop { + tokio::select! { + task_option = receive_task(&state) => { + match task_option { + Some(task) => { + state.task_execution_tracker.start_task(&task.id).await; + let result = process_task( + &task, + state.task_execution_tracker.clone(), + task_config.clone(), + state.cancellation_token.clone(), + ) + .await; - if let Err(e) = state.result_sender.send(result).await { - tracing::error!("Worker failed to send result: {}", e); - break; + if let Err(e) = state.result_sender.send(result).await { + // Only log error if not cancelled (channel close is expected during cancellation) + if !state.cancellation_token.is_cancelled() { + tracing::error!("Worker failed to send result: {}", e); + } + break; + } + } + None => break, // No more tasks + } + } + _ = state.cancellation_token.cancelled() => { + tracing::debug!("Worker cancelled"); + break; + } } } From 553745e4382ef7d60a44fe65a7008a7d78a49f25 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 23 Jul 2025 13:43:26 +1000 Subject: [PATCH 2/5] clean up --- crates/goose/src/agents/agent.rs | 2 - .../task_execution_tracker.rs | 129 ++++++------------ 2 files changed, 42 insertions(+), 89 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 13f16a53d124..970a1df61273 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -348,7 +348,6 @@ impl Agent { let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); - // Get the current cancellation token let current_token = self.current_cancellation_token.lock().await.clone(); subagent_execute_task_tool::run_tasks( @@ -726,7 +725,6 @@ impl Agent { session: Option, cancel_token: Option, ) -> Result>> { - // Store the cancellation token for use in tool calls { let mut current_token = self.current_cancellation_token.lock().await; *current_token = cancel_token.clone(); diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index 2d501ebb6504..7d6854b3e229 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -3,7 +3,7 @@ use rmcp::object; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; -use tokio::time::{sleep, Duration, Instant}; +use tokio::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; use crate::agents::subagent_execution_tool::notification_events::{ @@ -22,7 +22,6 @@ pub enum DisplayMode { } const THROTTLE_INTERVAL_MS: u64 = 250; -const COMPLETION_NOTIFICATION_DELAY_MS: u64 = 500; fn format_task_metadata(task_info: &TaskInfo) -> String { if let Some(params) = task_info.task.get_command_parameters() { @@ -100,6 +99,40 @@ impl TaskExecutionTracker { } } + fn is_cancelled(&self) -> bool { + self.cancellation_token + .as_ref() + .is_some_and(|t| t.is_cancelled()) + } + + fn log_notification_error( + &self, + error: &mpsc::error::TrySendError, + context: &str, + ) { + if !self.is_cancelled() { + tracing::warn!("Failed to send {} notification: {}", context, error); + } + } + + fn try_send_notification(&self, event: TaskExecutionNotificationEvent, context: &str) { + if let Err(e) = self + .notifier + .try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "data": event.to_notification_data() + }), + extensions: Default::default(), + }, + })) + { + self.log_notification_error(&e, context); + } + } + pub async fn start_task(&self, task_id: &str) { let mut tasks = self.tasks.write().await; if let Some(task_info) = tasks.get_mut(task_id) { @@ -157,28 +190,7 @@ impl TaskExecutionTracker { formatted_line, ); - if let Err(e) = - self.notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: JsonRpcVersion2_0, - notification: Notification { - method: "notifications/message".to_string(), - params: object!({ - "data": event.to_notification_data() - }), - extensions: Default::default(), - }, - })) - { - // Only log warning if not cancelled (channel close is expected during cancellation) - if let Some(ref token) = self.cancellation_token { - if !token.is_cancelled() { - tracing::warn!("Failed to send live output notification: {}", e); - } - } else { - tracing::warn!("Failed to send live output notification: {}", e); - } - } + self.try_send_notification(event, "live output"); } DisplayMode::MultipleTasksOutput => { let mut tasks = self.tasks.write().await; @@ -208,11 +220,8 @@ impl TaskExecutionTracker { } async fn send_tasks_update(&self) { - // Check if we're cancelled before sending notifications - if let Some(ref token) = self.cancellation_token { - if token.is_cancelled() { - return; - } + if self.is_cancelled() { + return; } let tasks = self.tasks.read().await; @@ -247,28 +256,7 @@ impl TaskExecutionTracker { let event = TaskExecutionNotificationEvent::tasks_update(stats, event_tasks); - if let Err(e) = self - .notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: JsonRpcVersion2_0, - notification: Notification { - method: "notifications/message".to_string(), - params: object!({ - "data": event.to_notification_data() - }), - extensions: Default::default(), - }, - })) - { - // Only log warning if not cancelled (channel close is expected during cancellation) - if let Some(ref token) = self.cancellation_token { - if !token.is_cancelled() { - tracing::warn!("Failed to send tasks update notification: {}", e); - } - } else { - tracing::warn!("Failed to send tasks update notification: {}", e); - } - } + self.try_send_notification(event, "tasks update"); } pub async fn refresh_display(&self) { @@ -301,11 +289,8 @@ impl TaskExecutionTracker { } pub async fn send_tasks_complete(&self) { - // Check if we're cancelled before sending notifications - if let Some(ref token) = self.cancellation_token { - if token.is_cancelled() { - return; - } + if self.is_cancelled() { + return; } let tasks = self.tasks.read().await; @@ -325,36 +310,6 @@ impl TaskExecutionTracker { let event = TaskExecutionNotificationEvent::tasks_complete(stats, failed_tasks); - if let Err(e) = self - .notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: JsonRpcVersion2_0, - notification: Notification { - method: "notifications/message".to_string(), - params: object!({ - "data": event.to_notification_data() - }), - extensions: Default::default(), - }, - })) - { - // Only log warning if not cancelled (channel close is expected during cancellation) - if let Some(ref token) = self.cancellation_token { - if !token.is_cancelled() { - tracing::warn!("Failed to send tasks complete notification: {}", e); - } - } else { - tracing::warn!("Failed to send tasks complete notification: {}", e); - } - } - - // Brief delay to ensure completion notification is processed (skip if cancelled) - if let Some(ref token) = self.cancellation_token { - if !token.is_cancelled() { - sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; - } - } else { - sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; - } + self.try_send_notification(event, "tasks complete"); } } From 5ad676b8bb3bb053ee7d99533f7951296db4ff7a Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 23 Jul 2025 13:56:13 +1000 Subject: [PATCH 3/5] removed comments --- crates/goose/src/agents/subagent_execution_tool/tasks.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks.rs b/crates/goose/src/agents/subagent_execution_tool/tasks.rs index 89f399cb13cb..4ecd5b628ffa 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks.rs @@ -94,7 +94,6 @@ async fn handle_text_instruction_task( // "instructions": "You are a helpful assistant. Execute the given task and provide a clear, concise response.", }); - // Use tokio::select to race between subagent execution and cancellation let result = tokio::select! { result = run_complete_subagent_task(task_arguments, task_config) => result, _ = cancellation_token.cancelled() => { @@ -191,10 +190,8 @@ async fn run_command( task_execution_tracker.clone(), ); - // Use tokio::select to race between process completion and cancellation let result = tokio::select! { _ = cancellation_token.cancelled() => { - // Kill the child process if let Err(e) = child.kill().await { tracing::warn!("Failed to kill child process: {}", e); } From 3b1fb1eb88b178a49c46831aa0f9aa08c14142a6 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 23 Jul 2025 21:51:18 +1000 Subject: [PATCH 4/5] removed setting cancellation token on agent --- crates/goose/src/agents/agent.rs | 15 ++++----------- crates/goose/src/agents/tool_execution.rs | 4 +++- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 970a1df61273..c0e83fb8063d 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -75,7 +75,6 @@ pub struct Agent { pub(super) tool_monitor: Arc>>, pub(super) router_tool_selector: Mutex>>>, pub(super) scheduler_service: Mutex>>, - pub(super) current_cancellation_token: Arc>>, pub(super) mcp_tx: Mutex>, pub(super) mcp_notification_rx: Arc>>, pub(super) retry_manager: RetryManager, @@ -155,7 +154,6 @@ impl Agent { tool_monitor, router_tool_selector: Mutex::new(None), scheduler_service: Mutex::new(None), - current_cancellation_token: Arc::new(Mutex::new(None)), // Initialize with MCP notification support mcp_tx: Mutex::new(mcp_tx), mcp_notification_rx: Arc::new(Mutex::new(mcp_rx)), @@ -275,6 +273,7 @@ impl Agent { &self, tool_call: mcp_core::tool::ToolCall, request_id: String, + cancellation_token: Option, ) -> (String, Result) { // Check if this tool call should be allowed based on repetition monitoring if let Some(monitor) = self.tool_monitor.lock().await.as_mut() { @@ -348,13 +347,11 @@ impl Agent { let task_config = TaskConfig::new(provider, Some(Arc::clone(&self.extension_manager)), mcp_tx); - let current_token = self.current_cancellation_token.lock().await.clone(); - subagent_execute_task_tool::run_tasks( tool_call.arguments.clone(), task_config, &self.tasks_manager, - current_token, + cancellation_token, ) .await } else if tool_call.name == DYNAMIC_TASK_TOOL_NAME_PREFIX { @@ -725,11 +722,6 @@ impl Agent { session: Option, cancel_token: Option, ) -> Result>> { - { - let mut current_token = self.current_cancellation_token.lock().await; - *current_token = cancel_token.clone(); - } - let mut messages = messages.to_vec(); let initial_messages = messages.clone(); let reply_span = tracing::Span::current(); @@ -925,7 +917,7 @@ impl Agent { for request in &permission_check_result.approved { if let Ok(tool_call) = request.tool_call.clone() { let (req_id, tool_result) = self - .dispatch_tool_call(tool_call, request.id.clone()) + .dispatch_tool_call(tool_call, request.id.clone(), cancel_token.clone()) .await; tool_futures.push(( @@ -962,6 +954,7 @@ impl Agent { tool_futures_arc.clone(), &mut permission_manager, message_tool_response.clone(), + cancel_token.clone(), ); while let Some(msg) = tool_approval_stream.try_next().await? { diff --git a/crates/goose/src/agents/tool_execution.rs b/crates/goose/src/agents/tool_execution.rs index 9af001fe7666..bc9f4292f72f 100644 --- a/crates/goose/src/agents/tool_execution.rs +++ b/crates/goose/src/agents/tool_execution.rs @@ -6,6 +6,7 @@ use futures::stream::{self, BoxStream}; use futures::{Stream, StreamExt}; use rmcp::model::JsonRpcMessage; use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; use crate::config::permission::PermissionLevel; use crate::config::PermissionManager; @@ -53,6 +54,7 @@ impl Agent { tool_futures: Arc>>, permission_manager: &'a mut PermissionManager, message_tool_response: Arc>, + cancellation_token: Option, ) -> BoxStream<'a, anyhow::Result> { try_stream! { for request in tool_requests { @@ -69,7 +71,7 @@ impl Agent { while let Some((req_id, confirmation)) = rx.recv().await { if req_id == request.id { if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow { - let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone()).await; + let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone()).await; let mut futures = tool_futures.lock().await; futures.push((req_id, match tool_result { From 1cae70c6972a77ec72f537f1b78b6632181674f0 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Wed, 23 Jul 2025 22:06:39 +1000 Subject: [PATCH 5/5] fixed test compilation --- crates/goose/tests/agent.rs | 2 +- crates/goose/tests/private_tests.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index ab8b8cb155d8..497ebcaab715 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -603,7 +603,7 @@ mod final_output_tool_tests { }), ); let (_, result) = agent - .dispatch_tool_call(tool_call, "request_id".to_string()) + .dispatch_tool_call(tool_call, "request_id".to_string(), None) .await; assert!(result.is_ok(), "Tool call should succeed"); diff --git a/crates/goose/tests/private_tests.rs b/crates/goose/tests/private_tests.rs index d2ec7a06e8ae..e23d0c09e319 100644 --- a/crates/goose/tests/private_tests.rs +++ b/crates/goose/tests/private_tests.rs @@ -885,7 +885,7 @@ async fn test_schedule_tool_dispatch() { }; let (request_id, result) = agent - .dispatch_tool_call(tool_call, "test_dispatch".to_string()) + .dispatch_tool_call(tool_call, "test_dispatch".to_string(), None) .await; assert_eq!(request_id, "test_dispatch"); assert!(result.is_ok());