Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::conversation::message::{Message, MessageContent, ProviderMetadata};
use crate::mcp_utils::extract_text_from_resource;
use crate::model::ModelConfig;
use crate::providers::base::{ProviderUsage, Usage};
use crate::providers::errors::ProviderError;
use crate::providers::utils::{
convert_image, detect_image_path, extract_reasoning_effort, is_valid_function_name,
load_image_file, safely_parse_json, sanitize_function_name, ImageFormat,
Expand Down Expand Up @@ -528,6 +529,32 @@ fn strip_data_prefix(line: &str) -> Option<&str> {
line.strip_prefix("data: ").map(|s| s.trim())
}

fn parse_streaming_chunk(line: &str) -> Result<StreamingChunk, ProviderError> {
let value: Value = serde_json::from_str(line).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse streaming chunk: {e}: {line:?}"))
})?;

if let Some(error) = value.get("error") {
let message = error
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown server error");
return Err(ProviderError::ServerError(message.to_string()));
}

if value.get("object").and_then(|o| o.as_str()) == Some("error") {
let message = value
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Unknown server error");
return Err(ProviderError::ServerError(message.to_string()));
}

serde_json::from_value(value).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse streaming chunk: {e}: {line:?}"))
})
}

pub fn response_to_streaming_message<S>(
mut stream: S,
) -> impl Stream<Item = anyhow::Result<(Option<Message>, Option<ProviderUsage>)>> + 'static
Expand All @@ -551,9 +578,9 @@ where
continue
}

let chunk: StreamingChunk = serde_json::from_str(line
.ok_or_else(|| anyhow!("unexpected stream format"))?)
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
let chunk: StreamingChunk = parse_streaming_chunk(
line.ok_or_else(|| anyhow!("unexpected stream format"))?
)?;

if !chunk.choices.is_empty() {
if let Some(details) = &chunk.choices[0].delta.reasoning_details {
Expand Down Expand Up @@ -591,8 +618,7 @@ where
let response_str = response_chunk?;
if let Some(line) = strip_data_prefix(&response_str) {

let tool_chunk: StreamingChunk = serde_json::from_str(line)
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;
let tool_chunk: StreamingChunk = parse_streaming_chunk(line)?;

if let Some(chunk_usage) = extract_usage_with_output_tokens(&tool_chunk) {
usage = Some(chunk_usage);
Expand Down Expand Up @@ -820,6 +846,7 @@ mod tests {
use rmcp::model::CallToolResult;
use rmcp::object;
use serde_json::json;
use test_case::test_case;
use tokio::pin;
use tokio_stream::{self, StreamExt};

Expand Down Expand Up @@ -1889,4 +1916,42 @@ data: [DONE]"#;

Ok(())
}

#[test_case(
"data: {\"error\":{\"message\":\"Internal server error\",\"type\":\"server_error\",\"code\":500}}\ndata: [DONE]",
"Internal server error";
"openai error format"
)]
#[test_case(
"data: {\"object\":\"error\",\"message\":\"CUDA out of memory\",\"code\":500}\ndata: [DONE]",
"CUDA out of memory";
"vllm error format"
)]
#[test_case(
"data: {\"error\":{\"message\":\"Rate limit exceeded\",\"type\":\"rate_limit_error\"}}",
"Rate limit exceeded";
"error as first chunk"
)]
#[tokio::test]
async fn test_mid_stream_server_error(response_lines: &str, expected_msg: &str) {
let lines: Vec<String> = response_lines.lines().map(|s| s.to_string()).collect();
let response_stream = tokio_stream::iter(lines.into_iter().map(Ok));
let mut messages = std::pin::pin!(response_to_streaming_message(response_stream));
let mut found_error = false;
while let Some(result) = messages.next().await {
if let Err(e) = result {
let err_str = e.to_string();
assert!(
err_str.contains(expected_msg),
"unexpected error text: {err_str}"
);
found_error = true;
break;
}
}
assert!(
found_error,
"expected an error but stream completed successfully"
);
}
}
3 changes: 2 additions & 1 deletion crates/goose/src/providers/openai_compatible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ pub fn stream_openai_compat(
pin!(message_stream);
while let Some(message) = message_stream.next().await {
let (message, usage) = message.map_err(|e|
ProviderError::RequestFailed(format!("Stream decode error: {}", e))
e.downcast::<ProviderError>()
.unwrap_or_else(|e| ProviderError::RequestFailed(format!("Stream decode error: {e}")))
)?;
log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
yield (message, usage);
Expand Down
Loading