Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 305 additions & 0 deletions crates/goose/src/providers/formats/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use anyhow::{anyhow, Error};
use async_stream::try_stream;
use chrono;
use futures::Stream;
use regex::Regex;
use rmcp::model::{
object, AnnotateAble, CallToolRequestParams, Content, ErrorCode, ErrorData, RawContent,
ResourceContents, Role, Tool,
Expand All @@ -17,6 +18,62 @@ use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::borrow::Cow;
use std::ops::Deref;
use uuid::Uuid;

/// Parse XML-style tool calls from content (Qwen3-coder fallback format when many tools are provided)
/// Format: <function=name><parameter=key>value</parameter>...</function>
/// Returns a tuple of (prefix_text, tool_calls) where prefix_text is any text before the first function tag
fn parse_xml_tool_calls(content: &str) -> (Option<String>, Vec<MessageContent>) {
let mut tool_calls = Vec::new();

let function_re = Regex::new(r"<function=([^>]+)>([\s\S]*?)</function>").unwrap();
let param_re = Regex::new(r"<parameter=([^>]+)>([\s\S]*?)</parameter>").unwrap();

let prefix = content
.find("<function=")
.and_then(|idx| content.get(..idx))
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.to_string());

for func_cap in function_re.captures_iter(content) {
let function_name = func_cap[1].trim().to_string();
let function_body = &func_cap[2];

let mut arguments = serde_json::Map::new();
for param_cap in param_re.captures_iter(function_body) {
let param_name = param_cap[1].trim().to_string();
let param_value = param_cap[2].trim().to_string();
arguments.insert(param_name, serde_json::Value::String(param_value));
}

let id = Uuid::new_v4().to_string();

if is_valid_function_name(&function_name) {
tool_calls.push(MessageContent::tool_request(
id,
Ok(CallToolRequestParams {
meta: None,
task: None,
name: function_name.into(),
arguments: Some(object(serde_json::Value::Object(arguments))),
}),
));
} else {
let error = ErrorData {
code: ErrorCode::INVALID_REQUEST,
message: Cow::from(format!(
"The provided function name '{}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+",
function_name
)),
data: None,
};
tool_calls.push(MessageContent::tool_request(id, Err(error)));
}
}

(prefix, tool_calls)
}

#[derive(Serialize, Deserialize, Debug, Default)]
struct DeltaToolCallFunction {
Expand Down Expand Up @@ -368,6 +425,28 @@ pub fn response_to_message(response: &Value) -> anyhow::Result<Message> {
}
}

// Fallback: If no JSON tool_calls found, check for XML-style tool calls in content
// This handles models like Qwen3-coder that output XML when given many tools
let has_tool_requests = content
.iter()
.any(|c| matches!(c, MessageContent::ToolRequest(_)));

if !has_tool_requests {
if let Some(text) = original.get("content").and_then(|c| c.as_str()) {
if text.contains("<function=") {
let (prefix, xml_tool_calls) = parse_xml_tool_calls(text);
if !xml_tool_calls.is_empty() {
// Replace text content with parsed tool calls
content.clear();
if let Some(prefix_text) = prefix {
content.push(MessageContent::text(prefix_text));
}
content.extend(xml_tool_calls);
}
}
}
}

Ok(Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
Expand Down Expand Up @@ -471,6 +550,9 @@ where
use futures::StreamExt;

let mut accumulated_reasoning: Vec<Value> = Vec::new();
// Track accumulated text content for XML tool call detection
let mut accumulated_text: String = String::new();
let mut last_chunk_id: Option<String> = None;

'outer: while let Some(response) = stream.next().await {
if response.as_ref().is_ok_and(|s| s == "data: [DONE]") {
Expand All @@ -493,6 +575,11 @@ where
}
}

// Track chunk ID for message construction
if chunk.id.is_some() {
last_chunk_id = chunk.id.clone();
}

let mut usage = extract_usage_with_output_tokens(&chunk);

if chunk.choices.is_empty() {
Expand Down Expand Up @@ -620,6 +707,38 @@ where
)
} else if chunk.choices[0].delta.content.is_some() {
let text = chunk.choices[0].delta.content.as_ref().unwrap();

accumulated_text.push_str(text);

let is_final = chunk.choices[0].finish_reason.is_some();

if is_final && accumulated_text.contains("<function=") {
let (prefix, xml_tool_calls) = parse_xml_tool_calls(&accumulated_text);

if !xml_tool_calls.is_empty() {
let mut contents = Vec::new();
if let Some(prefix_text) = prefix {
contents.push(MessageContent::text(prefix_text));
}
contents.extend(xml_tool_calls);

let mut msg = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
contents,
);

if let Some(id) = last_chunk_id.clone() {
msg = msg.with_id(id);
}

yield (Some(msg), usage);
accumulated_text.clear();
continue;
}
}

// Normal text streaming - yield the text chunk
let mut msg = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
Expand Down Expand Up @@ -1661,4 +1780,190 @@ data: [DONE]

Ok(())
}

// Tests for XML tool call parsing (Qwen3-coder fallback format)

