Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,8 @@ impl Session {
}
}
_ = tokio::signal::ctrl_c() => {
self.agent.cancel_all_subagent_executions().await;

drop(stream);
if let Err(e) = self.handle_interrupted_messages(true).await {
eprintln!("Error handling interruption: {}", e);
Expand Down
4 changes: 4 additions & 0 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ impl Agent {
*scheduler_service = Some(scheduler);
}

pub async fn cancel_all_subagent_executions(&self) {
self.tasks_manager.cancel_all_executions().await;
}

/// Get a reference count clone to the provider
pub async fn provider(&self) -> Result<Arc<dyn Provider>, anyhow::Error> {
match &*self.provider.lock().await {
Expand Down
34 changes: 32 additions & 2 deletions crates/goose/src/agents/subagent_execution_tool/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,43 @@ use tokio::time::Instant;
const EXECUTION_STATUS_COMPLETED: &str = "completed";
const DEFAULT_MAX_WORKERS: usize = 10;

/// Sets up cancellation handling for a task execution tracker
async fn setup_cancellation_handling(
cancellation_token: tokio_util::sync::CancellationToken,
task_execution_tracker: Arc<TaskExecutionTracker>,
) {
cancellation_token.cancelled().await;
task_execution_tracker.mark_cancelled();
}

pub async fn execute_single_task(
task: &Task,
notifier: mpsc::Sender<JsonRpcMessage>,
task_config: TaskConfig,
cancellation_token: tokio_util::sync::CancellationToken,
) -> ExecutionResponse {
let start_time = Instant::now();
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
vec![task.clone()],
DisplayMode::SingleTaskOutput,
notifier,
));
let result = process_task(task, task_execution_tracker.clone(), task_config).await;

// Complete the task in the tracker
let cancellation_future =
setup_cancellation_handling(cancellation_token, task_execution_tracker.clone());

let result = tokio::select! {
result = process_task(task, task_execution_tracker.clone(), task_config) => result,
_ = cancellation_future => {
crate::agents::subagent_execution_tool::task_types::TaskResult {
task_id: task.id.clone(),
status: crate::agents::subagent_execution_tool::task_types::TaskStatus::Failed,
data: None,
error: Some("Task execution cancelled".to_string()),
}
}
};

task_execution_tracker
.complete_task(&result.task_id, result.clone())
.await;
Expand All @@ -48,12 +71,19 @@ pub async fn execute_tasks_in_parallel(
tasks: Vec<Task>,
notifier: mpsc::Sender<JsonRpcMessage>,
task_config: TaskConfig,
cancellation_token: tokio_util::sync::CancellationToken,
) -> ExecutionResponse {
let task_execution_tracker = Arc::new(TaskExecutionTracker::new(
tasks.clone(),
DisplayMode::MultipleTasksOutput,
notifier,
));

tokio::spawn(setup_cancellation_handling(
cancellation_token,
task_execution_tracker.clone(),
));

let start_time = Instant::now();
let task_count = tasks.len();

Expand Down
13 changes: 10 additions & 3 deletions crates/goose/src/agents/subagent_execution_tool/lib/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub async fn execute_tasks(
notifier: mpsc::Sender<JsonRpcMessage>,
task_config: TaskConfig,
tasks_manager: &TasksManager,
cancellation_token: tokio_util::sync::CancellationToken,
) -> Result<Value, String> {
let task_ids: Vec<String> = serde_json::from_value(
input
Expand All @@ -31,7 +32,8 @@ pub async fn execute_tasks(
match execution_mode {
ExecutionMode::Sequential => {
if task_count == 1 {
let response = execute_single_task(&tasks[0], notifier, task_config).await;
let response =
execute_single_task(&tasks[0], notifier, task_config, cancellation_token).await;
handle_response(response)
} else {
Err("Sequential execution mode requires exactly one task".to_string())
Expand All @@ -47,8 +49,13 @@ pub async fn execute_tasks(
}
))
} else {
let response: ExecutionResponse =
execute_tasks_in_parallel(tasks, notifier.clone(), task_config).await;
let response: ExecutionResponse = execute_tasks_in_parallel(
tasks,
notifier.clone(),
task_config,
cancellation_token,
)
.await;
handle_response(response)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::agents::{
use mcp_core::protocol::JsonRpcMessage;
use tokio::sync::mpsc;
use tokio_stream;
use tokio_util::sync::CancellationToken;

pub const SUBAGENT_EXECUTE_TASK_TOOL_NAME: &str = "subagent__execute_task";
pub fn create_subagent_execute_task_tool() -> Tool {
Expand Down Expand Up @@ -66,29 +67,41 @@ pub async fn run_tasks(
tasks_manager: &TasksManager,
) -> ToolCallResult {
let (notification_tx, notification_rx) = mpsc::channel::<JsonRpcMessage>(100);
let cancellation_token = CancellationToken::new();

// Register the execution with the tasks manager
tasks_manager
.register_execution(cancellation_token.clone())
.await;

let tasks_manager_clone = tasks_manager.clone();
let cancellation_token_clone = cancellation_token.clone();
let result_future = async move {
let execute_data_clone = execute_data.clone();
let execution_mode = execute_data_clone
let execution_mode = execute_data
.get("execution_mode")
.and_then(|v| serde_json::from_value::<ExecutionMode>(v.clone()).ok())
.unwrap_or_default();

match execute_tasks(
execute_data,
execution_mode,
notification_tx,
task_config,
&tasks_manager_clone,
)
.await
{
Ok(result) => {
let output = serde_json::to_string(&result).unwrap();
Ok(vec![Content::text(output)])
tokio::select! {
result = execute_tasks(
execute_data,
execution_mode,
notification_tx,
task_config,
&tasks_manager_clone,
cancellation_token_clone.clone(),
) => {
match result {
Ok(result) => {
let output = serde_json::to_string(&result).unwrap();
Ok(vec![Content::text(output)])
}
Err(e) => Err(ToolError::ExecutionError(e.to_string())),
}
}
_ = cancellation_token_clone.cancelled() => {
Err(ToolError::ExecutionError("Task execution cancelled".to_string()))
}
Err(e) => Err(ToolError::ExecutionError(e.to_string())),
}
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use mcp_core::protocol::{JsonRpcMessage, JsonRpcNotification};
use serde_json::json;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio::time::{sleep, Duration, Instant};
Expand Down Expand Up @@ -61,6 +62,7 @@ pub struct TaskExecutionTracker {
last_refresh: Arc<RwLock<Instant>>,
notifier: mpsc::Sender<JsonRpcMessage>,
display_mode: DisplayMode,
is_cancelled: Arc<AtomicBool>,
}

impl TaskExecutionTracker {
Expand Down Expand Up @@ -92,6 +94,7 @@ impl TaskExecutionTracker {
last_refresh: Arc::new(RwLock::new(Instant::now())),
notifier,
display_mode,
is_cancelled: Arc::new(AtomicBool::new(false)),
}
}

Expand Down Expand Up @@ -162,7 +165,9 @@ impl TaskExecutionTracker {
})),
}))
{
tracing::warn!("Failed to send live output notification: {}", e);
if !self.should_suppress_error(&e) {
tracing::warn!("Failed to send live output notification: {}", e);
}
}
}
DisplayMode::MultipleTasksOutput => {
Expand Down Expand Up @@ -235,7 +240,9 @@ impl TaskExecutionTracker {
})),
}))
{
tracing::warn!("Failed to send tasks update notification: {}", e);
if !self.should_suppress_error(&e) {
tracing::warn!("Failed to send tasks update notification: {}", e);
}
}
}

Expand Down Expand Up @@ -296,10 +303,36 @@ impl TaskExecutionTracker {
})),
}))
{
tracing::warn!("Failed to send tasks complete notification: {}", e);
if !self.should_suppress_error(&e) {
tracing::warn!("Failed to send tasks complete notification: {}", e);
}
}

