diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index c6147e853ccd..5bba92e274f3 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -64,13 +64,22 @@ struct StreamingChoice { finish_reason: Option, } +#[derive(Serialize, Deserialize, Debug)] +struct StreamingError { + message: Option, + r#type: Option, + code: Option, +} + #[derive(Serialize, Deserialize, Debug)] struct StreamingChunk { + #[serde(default)] choices: Vec, created: Option, id: Option, usage: Option, model: Option, + error: Option, } pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec { @@ -532,6 +541,26 @@ fn strip_data_prefix(line: &str) -> Option<&str> { line.strip_prefix("data: ").map(|s| s.trim()) } +fn check_streaming_error(chunk: &StreamingChunk) -> anyhow::Result<()> { + if let Some(ref err) = chunk.error { + let msg = err.message.as_deref().unwrap_or("Unknown error"); + let code = err + .code + .as_ref() + .map(|c| c.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + let err_type = err.r#type.as_deref().unwrap_or("server_error"); + Err(anyhow!( + "Server error during streaming (code: {}, type: {}): {}", + code, + err_type, + msg + )) + } else { + Ok(()) + } +} + pub fn response_to_streaming_message( mut stream: S, ) -> impl Stream, Option)>> + 'static @@ -559,6 +588,8 @@ where .ok_or_else(|| anyhow!("unexpected stream format"))?) .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + check_streaming_error(&chunk)?; + if !chunk.choices.is_empty() { if let Some(details) = &chunk.choices[0].delta.reasoning_details { accumulated_reasoning.extend(details.iter().cloned()); @@ -598,6 +629,8 @@ where let tool_chunk: StreamingChunk = serde_json::from_str(line) .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + check_streaming_error(&tool_chunk)?; + if let Some(chunk_usage) = extract_usage_with_output_tokens(&tool_chunk) { usage = Some(chunk_usage); } @@ -1879,7 +1912,6 @@ data: [DONE]"#; #[test] fn test_response_to_message_with_reasoning_content() -> anyhow::Result<()> { - // Test capturing reasoning_content from DeepSeek reasoning models let response = json!({ "choices": [{ "role": "assistant", @@ -1898,14 +1930,12 @@ data: [DONE]"#; let message = response_to_message(&response)?; assert_eq!(message.content.len(), 2); - // First should be reasoning content if let MessageContent::Reasoning(reasoning) = &message.content[0] { assert_eq!(reasoning.text, "Let me think about this step by step..."); } else { panic!("Expected Reasoning content"); } - // Second should be text content if let MessageContent::Text(text) = &message.content[1] { assert_eq!(text.text, "The answer is 9.11 is greater than 9.8"); } else { @@ -1917,12 +1947,10 @@ data: [DONE]"#; #[test] fn test_format_messages_with_reasoning_content() -> anyhow::Result<()> { - // Test that reasoning_content is properly included in formatted messages let mut message = Message::assistant() .with_content(MessageContent::reasoning("Thinking through the problem...")) .with_text("The result is 42"); - // Add a tool call to test that reasoning_content works with tool calls message = message.with_tool_request( "tool1", Ok(rmcp::model::CallToolRequestParams { @@ -1938,20 +1966,108 @@ data: [DONE]"#; assert_eq!(spec.len(), 1); assert_eq!(spec[0]["role"], "assistant"); - // Should have reasoning_content field assert!(spec[0].get("reasoning_content").is_some()); assert_eq!( spec[0]["reasoning_content"], "Thinking through the problem..." ); - // Should have content assert_eq!(spec[0]["content"], "The result is 42"); - // Should have tool_calls assert!(spec[0]["tool_calls"].is_array()); assert_eq!(spec[0]["tool_calls"][0]["function"]["name"], "test_tool"); Ok(()) } + + #[tokio::test] + async fn test_streaming_error_chunk_returns_server_error() { + let response_lines = r#" +data: {"error":{"code":500,"message":"Invalid diff: now finding less tool calls!","type":"server_error"}} +"#; + + let lines: Vec = response_lines.lines().map(|s| s.to_string()).collect(); + let response_stream = tokio_stream::iter(lines.into_iter().map(Ok)); + let messages = response_to_streaming_message(response_stream); + pin!(messages); + + let mut found_error = false; + while let Some(result) = messages.next().await { + if let Err(e) = result { + let err_str = e.to_string(); + assert!( + err_str.contains("Server error during streaming"), + "Expected server error message, got: {}", + err_str + ); + assert!( + err_str.contains("Invalid diff: now finding less tool calls!"), + "Expected original error message preserved, got: {}", + err_str + ); + found_error = true; + break; + } + } + assert!(found_error, "Expected an error from streaming error chunk"); + } + + #[tokio::test] + async fn test_streaming_error_chunk_during_tool_calls() { + let response_lines = r#" +data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"developer__shell","arguments":""}}]},"finish_reason":null}],"usage":null} +data: {"error":{"code":500,"message":"Internal server error","type":"server_error"}} +"#; + + let lines: Vec = response_lines.lines().map(|s| s.to_string()).collect(); + let response_stream = tokio_stream::iter(lines.into_iter().map(Ok)); + let messages = response_to_streaming_message(response_stream); + pin!(messages); + + let mut found_error = false; + while let Some(result) = messages.next().await { + if let Err(e) = result { + assert!( + e.to_string().contains("Server error during streaming"), + "Expected server error, got: {}", + e + ); + found_error = true; + break; + } + } + assert!( + found_error, + "Expected error when server sends error mid-tool-call" + ); + } + + #[tokio::test] + async fn test_streaming_error_chunk_with_no_choices_no_crash() { + let response_lines = r#" +data: {"error":{"message":"rate limit exceeded","type":"rate_limit_error","code":429}} +"#; + + let lines: Vec = response_lines.lines().map(|s| s.to_string()).collect(); + let response_stream = tokio_stream::iter(lines.into_iter().map(Ok)); + let messages = response_to_streaming_message(response_stream); + pin!(messages); + + let mut found_error = false; + while let Some(result) = messages.next().await { + if let Err(e) = result { + assert!( + e.to_string().contains("rate limit exceeded"), + "Expected rate limit error, got: {}", + e + ); + found_error = true; + break; + } + } + assert!( + found_error, + "Expected error from error-only streaming chunk" + ); + } }