#[test]
fn test_parse_xml_tool_calls_single() {
let content = r#"<function=developer__text_editor>
<parameter=command>write</parameter>
<parameter=path>/tmp/test.txt</parameter>
<parameter=file_text>hello world</parameter>
</function>"#;

let (prefix, tool_calls) = parse_xml_tool_calls(content);

assert!(prefix.is_none(), "Should have no prefix");
assert_eq!(tool_calls.len(), 1, "Should have 1 tool call");

if let MessageContent::ToolRequest(request) = &tool_calls[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "developer__text_editor");
let args = tool_call.arguments.as_ref().unwrap();
assert_eq!(args.get("command").unwrap(), "write");
assert_eq!(args.get("path").unwrap(), "/tmp/test.txt");
assert_eq!(args.get("file_text").unwrap(), "hello world");
} else {
panic!("Expected ToolRequest content");
}
}

#[test]
fn test_parse_xml_tool_calls_with_prefix() {
let content = r#"I'll create the file for you.

<function=developer__text_editor>
<parameter=command>write</parameter>
<parameter=path>/tmp/test.txt</parameter>
</function>"#;

let (prefix, tool_calls) = parse_xml_tool_calls(content);

assert_eq!(
prefix,
Some("I'll create the file for you.".to_string()),
"Should have prefix text"
);
assert_eq!(tool_calls.len(), 1, "Should have 1 tool call");
}

#[test]
fn test_parse_xml_tool_calls_multiple() {
let content = r#"<function=developer__shell>
<parameter=command>ls -la</parameter>
</function>
<function=developer__text_editor>
<parameter=command>view</parameter>
<parameter=path>/tmp/test.txt</parameter>
</function>"#;

let (prefix, tool_calls) = parse_xml_tool_calls(content);

assert!(prefix.is_none());
assert_eq!(tool_calls.len(), 2, "Should have 2 tool calls");

if let MessageContent::ToolRequest(request) = &tool_calls[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "developer__shell");
} else {
panic!("Expected ToolRequest content");
}

if let MessageContent::ToolRequest(request) = &tool_calls[1] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "developer__text_editor");
} else {
panic!("Expected ToolRequest content");
}
}

#[test]
fn test_parse_xml_tool_calls_no_match() {
let content = "This is just regular text without any tool calls.";

let (prefix, tool_calls) = parse_xml_tool_calls(content);

assert!(prefix.is_none());
assert!(tool_calls.is_empty(), "Should have no tool calls");
}

#[test]
fn test_parse_xml_tool_calls_qwen_format() {
// Test the exact format observed from Qwen3-coder via Ollama
let content = r#"I'll create a file at /tmp/hello.txt with the content "hello".

<function=developer__text_editor>
<parameter=command>
write
</parameter>
<parameter=path>
/tmp/hello.txt
</parameter>
<parameter=file_text>
hello
</parameter>
</function>
</tool_call>"#;

let (prefix, tool_calls) = parse_xml_tool_calls(content);

assert!(prefix.is_some(), "Should have prefix");
assert_eq!(tool_calls.len(), 1, "Should have 1 tool call");

if let MessageContent::ToolRequest(request) = &tool_calls[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "developer__text_editor");
let args = tool_call.arguments.as_ref().unwrap();
assert_eq!(args.get("command").unwrap(), "write");
assert_eq!(args.get("path").unwrap(), "/tmp/hello.txt");
assert_eq!(args.get("file_text").unwrap(), "hello");
} else {
panic!("Expected ToolRequest content");
}
}

#[test]
fn test_response_to_message_xml_fallback() -> anyhow::Result<()> {
// Test that response_to_message falls back to XML parsing when no JSON tool_calls
let response = json!({
"choices": [{
"message": {
"role": "assistant",
"content": "<function=developer__shell>\n<parameter=command>ls</parameter>\n</function>"
}
}]
});

let message = response_to_message(&response)?;

assert_eq!(message.content.len(), 1);
if let MessageContent::ToolRequest(request) = &message.content[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "developer__shell");
} else {
panic!("Expected ToolRequest content from XML parsing");
}

Ok(())
}

#[test]
fn test_response_to_message_prefers_json_over_xml() -> anyhow::Result<()> {
// Test that JSON tool_calls take precedence over XML in content
let response = json!({
"choices": [{
"message": {
"role": "assistant",
"content": "<function=wrong_tool>\n<parameter=x>y</parameter>\n</function>",
"tool_calls": [{
"id": "call_123",
"function": {
"name": "correct_tool",
"arguments": "{\"a\": \"b\"}"
}
}]
}
}]
});

let message = response_to_message(&response)?;

// Should have both text (from content) and tool request (from tool_calls)
// The XML in content should NOT be parsed since we have JSON tool_calls
let tool_requests: Vec<_> = message
.content
.iter()
.filter(|c| matches!(c, MessageContent::ToolRequest(_)))
.collect();

assert_eq!(tool_requests.len(), 1);
if let MessageContent::ToolRequest(request) = tool_requests[0] {
let tool_call = request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, "correct_tool");
} else {
panic!("Expected ToolRequest");
}

Ok(())
}
}
Loading