// Brief delay to ensure completion notification is processed
sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await;
if !self.is_cancelled() {
sleep(Duration::from_millis(COMPLETION_NOTIFICATION_DELAY_MS)).await;
}
}

fn is_channel_closed(
&self,
error: &tokio::sync::mpsc::error::TrySendError<JsonRpcMessage>,
) -> bool {
matches!(error, tokio::sync::mpsc::error::TrySendError::Closed(_))
}

fn should_suppress_error(
&self,
error: &tokio::sync::mpsc::error::TrySendError<JsonRpcMessage>,
) -> bool {
self.is_cancelled() && self.is_channel_closed(error)
}

pub fn is_cancelled(&self) -> bool {
self.is_cancelled.load(Ordering::SeqCst)
}

pub fn mark_cancelled(&self) {
self.is_cancelled.store(true, Ordering::SeqCst);
}
}
58 changes: 58 additions & 0 deletions crates/goose/src/agents/subagent_execution_tool/tasks_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;

use crate::agents::subagent_execution_tool::task_types::Task;

#[derive(Debug, Clone)]
pub struct TasksManager {
tasks: Arc<RwLock<HashMap<String, Task>>>,
active_tokens: Arc<RwLock<Vec<CancellationToken>>>,
}

impl Default for TasksManager {
Expand All @@ -20,6 +22,7 @@ impl TasksManager {
pub fn new() -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
active_tokens: Arc::new(RwLock::new(Vec::new())),
}
}

