diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 746e1d6b73f8..ecfb0385ccb8 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -1,6 +1,6 @@ use anyhow::Result; use async_trait::async_trait; -use serde_json::{json, Value}; +use serde_json::Value; use std::path::PathBuf; use std::process::Stdio; use std::sync::{Arc, OnceLock}; @@ -13,7 +13,7 @@ use super::base::{ }; use super::cli_common::{error_from_event, extract_usage_tokens}; use super::errors::ProviderError; -use super::utils::{filter_extensions_from_system_prompt, RequestLog}; +use super::utils::filter_extensions_from_system_prompt; use crate::config::base::GeminiCliCommand; use crate::config::search_path::SearchPaths; use crate::config::Config; @@ -21,6 +21,7 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::ConfigKey; use crate::subprocess::configure_subprocess; +use async_stream::try_stream; use futures::future::BoxFuture; use rmcp::model::Role; use rmcp::model::Tool; @@ -63,10 +64,6 @@ impl GeminiCliProvider { self.cli_session_id.get().map(|s| s.as_str()) } - fn set_session_id(&self, sid: String) { - let _ = self.cli_session_id.set(sid); - } - fn last_user_message_text(messages: &[Message]) -> String { messages .iter() @@ -155,144 +152,6 @@ impl GeminiCliProvider { Ok((child, BufReader::new(stdout))) } - - async fn execute_command( - &self, - system: &str, - messages: &[Message], - _tools: &[Tool], - model_name: &str, - ) -> Result, ProviderError> { - let (mut child, mut reader) = self.spawn_command(system, messages, model_name)?; - - // Drain stderr concurrently to avoid pipe deadlock - let stderr_task = tokio::spawn(async move { - let mut buf = String::new(); - if let Some(mut stderr) = child.stderr.take() { - let _ = stderr.read_to_string(&mut buf).await; - } - (child, buf) - }); - - let mut events = Vec::new(); - let mut line = String::new(); - - loop { - line.clear(); - match reader.read_line(&mut line).await { - Ok(0) => break, - Ok(_) => { - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - - match serde_json::from_str::(trimmed) { - Ok(parsed) => { - if parsed.get("type").and_then(|t| t.as_str()) == Some("init") { - if let Some(sid) = parsed.get("session_id").and_then(|s| s.as_str()) - { - self.set_session_id(sid.to_string()); - } - } - events.push(parsed); - } - Err(_) => { - tracing::warn!(line = trimmed, "Non-JSON line in stream-json output"); - } - } - } - Err(e) => { - return Err(ProviderError::RequestFailed(format!( - "Failed to read output: {e}" - ))); - } - } - } - - let (mut child, stderr_text) = stderr_task - .await - .map_err(|e| ProviderError::RequestFailed(format!("Failed to read stderr: {e}")))?; - - let exit_status = child.wait().await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to wait for command: {e}")) - })?; - - if !exit_status.success() { - let stderr_snippet = stderr_text.trim(); - let detail = if stderr_snippet.is_empty() { - format!("exit code {:?}", exit_status.code()) - } else { - format!("exit code {:?}: {stderr_snippet}", exit_status.code()) - }; - return Err(ProviderError::RequestFailed(format!( - "Gemini CLI command failed ({detail})" - ))); - } - - tracing::debug!( - "Gemini CLI executed successfully, got {} events", - events.len() - ); - - Ok(events) - } - - fn parse_stream_json_response(events: &[Value]) -> Result<(Message, Usage), ProviderError> { - let mut all_text_content = Vec::new(); - let mut all_thinking_content = Vec::new(); - let mut usage = Usage::default(); - - for parsed in events { - match parsed.get("type").and_then(|t| t.as_str()) { - Some("thinking") => { - if let Some(content) = parsed.get("content").and_then(|c| c.as_str()) { - if !content.is_empty() { - all_thinking_content.push(content.to_string()); - } - } - } - Some("message") => { - if parsed.get("role").and_then(|r| r.as_str()) == Some("assistant") { - if let Some(content) = parsed.get("content").and_then(|c| c.as_str()) { - if !content.is_empty() { - all_text_content.push(content.to_string()); - } - } - } - } - Some("result") => { - if let Some(stats) = parsed.get("stats") { - usage = extract_usage_tokens(stats); - } - } - Some("error") => { - return Err(error_from_event("Gemini CLI", parsed)); - } - _ => {} - } - } - - let combined_text = all_text_content.join(""); - if combined_text.is_empty() { - return Err(ProviderError::RequestFailed( - "No text content found in response".to_string(), - )); - } - - let mut content = Vec::new(); - - let combined_thinking = all_thinking_content.join(""); - if !combined_thinking.is_empty() { - content.push(MessageContent::thinking(combined_thinking, String::new())); - } - - content.push(MessageContent::text(combined_text)); - - let message = Message::new(Role::Assistant, chrono::Utc::now().timestamp(), content); - - Ok((message, usage)) - } } impl ProviderDef for GeminiCliProvider { @@ -337,7 +196,7 @@ impl Provider for GeminiCliProvider { } #[tracing::instrument( - skip(self, model_config, system, messages, tools), + skip(self, model_config, system, messages, _tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] async fn stream( @@ -346,7 +205,7 @@ impl Provider for GeminiCliProvider { _session_id: &str, // CLI has no external session-id flag to propagate. system: &str, messages: &[Message], - tools: &[Tool], + _tools: &[Tool], ) -> Result { if super::cli_common::is_session_description_request(system) { let (message, provider_usage) = super::cli_common::generate_simple_session_description( @@ -356,40 +215,113 @@ impl Provider for GeminiCliProvider { return Ok(stream_from_single_message(message, provider_usage)); } - let payload = json!({ - "command": self.command, - "model": model_config.model_name, - "system": system, - "messages": messages.len() - }); + let (mut child, mut reader) = + self.spawn_command(system, messages, &model_config.model_name)?; + let session_id_lock = Arc::clone(&self.cli_session_id); + let model_name = model_config.model_name.clone(); + let message_id = uuid::Uuid::new_v4().to_string(); - let mut log = RequestLog::start(model_config, &payload).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to start request log: {e}")) - })?; + let stderr = child.stderr.take(); + let stderr_drain = tokio::spawn(async move { + let mut buf = String::new(); + if let Some(mut stderr) = stderr { + let _ = AsyncReadExt::read_to_string(&mut stderr, &mut buf).await; + } + buf + }); - let events = self - .execute_command(system, messages, tools, &model_config.model_name) - .await?; - let (message, usage) = Self::parse_stream_json_response(&events)?; + Ok(Box::pin(try_stream! { + let mut line = String::new(); + let mut accumulated_usage = Usage::default(); + let stream_timestamp = chrono::Utc::now().timestamp(); + + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } - let response = json!({ - "events": events.len(), - "usage": usage - }); + if let Ok(parsed) = serde_json::from_str::(trimmed) { + match parsed.get("type").and_then(|t| t.as_str()) { + Some("init") => { + if let Some(sid) = + parsed.get("session_id").and_then(|s| s.as_str()) + { + let _ = session_id_lock.set(sid.to_string()); + } + } + Some("message") => { + let is_assistant = parsed.get("role").and_then(|r| r.as_str()) + == Some("assistant"); + let content = parsed + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + if is_assistant && !content.is_empty() { + let mut partial = Message::new( + Role::Assistant, + stream_timestamp, + vec![MessageContent::text(content)], + ); + partial.id = Some(message_id.clone()); + yield (Some(partial), None); + } + } + Some("result") => { + if let Some(stats) = parsed.get("stats") { + accumulated_usage = extract_usage_tokens(stats); + } + break; + } + Some("error") => { + let _ = child.wait().await; + Err(error_from_event("Gemini CLI", &parsed))?; + } + _ => {} + } + } else { + tracing::warn!(line = trimmed, "Non-JSON line in stream-json output"); + } + } + Err(e) => { + let _ = child.wait().await; + Err(ProviderError::RequestFailed(format!( + "Failed to read streaming output: {e}" + )))?; + } + } + } - log.write(&response, Some(&usage)).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write request log: {e}")) - })?; + let stderr_text = stderr_drain.await.unwrap_or_default(); + let exit_status = child.wait().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to wait for command: {e}")) + })?; + + if !exit_status.success() { + let stderr_snippet = stderr_text.trim(); + let detail = if stderr_snippet.is_empty() { + format!("exit code {:?}", exit_status.code()) + } else { + format!("exit code {:?}: {stderr_snippet}", exit_status.code()) + }; + Err(ProviderError::RequestFailed(format!( + "Gemini CLI command failed ({detail})" + )))?; + } - let provider_usage = ProviderUsage::new(model_config.model_name.clone(), usage); - Ok(stream_from_single_message(message, provider_usage)) + let provider_usage = ProviderUsage::new(model_name, accumulated_usage); + yield (None, Some(provider_usage)); + })) } } #[cfg(test)] mod tests { use super::*; - use serde_json::json; fn make_provider() -> GeminiCliProvider { GeminiCliProvider { @@ -400,70 +332,6 @@ mod tests { } } - #[test] - fn test_parse_stream_json_response() { - let events = vec![ - json!({"type":"init","session_id":"abc","model":"gemini-2.5-pro"}), - json!({"type":"message","role":"user","content":"Hi"}), - json!({"type":"message","role":"assistant","content":"Hello ","delta":true}), - json!({"type":"message","role":"assistant","content":"there!","delta":true}), - json!({"type":"result","status":"success","stats":{"input_tokens":20,"output_tokens":5,"total_tokens":25}}), - ]; - let (message, usage) = GeminiCliProvider::parse_stream_json_response(&events).unwrap(); - assert_eq!(message.role, Role::Assistant); - assert_eq!(message.as_concat_text(), "Hello there!"); - assert_eq!(usage.input_tokens, Some(20)); - assert_eq!(usage.output_tokens, Some(5)); - - let error_events = vec![ - json!({"type":"init","session_id":"abc"}), - json!({"type":"error","error":"Rate limit exceeded"}), - ]; - let err = GeminiCliProvider::parse_stream_json_response(&error_events).unwrap_err(); - assert!(err.to_string().contains("Rate limit exceeded")); - - let empty: Vec = vec![]; - assert!(GeminiCliProvider::parse_stream_json_response(&empty).is_err()); - } - - #[test] - fn test_parse_thinking_blocks() { - let events = vec![ - json!({"type":"init","session_id":"abc","model":"gemini-2.5-pro"}), - json!({"type":"thinking","content":"Let me reason about this...","delta":true}), - json!({"type":"thinking","content":" Step 1: analyze the problem.","delta":true}), - json!({"type":"message","role":"assistant","content":"Here is the answer.","delta":true}), - json!({"type":"result","status":"success","stats":{"input_tokens":30,"output_tokens":15,"total_tokens":45}}), - ]; - let (message, usage) = GeminiCliProvider::parse_stream_json_response(&events).unwrap(); - assert_eq!(message.role, Role::Assistant); - - // Should have thinking content followed by text content - assert_eq!(message.content.len(), 2); - let thinking = message.content[0] - .as_thinking() - .expect("first content should be thinking"); - assert_eq!( - thinking.thinking, - "Let me reason about this... Step 1: analyze the problem." - ); - assert_eq!(message.as_concat_text(), "Here is the answer."); - assert_eq!(usage.input_tokens, Some(30)); - assert_eq!(usage.output_tokens, Some(15)); - } - - #[test] - fn test_parse_no_thinking_blocks() { - // When there's no thinking, message should only have text content - let events = vec![ - json!({"type":"message","role":"assistant","content":"Direct answer.","delta":true}), - json!({"type":"result","status":"success","stats":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}), - ]; - let (message, _usage) = GeminiCliProvider::parse_stream_json_response(&events).unwrap(); - assert_eq!(message.content.len(), 1); - assert_eq!(message.as_concat_text(), "Direct answer."); - } - #[test] fn test_build_prompt_first_and_resume() { let provider = make_provider(); @@ -477,7 +345,7 @@ mod tests { assert!(prompt.contains("You are helpful.")); assert!(prompt.contains("Hello")); - provider.set_session_id("session-123".to_string()); + let _ = provider.cli_session_id.set("session-123".to_string()); let messages = vec![ Message::new(Role::User, 0, vec![MessageContent::text("Hello")]), Message::new(Role::Assistant, 0, vec![MessageContent::text("Hi!")]),