diff --git a/Cargo.lock b/Cargo.lock index 9e0bfa5c8889..22173c044526 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2824,6 +2824,7 @@ dependencies = [ "indoc", "keyring", "lazy_static", + "libc", "lopdf", "lru", "mcp-core", @@ -2848,6 +2849,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-stream", + "tokio-util", "tracing", "tracing-appender", "tracing-subscriber", @@ -6790,9 +6792,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.15" +version = "0.7.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" dependencies = [ "bytes", "futures-core", diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index 9480e73c82b2..386e8cacf644 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -71,10 +71,12 @@ tree-sitter-kotlin = "0.3.8" devgen-tree-sitter-swift = "0.21.0" streaming-iterator = "0.1" rayon = "1.10" +libc = "0.2" # TODO: Fork mpatch or replace with a custom implementation using `similar` crate # for fuzzy patch matching. Current crate has limited maintenance (single maintainer, # ~1000 downloads). Pinned to exact version to prevent supply chain attacks. mpatch = "=0.2.0" +tokio-util = "0.7.16" [dev-dependencies] diff --git a/crates/goose-mcp/src/developer/rmcp_developer.rs b/crates/goose-mcp/src/developer/rmcp_developer.rs index 33c2fe491558..fe94b8766130 100644 --- a/crates/goose-mcp/src/developer/rmcp_developer.rs +++ b/crates/goose-mcp/src/developer/rmcp_developer.rs @@ -5,13 +5,13 @@ use indoc::{formatdoc, indoc}; use rmcp::{ handler::server::{router::tool::ToolRouter, wrapper::Parameters}, model::{ - CallToolResult, Content, ErrorCode, ErrorData, GetPromptRequestParam, GetPromptResult, - Implementation, ListPromptsResult, LoggingLevel, LoggingMessageNotificationParam, - PaginatedRequestParam, Prompt, PromptArgument, PromptMessage, PromptMessageRole, Role, - ServerCapabilities, ServerInfo, + CallToolResult, CancelledNotificationParam, Content, ErrorCode, ErrorData, + GetPromptRequestParam, GetPromptResult, Implementation, ListPromptsResult, LoggingLevel, + LoggingMessageNotificationParam, PaginatedRequestParam, Prompt, PromptArgument, + PromptMessage, PromptMessageRole, Role, ServerCapabilities, ServerInfo, }, schemars::JsonSchema, - service::RequestContext, + service::{NotificationContext, RequestContext}, tool, tool_handler, tool_router, RoleServer, ServerHandler, }; use serde::{Deserialize, Serialize}; @@ -20,21 +20,23 @@ use std::{ future::Future, io::Cursor, path::{Path, PathBuf}, - process::Stdio, sync::{Arc, Mutex}, }; use xcap::{Monitor, Window}; use tokio::{ io::{AsyncBufReadExt, BufReader}, - process::Command, + sync::RwLock, }; use tokio_stream::{wrappers::SplitStream, StreamExt as _}; +use tokio_util::sync::CancellationToken; use super::analyze::{types::AnalyzeParams, CodeAnalyzer}; use super::editor_models::{create_editor_model, EditorModel}; use super::goose_hints::load_hints::{load_hint_files, GOOSE_HINTS_FILENAME}; -use super::shell::{expand_path, get_shell_config, is_absolute_path}; +use super::shell::{ + configure_shell_command, expand_path, get_shell_config, is_absolute_path, kill_process_group, +}; use super::text_editor::{ text_editor_insert, text_editor_replace, text_editor_undo, text_editor_view, text_editor_write, }; @@ -173,6 +175,10 @@ pub struct DeveloperServer { editor_model: Option, prompts: HashMap, code_analyzer: CodeAnalyzer, + #[cfg(test)] + pub running_processes: Arc>>, + #[cfg(not(test))] + running_processes: Arc>>, } #[tool_handler(router = self.tool_router)] @@ -503,6 +509,27 @@ impl ServerHandler for DeveloperServer { ))), } } + + /// Called when the client cancels a specific request. + /// This method cancels the running process associated with the given request_id. + #[allow(clippy::manual_async_fn)] + fn on_cancelled( + &self, + notification: CancelledNotificationParam, + _context: NotificationContext, + ) -> impl Future + Send + '_ { + async move { + let request_id = notification.request_id.to_string(); + let processes = self.running_processes.read().await; + + if let Some(token) = processes.get(&request_id) { + token.cancel(); + tracing::debug!("Found process for request {}, cancelling token", request_id); + } else { + tracing::warn!("No process found for request ID: {}", request_id); + } + } + } } impl Default for DeveloperServer { @@ -528,6 +555,7 @@ impl DeveloperServer { editor_model, prompts: load_prompt_files(), code_analyzer: CodeAnalyzer::new(), + running_processes: Arc::new(RwLock::new(HashMap::new())), } } @@ -811,7 +839,7 @@ impl DeveloperServer { /// this tool does not run indefinitely. #[tool( name = "shell", - description = "Execute a command in the shell. Returns output and error concatenated. Avoid commands with large output, use background commands for long-running processes." + description = "Execute a command in the shell.This will return the output and error concatenated into a single string, as you would see from running on the command line. There will also be an indication of if the command succeeded or failed. Avoid commands that produce a large amount of output, and consider piping those outputs to files. If you need to run a long lived command, background it - e.g. `uvicorn main:app &` so that this tool does not run indefinitely." )] pub async fn shell( &self, @@ -821,12 +849,38 @@ impl DeveloperServer { let params = params.0; let command = ¶ms.command; let peer = context.peer; + let request_id = context.id; // Validate the shell command self.validate_shell_command(command)?; + let cancellation_token = CancellationToken::new(); + // Track the process using the request ID + { + let mut processes = self.running_processes.write().await; + let request_id_str = request_id.to_string(); + processes.insert(request_id_str.clone(), cancellation_token.clone()); + } + // Execute the command and capture output - let output_str = self.execute_shell_command(command, &peer).await?; + let output_result = self + .execute_shell_command(command, &peer, cancellation_token.clone()) + .await; + + // Clean up the process from tracking + { + let mut processes = self.running_processes.write().await; + let request_id_str = request_id.to_string(); + let was_present = processes.remove(&request_id_str).is_some(); + if !was_present { + tracing::warn!( + "Process for request_id {} was not in tracking map when trying to remove", + request_id + ); + } + } + + let output_str = output_result?; // Validate output size self.validate_shell_output_size(command, &output_str)?; @@ -893,38 +947,55 @@ impl DeveloperServer { &self, command: &str, peer: &rmcp::service::Peer, + cancellation_token: CancellationToken, ) -> Result { // Get platform-specific shell configuration let shell_config = get_shell_config(); - // Execute the command using platform-specific shell - let mut child = Command::new(&shell_config.executable) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .stdin(Stdio::null()) - .kill_on_drop(true) - .env("GOOSE_TERMINAL", "1") - .args(&shell_config.args) - .arg(command) + let mut child = configure_shell_command(&shell_config, command) .spawn() .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; - // Stream the output - let output_str = self - .stream_shell_output( - child.stdout.take().unwrap(), - child.stderr.take().unwrap(), - peer.clone(), - ) - .await?; + let pid = child.id(); + if let Some(pid) = pid { + tracing::debug!("Shell process spawned with PID: {}", pid); + } else { + tracing::warn!("Shell process spawned but PID not available"); + } - // Wait for the command to complete - child - .wait() - .await - .map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; + // Stream the output and wait for completion with cancellation support + let output_task = self.stream_shell_output( + child.stdout.take().unwrap(), + child.stderr.take().unwrap(), + peer.clone(), + ); + + tokio::select! { + output_result = output_task => { + // Wait for the process to complete + let _exit_status = child.wait().await.map_err(|e| ErrorData::new(ErrorCode::INTERNAL_ERROR, e.to_string(), None))?; + output_result + } + _ = cancellation_token.cancelled() => { + tracing::info!("Cancellation token triggered! Attempting to kill process and all child processes"); + + // Kill the process and its children using platform-specific approach + match kill_process_group(&mut child, pid).await { + Ok(_) => { + tracing::debug!("Successfully killed shell process and child processes"); + } + Err(e) => { + tracing::error!("Failed to kill shell process and child processes: {}", e); + } + } - Ok(output_str) + Err(ErrorData::new( + ErrorCode::INTERNAL_ERROR, + "Shell command was cancelled by user".to_string(), + None, + )) + } + } } /// Stream shell output in real-time and return the combined output. @@ -1336,11 +1407,16 @@ impl DeveloperServer { mod tests { use super::*; use rmcp::handler::server::wrapper::Parameters; - use rmcp::model::NumberOrString; - use rmcp::service::serve_directly; + use rmcp::model::{CancelledNotificationParam, NumberOrString}; + use rmcp::service::{serve_directly, NotificationContext}; + use rmcp::ServerHandler; use serial_test::serial; - use std::fs; + use std::{ + fs, + time::{Duration, Instant}, + }; use tempfile::TempDir; + use tokio::time::timeout; fn create_test_server() -> DeveloperServer { DeveloperServer::new() @@ -3310,9 +3386,9 @@ Additional instructions here. std::env::remove_var("CONTEXT_FILE_NAMES"); } - #[tokio::test] + #[test] #[serial] - async fn test_resolve_path_absolute() { + fn test_resolve_path_absolute() { let temp_dir = tempfile::tempdir().unwrap(); std::env::set_current_dir(&temp_dir).unwrap(); @@ -3393,4 +3469,237 @@ Additional instructions here. let content = fs::read_to_string(&absolute_path).unwrap(); assert_eq!(content.trim(), "Relative path test"); } + + #[test] + #[serial] + #[cfg(unix)] // Unix-specific test using sleep command + fn test_shell_command_cancellation() { + run_shell_test(|| async { + let server = create_test_server(); + let running_service = serve_directly(server.clone(), create_test_transport(), None); + let peer = running_service.peer().clone(); + + let request_id = NumberOrString::Number(123); + + let context = RequestContext { + ct: Default::default(), + id: request_id.clone(), + meta: Default::default(), + extensions: Default::default(), + peer: peer.clone(), + }; + + // Start a long-running shell command in the background + let server_clone = server.clone(); + let shell_task = tokio::spawn(async move { + server_clone + .shell( + Parameters(ShellParams { + command: "sleep 30".to_string(), + }), + context, + ) + .await + }); + + // Give the command a moment to start + tokio::time::sleep(Duration::from_millis(200)).await; + + // Verify the process is tracked + { + let processes = server.running_processes.read().await; + assert!(processes.contains_key("123"), "Process should be tracked"); + } + + let start_time = Instant::now(); + + // Cancel the command + let cancel_params = CancelledNotificationParam { + request_id: request_id, + reason: Some("test cancellation".to_string()), + }; + + let notification_context = NotificationContext { + peer: peer.clone(), + meta: Default::default(), + extensions: Default::default(), + }; + + server + .on_cancelled(cancel_params, notification_context) + .await; + + // Wait for the shell task to complete + let result = timeout(Duration::from_secs(5), shell_task).await; + let elapsed = start_time.elapsed(); + + // Verify the task completed due to cancellation (not timeout) + assert!(result.is_ok(), "Shell task should complete within timeout"); + let task_result = result.unwrap(); + assert!(task_result.is_ok(), "Shell task should not panic"); + + // Verify the command was cancelled quickly (much less than 30 seconds) + assert!( + elapsed < Duration::from_secs(5), + "Command should be cancelled quickly, took {:?}", + elapsed + ); + + // Verify the process is no longer tracked + { + let processes = server.running_processes.read().await; + assert!( + !processes.contains_key("123"), + "Process should be removed from tracking" + ); + } + + cleanup_test_service(running_service, peer); + }); + } + + #[test] + #[serial] + #[cfg(unix)] // Unix-specific test using shell commands + fn test_child_process_cancellation() { + run_shell_test(|| async { + let server = create_test_server(); + let running_service = serve_directly(server.clone(), create_test_transport(), None); + let peer = running_service.peer().clone(); + + let request_id = NumberOrString::Number(456); + + let context = RequestContext { + ct: Default::default(), + id: request_id.clone(), + meta: Default::default(), + extensions: Default::default(), + peer: peer.clone(), + }; + + // Start a command that spawns child processes + let server_clone = server.clone(); + let shell_task = tokio::spawn(async move { + server_clone + .shell( + Parameters(ShellParams { + command: "bash -c 'sleep 60 & wait'".to_string(), + }), + context, + ) + .await + }); + + // Give the command time to start and spawn child processes + tokio::time::sleep(Duration::from_millis(300)).await; + + let start_time = Instant::now(); + + // Cancel the command + let cancel_params = CancelledNotificationParam { + request_id: request_id, + reason: Some("test cancellation".to_string()), + }; + + let notification_context = NotificationContext { + peer: peer.clone(), + meta: Default::default(), + extensions: Default::default(), + }; + + server + .on_cancelled(cancel_params, notification_context) + .await; + + // Wait for completion + let result = timeout(Duration::from_secs(5), shell_task).await; + let elapsed = start_time.elapsed(); + + assert!(result.is_ok(), "Shell task should complete within timeout"); + assert!( + elapsed < Duration::from_secs(5), + "Command with child processes should be cancelled quickly, took {:?}", + elapsed + ); + + cleanup_test_service(running_service, peer); + }); + } + + #[test] + #[serial] + fn test_cancel_nonexistent_process() { + run_shell_test(|| async { + let server = create_test_server(); + let running_service = serve_directly(server.clone(), create_test_transport(), None); + let peer = running_service.peer().clone(); + + // Try to cancel a process that doesn't exist + let cancel_params = CancelledNotificationParam { + request_id: NumberOrString::Number(999), + reason: Some("test cancellation".to_string()), + }; + + let notification_context = NotificationContext { + peer: peer.clone(), + meta: Default::default(), + extensions: Default::default(), + }; + + // This should not panic or cause issues + server + .on_cancelled(cancel_params, notification_context) + .await; + + // Verify no processes are tracked + let processes = server.running_processes.read().await; + assert!(processes.is_empty(), "No processes should be tracked"); + + cleanup_test_service(running_service, peer); + }); + } + + #[test] + #[serial] + #[cfg(unix)] + fn test_successful_shell_command_completion() { + run_shell_test(|| async { + let server = create_test_server(); + let running_service = serve_directly(server.clone(), create_test_transport(), None); + let peer = running_service.peer().clone(); + + let context = RequestContext { + ct: Default::default(), + id: NumberOrString::Number(789), + meta: Default::default(), + extensions: Default::default(), + peer: peer.clone(), + }; + + // Run a quick command that should complete successfully + let result = server + .shell( + Parameters(ShellParams { + command: "echo 'Hello, World!'".to_string(), + }), + context, + ) + .await; + + assert!( + result.is_ok(), + "Simple shell command should succeed: {:?}", + result + ); + + // Verify no processes are left tracked after completion + let processes = server.running_processes.read().await; + assert!( + !processes.contains_key("789"), + "Process should be cleaned up after completion" + ); + + cleanup_test_service(running_service, peer); + }); + } } diff --git a/crates/goose-mcp/src/developer/shell.rs b/crates/goose-mcp/src/developer/shell.rs index b45ec577a3d2..699c9699aa16 100644 --- a/crates/goose-mcp/src/developer/shell.rs +++ b/crates/goose-mcp/src/developer/shell.rs @@ -1,4 +1,8 @@ -use std::env; +use std::{env, process::Stdio}; + +#[cfg(unix)] +#[allow(unused_imports)] // False positive: trait is used for process_group method +use std::os::unix::process::CommandExt; #[derive(Debug, Clone)] pub struct ShellConfig { @@ -105,3 +109,70 @@ pub fn normalize_line_endings(text: &str) -> String { text.replace("\r\n", "\n") } } + +/// Configure a shell command with process group support for proper child process tracking. +/// +/// On Unix systems, creates a new process group so child processes can be killed together. +/// On Windows, the default behavior already supports process tree termination. +pub fn configure_shell_command( + shell_config: &ShellConfig, + command: &str, +) -> tokio::process::Command { + let mut command_builder = tokio::process::Command::new(&shell_config.executable); + command_builder + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .stdin(Stdio::null()) + .kill_on_drop(true) + .env("GOOSE_TERMINAL", "1") + .args(&shell_config.args) + .arg(command); + + // On Unix systems, create a new process group so we can kill child processes + #[cfg(unix)] + { + command_builder.process_group(0); + } + + command_builder +} + +/// Kill a process and all its child processes using platform-specific approaches. +/// +/// On Unix systems, kills the entire process group. +/// On Windows, kills the process tree. +pub async fn kill_process_group( + child: &mut tokio::process::Child, + pid: Option, +) -> Result<(), Box> { + #[cfg(unix)] + { + if let Some(pid) = pid { + // Try SIGTERM first + let _sigterm_result = unsafe { libc::kill(-(pid as i32), libc::SIGTERM) }; + + // Wait a brief moment for graceful shutdown + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + + // Force kill with SIGKILL + let _sigkill_result = unsafe { libc::kill(-(pid as i32), libc::SIGKILL) }; + } + + // Last fallback, return the result of tokio's kill + child.kill().await.map_err(|e| e.into()) + } + + #[cfg(windows)] + { + if let Some(pid) = pid { + // Use taskkill to kill the process tree on Windows + let _kill_result = tokio::process::Command::new("taskkill") + .args(&["/F", "/T", "/PID", &pid.to_string()]) + .output() + .await; + } + + // Return the result of tokio's kill + child.kill().await.map_err(|e| e.into()) + } +}