Expand Down Expand Up @@ -50,6 +53,22 @@ impl TasksManager {
}
Ok(tasks)
}

pub async fn register_execution(&self, cancellation_token: CancellationToken) {
let mut tokens = self.active_tokens.write().await;
tokens.retain(|token| !token.is_cancelled());
Copy link

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cleanup of cancelled tokens happens on every registration. For high-frequency registrations, this could become inefficient. Consider implementing a periodic cleanup strategy or cleanup threshold to avoid O(n) operations on every registration.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

@lifeizhou-ap lifeizhou-ap Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not high-frequency registrations

tokens.push(cancellation_token);
}

pub async fn cancel_all_executions(&self) {
let mut tokens = self.active_tokens.write().await;

for token in tokens.iter() {
token.cancel();
}

tokens.clear();
}
}

#[cfg(test)]
Expand Down Expand Up @@ -100,4 +119,43 @@ mod tests {
assert_eq!(task1.unwrap().id, "task1");
assert_eq!(task2.unwrap().id, "task2");
}

#[tokio::test]
async fn test_cancellation_token_tracking() {
let manager = TasksManager::new();

let token1 = CancellationToken::new();
let token2 = CancellationToken::new();

manager.register_execution(token1.clone()).await;
manager.register_execution(token2.clone()).await;

assert!(!token1.is_cancelled());
assert!(!token2.is_cancelled());

manager.cancel_all_executions().await;

assert!(token1.is_cancelled());
assert!(token2.is_cancelled());
}

#[tokio::test]
async fn test_automatic_cleanup_on_register() {
let manager = TasksManager::new();

let token1 = CancellationToken::new();
let token2 = CancellationToken::new();

manager.register_execution(token1.clone()).await;
manager.register_execution(token2.clone()).await;

token1.cancel();

let token3 = CancellationToken::new();
manager.register_execution(token3.clone()).await;

let tokens = manager.active_tokens.read().await;
assert_eq!(tokens.len(), 2);
assert!(!tokens.iter().any(|t| t.is_cancelled()));
}
}
4 changes: 3 additions & 1 deletion crates/goose/src/agents/subagent_execution_tool/workers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ async fn worker_loop(state: Arc<SharedState>, _worker_id: usize, task_config: Ta
.await;

if let Err(e) = state.result_sender.send(result).await {
tracing::error!("Worker failed to send result: {}", e);
if !state.task_execution_tracker.is_cancelled() {
tracing::error!("Worker failed to send result: {}", e);
}
break;
}
}
Expand Down
Loading