diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index cd00a07d07c4..bb9e47f01a10 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -235,7 +235,9 @@ impl Agent { }; Ok(Box::pin(try_stream! { - while let Some(Ok((mut message, usage))) = stream.next().await { + while let Some(result) = stream.next().await { + let (mut message, usage) = result?; + // Store the model information in the global store if let Some(usage) = usage.as_ref() { crate::providers::base::set_current_model(&usage.model); @@ -483,4 +485,44 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_stream_error_propagation() { + use futures::StreamExt; + + type StreamItem = Result<(Option, Option), ProviderError>; + let stream = futures::stream::iter(vec![ + Ok((Some(Message::assistant().with_text("chunk1")), None)), + Ok((Some(Message::assistant().with_text("chunk2")), None)), + Err(ProviderError::RequestFailed( + "simulated stream error".to_string(), + )), + ] as Vec); + + let mut pinned = Box::pin(stream); + let mut results = Vec::new(); + let mut error_seen = false; + + while let Some(result) = pinned.next().await { + match result { + Ok((message, _usage)) => { + if let Some(msg) = message { + results.push(msg.as_concat_text()); + } + } + Err(_e) => { + error_seen = true; + break; + } + } + } + + assert_eq!(results.len(), 2); + assert_eq!(results[0], "chunk1"); + assert_eq!(results[1], "chunk2"); + assert!( + error_seen, + "Error should have been propagated, not silently ignored" + ); + } } diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 937d9658ffde..1db03270fedf 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -18,10 +18,11 @@ use serde_json::{json, Value}; use std::borrow::Cow; use std::ops::Deref; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default)] struct DeltaToolCallFunction { name: Option, - arguments: String, // chunk of encoded JSON, + #[serde(default)] + arguments: String, } #[derive(Serialize, Deserialize, Debug)]