diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 465b9d356c3c..154975c85dc1 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -2,6 +2,7 @@ use crate::conversation::message::{Message, MessageContent, ProviderMetadata}; use crate::mcp_utils::extract_text_from_resource; use crate::model::ModelConfig; use crate::providers::base::{ProviderUsage, Usage}; +use crate::providers::errors::ProviderError; use crate::providers::utils::{ convert_image, detect_image_path, extract_reasoning_effort, is_valid_function_name, load_image_file, safely_parse_json, sanitize_function_name, ImageFormat, @@ -528,6 +529,32 @@ fn strip_data_prefix(line: &str) -> Option<&str> { line.strip_prefix("data: ").map(|s| s.trim()) } +fn parse_streaming_chunk(line: &str) -> Result { + let value: Value = serde_json::from_str(line).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse streaming chunk: {e}: {line:?}")) + })?; + + if let Some(error) = value.get("error") { + let message = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown server error"); + return Err(ProviderError::ServerError(message.to_string())); + } + + if value.get("object").and_then(|o| o.as_str()) == Some("error") { + let message = value + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown server error"); + return Err(ProviderError::ServerError(message.to_string())); + } + + serde_json::from_value(value).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse streaming chunk: {e}: {line:?}")) + }) +} + pub fn response_to_streaming_message( mut stream: S, ) -> impl Stream, Option)>> + 'static @@ -551,9 +578,9 @@ where continue } - let chunk: StreamingChunk = serde_json::from_str(line - .ok_or_else(|| anyhow!("unexpected stream format"))?) - .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let chunk: StreamingChunk = parse_streaming_chunk( + line.ok_or_else(|| anyhow!("unexpected stream format"))? + )?; if !chunk.choices.is_empty() { if let Some(details) = &chunk.choices[0].delta.reasoning_details { @@ -591,8 +618,7 @@ where let response_str = response_chunk?; if let Some(line) = strip_data_prefix(&response_str) { - let tool_chunk: StreamingChunk = serde_json::from_str(line) - .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + let tool_chunk: StreamingChunk = parse_streaming_chunk(line)?; if let Some(chunk_usage) = extract_usage_with_output_tokens(&tool_chunk) { usage = Some(chunk_usage); @@ -820,6 +846,7 @@ mod tests { use rmcp::model::CallToolResult; use rmcp::object; use serde_json::json; + use test_case::test_case; use tokio::pin; use tokio_stream::{self, StreamExt}; @@ -1889,4 +1916,42 @@ data: [DONE]"#; Ok(()) } + + #[test_case( + "data: {\"error\":{\"message\":\"Internal server error\",\"type\":\"server_error\",\"code\":500}}\ndata: [DONE]", + "Internal server error"; + "openai error format" + )] + #[test_case( + "data: {\"object\":\"error\",\"message\":\"CUDA out of memory\",\"code\":500}\ndata: [DONE]", + "CUDA out of memory"; + "vllm error format" + )] + #[test_case( + "data: {\"error\":{\"message\":\"Rate limit exceeded\",\"type\":\"rate_limit_error\"}}", + "Rate limit exceeded"; + "error as first chunk" + )] + #[tokio::test] + async fn test_mid_stream_server_error(response_lines: &str, expected_msg: &str) { + let lines: Vec = response_lines.lines().map(|s| s.to_string()).collect(); + let response_stream = tokio_stream::iter(lines.into_iter().map(Ok)); + let mut messages = std::pin::pin!(response_to_streaming_message(response_stream)); + let mut found_error = false; + while let Some(result) = messages.next().await { + if let Err(e) = result { + let err_str = e.to_string(); + assert!( + err_str.contains(expected_msg), + "unexpected error text: {err_str}" + ); + found_error = true; + break; + } + } + assert!( + found_error, + "expected an error but stream completed successfully" + ); + } } diff --git a/crates/goose/src/providers/openai_compatible.rs b/crates/goose/src/providers/openai_compatible.rs index e996c82d695f..313dfdfbf6c3 100644 --- a/crates/goose/src/providers/openai_compatible.rs +++ b/crates/goose/src/providers/openai_compatible.rs @@ -248,7 +248,8 @@ pub fn stream_openai_compat( pin!(message_stream); while let Some(message) = message_stream.next().await { let (message, usage) = message.map_err(|e| - ProviderError::RequestFailed(format!("Stream decode error: {}", e)) + e.downcast::() + .unwrap_or_else(|e| ProviderError::RequestFailed(format!("Stream decode error: {e}"))) )?; log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?; yield (message, usage);