diff --git a/Cargo.lock b/Cargo.lock index 61beaa6f1db1..ec858362917f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3474,6 +3474,7 @@ dependencies = [ "umya-spreadsheet", "url", "utoipa", + "uuid", "webbrowser 0.8.15", "which", "xcap", diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 82f5328bf4f2..20fda42e738f 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -33,6 +33,7 @@ use goose::agents::{Agent, SessionConfig}; use goose::config::Config; use goose::providers::pricing::initialize_pricing_cache; use goose::session; +use goose_mcp::FilePidTracker; use input::InputResult; use mcp_core::handler::ToolError; use rmcp::model::PromptMessage; @@ -1316,6 +1317,14 @@ impl Session { "The existing call to {} was interrupted. How would you like to proceed?", last_tool_name ); + + // If this was a shell command that might have left processes running, clean them up + if last_tool_name == "developer__shell" { + tokio::spawn(async { + cleanup_shell_processes().await; + }); + } + self.push_message(Message::assistant().with_text(&prompt)); // No need for description update here @@ -1624,6 +1633,33 @@ impl Session { } } +/// Cleanup function to kill tracked subprocess PIDs when shell commands are interrupted. +async fn cleanup_shell_processes() { + let file_tracker = FilePidTracker::new(); + let tracked_pids = file_tracker.get_all_pids(); + + if tracked_pids.is_empty() { + return; + } + if cfg!(windows) { + // On Windows, we can use taskkill to terminate processes by PID + for pid in tracked_pids { + let _ = tokio::process::Command::new("taskkill") + .args(&["/F", "/PID", &pid.to_string()]) + .output() + .await; + } + } else { + // On Unix-like systems, we can use kill to terminate processes by PID + for pid in tracked_pids { + let _ = tokio::process::Command::new("kill") + .args(&["-TERM", &pid.to_string()]) + .output() + .await; + } + } +} + fn get_reasoner() -> Result, anyhow::Error> { use goose::model::ModelConfig; use goose::providers::create; diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index e469b952ea3e..53c7ccb7b140 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -63,6 +63,7 @@ hyper = "1" serde_with = "3" which = "6.0" glob = "0.3" +uuid = { version = "1.0", features = ["v4"] } [dev-dependencies] diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 817f372ffe20..178a6ae7e8b9 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -15,6 +15,7 @@ use std::{ io::{Cursor, Read}, path::{Path, PathBuf}, pin::Pin, + sync::{Arc, Mutex}, }; use tokio::{ io::{AsyncBufReadExt, BufReader}, @@ -23,6 +24,7 @@ use tokio::{ }; use tokio_stream::{wrappers::SplitStream, StreamExt as _}; use url::Url; +use uuid::Uuid; use include_dir::{include_dir, Dir}; use mcp_core::{ @@ -41,13 +43,13 @@ use rmcp::object; use self::editor_models::{create_editor_model, EditorModel}; use self::shell::{expand_path, get_shell_config, is_absolute_path, normalize_line_endings}; +use crate::file_pid_tracker::FilePidTracker; + +use ignore::gitignore::{Gitignore, GitignoreBuilder}; use indoc::indoc; use std::process::Stdio; -use std::sync::{Arc, Mutex}; use xcap::{Monitor, Window}; -use ignore::gitignore::{Gitignore, GitignoreBuilder}; - #[derive(Debug, Serialize, Deserialize)] pub struct PromptTemplate { pub id: String, @@ -631,64 +633,130 @@ impl DeveloperRouter { // Get platform-specific shell configuration let shell_config = get_shell_config(); - // Execute the command using platform-specific shell + // Execute the command using shell with better process cleanup + let wrapped_command = if cfg!(windows) { + command.to_string() + } else { + format!("setsid bash -c '{}'", command) + }; + let mut child = Command::new(&shell_config.executable) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .stdin(Stdio::null()) .kill_on_drop(true) .args(&shell_config.args) - .arg(command) + .arg(&wrapped_command) .spawn() .map_err(|e| ToolError::ExecutionError(e.to_string()))?; - let stdout = BufReader::new(child.stdout.take().unwrap()); - let stderr = BufReader::new(child.stderr.take().unwrap()); + // Store the child PID for cleanup - generate a unique execution ID + let execution_id = format!("exec_{}", Uuid::new_v4().simple()); + + let child_pid = child.id(); + + // Store the PID globally for cleanup if cancellation occurs + if let Some(pid) = child_pid { + let file_tracker = FilePidTracker::new(); + file_tracker.register_process(execution_id.clone(), pid, command.to_string()); + } + + let stdout = child.stdout.take().unwrap(); + let stderr = child.stderr.take().unwrap(); + + let mut stdout_reader = BufReader::new(stdout); + let mut stderr_reader = BufReader::new(stderr); let output_task = tokio::spawn(async move { let mut combined_output = String::new(); - // We have the individual two streams above, now merge them into one unified stream of - // an enum. ref https://blog.yoshuawuyts.com/futures-concurrency-3 - let stdout = SplitStream::new(stdout.split(b'\n')).map(|v| ("stdout", v)); - let stderr = SplitStream::new(stderr.split(b'\n')).map(|v| ("stderr", v)); - let mut merged = stdout.merge(stderr); - - while let Some((key, line)) = merged.next().await { - let mut line = line?; - // Re-add this as clients expect it - line.push(b'\n'); - // Here we always convert to UTF-8 so agents don't have to deal with corrupted output - let line = String::from_utf8_lossy(&line); - - combined_output.push_str(&line); - - notifier - .try_send(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: JsonRpcVersion2_0, - notification: Notification { - method: "notifications/message".to_string(), - params: object!({ - "level": "info", - "data": { - "type": "shell", - "stream": key, - "output": line, + let mut stdout_buf = Vec::new(); + let mut stderr_buf = Vec::new(); + + let mut stdout_done = false; + let mut stderr_done = false; + + loop { + tokio::select! { + n = stdout_reader.read_until(b'\n', &mut stdout_buf), if !stdout_done => { + if n? == 0 { + stdout_done = true; + } else { + let line = String::from_utf8_lossy(&stdout_buf); + + notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "level": "info", + "data": { + "type": "shell", + "stream": "stdout", + "output": line.to_string(), + } + }), + extensions: Default::default(), + } + })).ok(); + + combined_output.push_str(&line); + stdout_buf.clear(); + } + } + + n = stderr_reader.read_until(b'\n', &mut stderr_buf), if !stderr_done => { + if n? == 0 { + stderr_done = true; + } else { + let line = String::from_utf8_lossy(&stderr_buf); + + notifier.try_send(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: JsonRpcVersion2_0, + notification: Notification { + method: "notifications/message".to_string(), + params: object!({ + "level": "info", + "data": { + "type": "shell", + "stream": "stderr", + "output": line.to_string(), + } + }), + extensions: Default::default(), } - }), - extensions: Default::default(), - }, - })) - .ok(); + })).ok(); + + combined_output.push_str(&line); + stderr_buf.clear(); + } + } + + else => break, + } + + if stdout_done && stderr_done { + break; + } } Ok::<_, std::io::Error>(combined_output) }); // Wait for the command to complete and get output - child - .wait() - .await - .map_err(|e| ToolError::ExecutionError(e.to_string()))?; + let exit_status_result = child.wait().await; + + match exit_status_result { + Ok(exit_status) => { + if exit_status.success() { + // Always use file-based tracking for consistency + let file_tracker = FilePidTracker::new(); + file_tracker.unregister_process(&execution_id); + } + } + Err(e) => { + return Err(ToolError::ExecutionError(e.to_string())); + } + } let output_str = match output_task.await { Ok(result) => result.map_err(|e| ToolError::ExecutionError(e.to_string()))?, @@ -1714,6 +1782,7 @@ mod tests { json!({ "command": "Get-ChildItem" }), + dummy_sender(), ) .await; assert!(result.is_ok()); diff --git a/crates/goose-mcp/src/file_pid_tracker.rs b/crates/goose-mcp/src/file_pid_tracker.rs new file mode 100644 index 000000000000..89be6dd91e89 --- /dev/null +++ b/crates/goose-mcp/src/file_pid_tracker.rs @@ -0,0 +1,128 @@ +use etcetera::{choose_app_strategy, AppStrategy}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::path::PathBuf; + +/// Process information stored in the PID tracking file +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProcessInfo { + pub pid: u32, + pub command: String, + pub timestamp: u64, +} + +/// File-based PID tracker that persists process information across different +/// parts of the application (CLI, session, MCP server) +#[derive(Debug)] +pub struct FilePidTracker { + file_path: PathBuf, +} + +impl FilePidTracker { + pub fn new() -> Self { + // Use the same app strategy as the rest of the application + let file_path = choose_app_strategy(crate::APP_STRATEGY.clone()) + .map(|strategy| strategy.in_data_dir("tracked_pids.json")) + .unwrap_or_else(|_| { + PathBuf::from( + shellexpand::tilde("~/.local/share/goose/tracked_pids.json").to_string(), + ) + }); + + // Create the directory if it doesn't exist + if let Some(parent) = file_path.parent() { + let _ = fs::create_dir_all(parent); + } + + Self { file_path } + } + + /// Read PIDs from the JSON file + fn read_pids(&self) -> HashMap { + if !self.file_path.exists() { + return HashMap::new(); + } + + match fs::read_to_string(&self.file_path) { + Ok(content) => match serde_json::from_str::>(&content) { + Ok(pids) => pids, + Err(_) => HashMap::new(), + }, + Err(_) => HashMap::new(), + } + } + + /// Write PIDs to the JSON file + fn write_pids(&self, pids: &HashMap) { + match serde_json::to_string_pretty(pids) { + Ok(content) => { + let _ = fs::write(&self.file_path, content); + } + Err(_) => {} + } + } + + /// Register a process PID with execution ID and command + pub fn register_process(&self, execution_id: String, pid: u32, command: String) { + let mut pids = self.read_pids(); + let process_info = ProcessInfo { + pid, + command, + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + pids.insert(execution_id.clone(), process_info); + self.write_pids(&pids); + } + + /// Unregister a process PID by execution ID + pub fn unregister_process(&self, execution_id: &str) -> Option { + let mut pids = self.read_pids(); + let removed = pids.remove(execution_id); + self.write_pids(&pids); + if let Some(ref process_info) = removed { + Some(process_info.pid) + } else { + None + } + } + + /// Get all currently tracked PIDs + pub fn get_all_pids(&self) -> Vec { + let pids = self.read_pids(); + pids.values().map(|info| info.pid).collect() + } + + /// Clear all tracked PIDs + pub fn clear_all(&self) { + let empty_pids: HashMap = HashMap::new(); + self.write_pids(&empty_pids); + } + + /// Clean up old PIDs (older than 1 hour) to prevent the file from growing indefinitely + pub fn cleanup_old_pids(&self) { + let mut pids = self.read_pids(); + let current_time = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + let one_hour = 3600; // 1 hour in seconds + let initial_count = pids.len(); + + pids.retain(|_, info| current_time - info.timestamp < one_hour); + + if pids.len() != initial_count { + self.write_pids(&pids); + } + } +} + +impl Default for FilePidTracker { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/goose-mcp/src/lib.rs b/crates/goose-mcp/src/lib.rs index c112c8fee3e9..50725f8e1e28 100644 --- a/crates/goose-mcp/src/lib.rs +++ b/crates/goose-mcp/src/lib.rs @@ -9,12 +9,14 @@ pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { pub mod computercontroller; mod developer; +pub mod file_pid_tracker; pub mod google_drive; mod memory; mod tutorial; pub use computercontroller::ComputerControllerRouter; pub use developer::DeveloperRouter; +pub use file_pid_tracker::FilePidTracker; pub use google_drive::GoogleDriveRouter; pub use memory::MemoryRouter; pub use tutorial::TutorialRouter;