diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 1635a65d919e..e79f954f6d08 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -1173,6 +1173,8 @@ impl Session { } } _ = tokio::signal::ctrl_c() => { + self.agent.cancel_all_subagent_executions().await; + drop(stream); if let Err(e) = self.handle_interrupted_messages(true).await { eprintln!("Error handling interruption: {}", e); diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 1fbc70e8492a..0dde0a9d134c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -219,6 +219,10 @@ impl Agent { *scheduler_service = Some(scheduler); } + pub async fn cancel_all_subagent_executions(&self) { + self.tasks_manager.cancel_all_executions().await; + } + /// Get a reference count clone to the provider pub async fn provider(&self) -> Result, anyhow::Error> { match &*self.provider.lock().await { diff --git a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs index 9a71ad4ad7f9..03e5c5a4c41b 100644 --- a/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/executor/mod.rs @@ -16,10 +16,20 @@ use tokio::time::Instant; const EXECUTION_STATUS_COMPLETED: &str = "completed"; const DEFAULT_MAX_WORKERS: usize = 10; +/// Sets up cancellation handling for a task execution tracker +async fn setup_cancellation_handling( + cancellation_token: tokio_util::sync::CancellationToken, + task_execution_tracker: Arc, +) { + cancellation_token.cancelled().await; + task_execution_tracker.mark_cancelled(); +} + pub async fn execute_single_task( task: &Task, notifier: mpsc::Sender, task_config: TaskConfig, + cancellation_token: tokio_util::sync::CancellationToken, ) -> ExecutionResponse { let start_time = Instant::now(); let task_execution_tracker = Arc::new(TaskExecutionTracker::new( @@ -27,9 +37,22 @@ pub async fn execute_single_task( DisplayMode::SingleTaskOutput, notifier, )); - let result = process_task(task, task_execution_tracker.clone(), task_config).await; - // Complete the task in the tracker + let cancellation_future = + setup_cancellation_handling(cancellation_token, task_execution_tracker.clone()); + + let result = tokio::select! { + result = process_task(task, task_execution_tracker.clone(), task_config) => result, + _ = cancellation_future => { + crate::agents::subagent_execution_tool::task_types::TaskResult { + task_id: task.id.clone(), + status: crate::agents::subagent_execution_tool::task_types::TaskStatus::Failed, + data: None, + error: Some("Task execution cancelled".to_string()), + } + } + }; + task_execution_tracker .complete_task(&result.task_id, result.clone()) .await; @@ -48,12 +71,19 @@ pub async fn execute_tasks_in_parallel( tasks: Vec, notifier: mpsc::Sender, task_config: TaskConfig, + cancellation_token: tokio_util::sync::CancellationToken, ) -> ExecutionResponse { let task_execution_tracker = Arc::new(TaskExecutionTracker::new( tasks.clone(), DisplayMode::MultipleTasksOutput, notifier, )); + + tokio::spawn(setup_cancellation_handling( + cancellation_token, + task_execution_tracker.clone(), + )); + let start_time = Instant::now(); let task_count = tasks.len(); diff --git a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs index faa2bcd5a578..8144b754f135 100644 --- a/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs +++ b/crates/goose/src/agents/subagent_execution_tool/lib/mod.rs @@ -16,6 +16,7 @@ pub async fn execute_tasks( notifier: mpsc::Sender, task_config: TaskConfig, tasks_manager: &TasksManager, + cancellation_token: tokio_util::sync::CancellationToken, ) -> Result { let task_ids: Vec = serde_json::from_value( input @@ -31,7 +32,8 @@ pub async fn execute_tasks( match execution_mode { ExecutionMode::Sequential => { if task_count == 1 { - let response = execute_single_task(&tasks[0], notifier, task_config).await; + let response = + execute_single_task(&tasks[0], notifier, task_config, cancellation_token).await; handle_response(response) } else { Err("Sequential execution mode requires exactly one task".to_string()) @@ -47,8 +49,13 @@ pub async fn execute_tasks( } )) } else { - let response: ExecutionResponse = - execute_tasks_in_parallel(tasks, notifier.clone(), task_config).await; + let response: ExecutionResponse = execute_tasks_in_parallel( + tasks, + notifier.clone(), + task_config, + cancellation_token, + ) + .await; handle_response(response) } } diff --git a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs index f3860253eb78..d7d0910f3212 100644 --- a/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs +++ b/crates/goose/src/agents/subagent_execution_tool/subagent_execute_task_tool.rs @@ -11,6 +11,7 @@ use crate::agents::{ use mcp_core::protocol::JsonRpcMessage; use tokio::sync::mpsc; use tokio_stream; +use tokio_util::sync::CancellationToken; pub const SUBAGENT_EXECUTE_TASK_TOOL_NAME: &str = "subagent__execute_task"; pub fn create_subagent_execute_task_tool() -> Tool { @@ -66,29 +67,41 @@ pub async fn run_tasks( tasks_manager: &TasksManager, ) -> ToolCallResult { let (notification_tx, notification_rx) = mpsc::channel::(100); + let cancellation_token = CancellationToken::new(); + + // Register the execution with the tasks manager + tasks_manager + .register_execution(cancellation_token.clone()) + .await; let tasks_manager_clone = tasks_manager.clone(); + let cancellation_token_clone = cancellation_token.clone(); let result_future = async move { - let execute_data_clone = execute_data.clone(); - let execution_mode = execute_data_clone + let execution_mode = execute_data .get("execution_mode") .and_then(|v| serde_json::from_value::(v.clone()).ok()) .unwrap_or_default(); - match execute_tasks( - execute_data, - execution_mode, - notification_tx, - task_config, - &tasks_manager_clone, - ) - .await - { - Ok(result) => { - let output = serde_json::to_string(&result).unwrap(); - Ok(vec![Content::text(output)]) + tokio::select! { + result = execute_tasks( + execute_data, + execution_mode, + notification_tx, + task_config, + &tasks_manager_clone, + cancellation_token_clone.clone(), + ) => { + match result { + Ok(result) => { + let output = serde_json::to_string(&result).unwrap(); + Ok(vec![Content::text(output)]) + } + Err(e) => Err(ToolError::ExecutionError(e.to_string())), + } + } + _ = cancellation_token_clone.cancelled() => { + Err(ToolError::ExecutionError("Task execution cancelled".to_string())) } - Err(e) => Err(ToolError::ExecutionError(e.to_string())), } }; diff --git a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs index c720459e01ae..70194a2385c0 100644 --- a/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs +++ b/crates/goose/src/agents/subagent_execution_tool/task_execution_tracker.rs @@ -1,6 +1,7 @@ use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification}; use serde_json::json; use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use tokio::time::{sleep, Duration, Instant}; @@ -61,6 +62,7 @@ pub struct TaskExecutionTracker { last_refresh: Arc>, notifier: mpsc::Sender, display_mode: DisplayMode, + is_cancelled: Arc, } impl TaskExecutionTracker { @@ -92,6 +94,7 @@ impl TaskExecutionTracker { last_refresh: Arc::new(RwLock::new(Instant::now())), notifier, display_mode, + is_cancelled: Arc::new(AtomicBool::new(false)), } } @@ -162,7 +165,9 @@ impl TaskExecutionTracker { })), })) { - tracing::warn!("Failed to send live output notification: {}", e); + if !self.should_suppress_error(&e) { + tracing::warn!("Failed to send live output notification: {}", e); + } } } DisplayMode::MultipleTasksOutput => { @@ -235,7 +240,9 @@ impl TaskExecutionTracker { })), })) { - tracing::warn!("Failed to send tasks update notification: {}", e); + if !self.should_suppress_error(&e) { + tracing::warn!("Failed to send tasks update notification: {}", e); + } } } @@ -296,10 +303,36 @@ impl TaskExecutionTracker { })), })) { - tracing::warn!("Failed to send tasks complete notification: {}", e); + if !self.should_suppress_error(&e) { + tracing::warn!("Failed to send tasks complete notification: {}", e); + } } // Brief delay to ensure completion notification is processed - sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; + if !self.is_cancelled() { + sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await; + } + } + + fn is_channel_closed( + &self, + error: &tokio::sync::mpsc::error::TrySendError, + ) -> bool { + matches!(error, tokio::sync::mpsc::error::TrySendError::Closed(_)) + } + + fn should_suppress_error( + &self, + error: &tokio::sync::mpsc::error::TrySendError, + ) -> bool { + self.is_cancelled() && self.is_channel_closed(error) + } + + pub fn is_cancelled(&self) -> bool { + self.is_cancelled.load(Ordering::SeqCst) + } + + pub fn mark_cancelled(&self) { + self.is_cancelled.store(true, Ordering::SeqCst); } } diff --git a/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs index 4864994b7a0d..f43884cc9cd3 100644 --- a/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs +++ b/crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs @@ -2,12 +2,14 @@ use anyhow::Result; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; use crate::agents::subagent_execution_tool::task_types::Task; #[derive(Debug, Clone)] pub struct TasksManager { tasks: Arc>>, + active_tokens: Arc>>, } impl Default for TasksManager { @@ -20,6 +22,7 @@ impl TasksManager { pub fn new() -> Self { Self { tasks: Arc::new(RwLock::new(HashMap::new())), + active_tokens: Arc::new(RwLock::new(Vec::new())), } } @@ -50,6 +53,22 @@ impl TasksManager { } Ok(tasks) } + + pub async fn register_execution(&self, cancellation_token: CancellationToken) { + let mut tokens = self.active_tokens.write().await; + tokens.retain(|token| !token.is_cancelled()); + tokens.push(cancellation_token); + } + + pub async fn cancel_all_executions(&self) { + let mut tokens = self.active_tokens.write().await; + + for token in tokens.iter() { + token.cancel(); + } + + tokens.clear(); + } } #[cfg(test)] @@ -100,4 +119,43 @@ mod tests { assert_eq!(task1.unwrap().id, "task1"); assert_eq!(task2.unwrap().id, "task2"); } + + #[tokio::test] + async fn test_cancellation_token_tracking() { + let manager = TasksManager::new(); + + let token1 = CancellationToken::new(); + let token2 = CancellationToken::new(); + + manager.register_execution(token1.clone()).await; + manager.register_execution(token2.clone()).await; + + assert!(!token1.is_cancelled()); + assert!(!token2.is_cancelled()); + + manager.cancel_all_executions().await; + + assert!(token1.is_cancelled()); + assert!(token2.is_cancelled()); + } + + #[tokio::test] + async fn test_automatic_cleanup_on_register() { + let manager = TasksManager::new(); + + let token1 = CancellationToken::new(); + let token2 = CancellationToken::new(); + + manager.register_execution(token1.clone()).await; + manager.register_execution(token2.clone()).await; + + token1.cancel(); + + let token3 = CancellationToken::new(); + manager.register_execution(token3.clone()).await; + + let tokens = manager.active_tokens.read().await; + assert_eq!(tokens.len(), 2); + assert!(!tokens.iter().any(|t| t.is_cancelled())); + } } diff --git a/crates/goose/src/agents/subagent_execution_tool/workers.rs b/crates/goose/src/agents/subagent_execution_tool/workers.rs index 4ae0ab250737..f3597d871811 100644 --- a/crates/goose/src/agents/subagent_execution_tool/workers.rs +++ b/crates/goose/src/agents/subagent_execution_tool/workers.rs @@ -31,7 +31,9 @@ async fn worker_loop(state: Arc, _worker_id: usize, task_config: Ta .await; if let Err(e) = state.result_sender.send(result).await { - tracing::error!("Worker failed to send result: {}", e); + if !state.task_execution_tracker.is_cancelled() { + tracing::error!("Worker failed to send result: {}", e); + } break; } }