Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 124 additions & 8 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,22 @@ struct StreamingChoice {
finish_reason: Option<String>,
}

#[derive(Serialize, Deserialize, Debug)]
struct StreamingError {
message: Option<String>,
r#type: Option<String>,
code: Option<Value>,
}

#[derive(Serialize, Deserialize, Debug)]
struct StreamingChunk {
#[serde(default)]
choices: Vec<StreamingChoice>,
created: Option<i64>,
id: Option<String>,
usage: Option<Value>,
model: Option<String>,
error: Option<StreamingError>,
}

pub fn format_messages(messages: &[Message], image_format: &ImageFormat) -> Vec<Value> {
Expand Down Expand Up @@ -532,6 +541,26 @@ fn strip_data_prefix(line: &str) -> Option<&str> {
line.strip_prefix("data: ").map(|s| s.trim())
}

fn check_streaming_error(chunk: &StreamingChunk) -> anyhow::Result<()> {
if let Some(ref err) = chunk.error {
let msg = err.message.as_deref().unwrap_or("Unknown error");
let code = err
.code
.as_ref()
.map(|c| c.to_string())
.unwrap_or_else(|| "unknown".to_string());
let err_type = err.r#type.as_deref().unwrap_or("server_error");
Err(anyhow!(
"Server error during streaming (code: {}, type: {}): {}",
code,
err_type,
msg
))
} else {
Ok(())
}
}

pub fn response_to_streaming_message<S>(
mut stream: S,
) -> impl Stream<Item = anyhow::Result<(Option<Message>, Option<ProviderUsage>)>> + 'static
Expand Down Expand Up @@ -559,6 +588,8 @@ where
.ok_or_else(|| anyhow!("unexpected stream format"))?)
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;

check_streaming_error(&chunk)?;

if !chunk.choices.is_empty() {
if let Some(details) = &chunk.choices[0].delta.reasoning_details {
accumulated_reasoning.extend(details.iter().cloned());
Expand Down Expand Up @@ -598,6 +629,8 @@ where
let tool_chunk: StreamingChunk = serde_json::from_str(line)
.map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?;

check_streaming_error(&tool_chunk)?;

if let Some(chunk_usage) = extract_usage_with_output_tokens(&tool_chunk) {
usage = Some(chunk_usage);
}
Expand Down Expand Up @@ -1879,7 +1912,6 @@ data: [DONE]"#;

#[test]
fn test_response_to_message_with_reasoning_content() -> anyhow::Result<()> {
// Test capturing reasoning_content from DeepSeek reasoning models
let response = json!({
"choices": [{
"role": "assistant",
Expand All @@ -1898,14 +1930,12 @@ data: [DONE]"#;
let message = response_to_message(&response)?;
assert_eq!(message.content.len(), 2);

// First should be reasoning content
if let MessageContent::Reasoning(reasoning) = &message.content[0] {
assert_eq!(reasoning.text, "Let me think about this step by step...");
} else {
panic!("Expected Reasoning content");
}

// Second should be text content
if let MessageContent::Text(text) = &message.content[1] {
assert_eq!(text.text, "The answer is 9.11 is greater than 9.8");
} else {
Expand All @@ -1917,12 +1947,10 @@ data: [DONE]"#;

#[test]
fn test_format_messages_with_reasoning_content() -> anyhow::Result<()> {
// Test that reasoning_content is properly included in formatted messages
let mut message = Message::assistant()
.with_content(MessageContent::reasoning("Thinking through the problem..."))
.with_text("The result is 42");

// Add a tool call to test that reasoning_content works with tool calls
message = message.with_tool_request(
"tool1",
Ok(rmcp::model::CallToolRequestParams {
Expand All @@ -1938,20 +1966,108 @@ data: [DONE]"#;
assert_eq!(spec.len(), 1);
assert_eq!(spec[0]["role"], "assistant");

// Should have reasoning_content field
assert!(spec[0].get("reasoning_content").is_some());
assert_eq!(
spec[0]["reasoning_content"],
"Thinking through the problem..."
);

// Should have content
assert_eq!(spec[0]["content"], "The result is 42");

// Should have tool_calls
assert!(spec[0]["tool_calls"].is_array());
assert_eq!(spec[0]["tool_calls"][0]["function"]["name"], "test_tool");

Ok(())
}

#[tokio::test]
async fn test_streaming_error_chunk_returns_server_error() {
let response_lines = r#"
data: {"error":{"code":500,"message":"Invalid diff: now finding less tool calls!","type":"server_error"}}
"#;

let lines: Vec<String> = response_lines.lines().map(|s| s.to_string()).collect();
let response_stream = tokio_stream::iter(lines.into_iter().map(Ok));
let messages = response_to_streaming_message(response_stream);
pin!(messages);

let mut found_error = false;
while let Some(result) = messages.next().await {
if let Err(e) = result {
let err_str = e.to_string();
assert!(
err_str.contains("Server error during streaming"),
"Expected server error message, got: {}",
err_str
);
assert!(
err_str.contains("Invalid diff: now finding less tool calls!"),
"Expected original error message preserved, got: {}",
err_str
);
found_error = true;
break;
}
}
assert!(found_error, "Expected an error from streaming error chunk");
}

#[tokio::test]
async fn test_streaming_error_chunk_during_tool_calls() {
let response_lines = r#"
data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"developer__shell","arguments":""}}]},"finish_reason":null}],"usage":null}
data: {"error":{"code":500,"message":"Internal server error","type":"server_error"}}
"#;

let lines: Vec<String> = response_lines.lines().map(|s| s.to_string()).collect();
let response_stream = tokio_stream::iter(lines.into_iter().map(Ok));
let messages = response_to_streaming_message(response_stream);
pin!(messages);

let mut found_error = false;
while let Some(result) = messages.next().await {
if let Err(e) = result {
assert!(
e.to_string().contains("Server error during streaming"),
"Expected server error, got: {}",
e
);
found_error = true;
break;
}
}
assert!(
found_error,
"Expected error when server sends error mid-tool-call"
);
}

#[tokio::test]
async fn test_streaming_error_chunk_with_no_choices_no_crash() {
let response_lines = r#"
data: {"error":{"message":"rate limit exceeded","type":"rate_limit_error","code":429}}
"#;

let lines: Vec<String> = response_lines.lines().map(|s| s.to_string()).collect();
let response_stream = tokio_stream::iter(lines.into_iter().map(Ok));
let messages = response_to_streaming_message(response_stream);
pin!(messages);

let mut found_error = false;
while let Some(result) = messages.next().await {
if let Err(e) = result {
assert!(
e.to_string().contains("rate limit exceeded"),
"Expected rate limit error, got: {}",
e
);
found_error = true;
break;
}
}
assert!(
found_error,
"Expected error from error-only streaming chunk"
);
}
}