Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/py/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ impl Message {
role,
content,
num_images,
tool_calls: None,
tool_call_id: None,
}
}
}
Expand Down
40 changes: 18 additions & 22 deletions src/server/claude_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1527,27 +1527,24 @@ pub async fn messages(
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);

let final_tool_calls = if strict_mode {
if !invalid.is_empty() {
if !invalid.is_empty() {
if strict_mode {
crate::log_warn!(
"[Seq {}] Strict mode enabled, dropping invalid calls",
seq_id
);
}
validated_calls
} else {
if !invalid.is_empty() {
} else {
crate::log_warn!(
"[Seq {}] Strict mode disabled, keeping invalid calls",
"[Seq {}] Strict mode disabled, but still dropping invalid calls to avoid malformed tool payloads",
seq_id
);
log_tool_calls("Invalid", seq_id, &invalid);
if let Some(ref l) = stream_logger {
l.log_tool_calls("Invalid", &invalid);
}
}
pending_tool_calls
};
log_tool_calls("Invalid", seq_id, &invalid);
if let Some(ref l) = stream_logger {
l.log_tool_calls("Invalid", &invalid);
}
}
let final_tool_calls = validated_calls;

if final_tool_calls.is_empty() {
(Vec::new(), false)
Expand Down Expand Up @@ -1852,17 +1849,16 @@ pub async fn messages(
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);

let valid_calls = if strict_mode {
if !invalid_calls.is_empty() {
if !invalid_calls.is_empty() {
if strict_mode {
crate::log_warn!("Strict mode enabled, dropping invalid calls");
} else {
crate::log_warn!(
"Strict mode disabled, but still dropping invalid calls to avoid malformed tool payloads"
);
}
validated_calls
} else {
if !invalid_calls.is_empty() {
crate::log_warn!("Strict mode disabled, keeping invalid calls");
}
parsed_calls
};
}
let valid_calls = validated_calls;

if !valid_calls.is_empty() {
log_tool_calls("Valid", output.seq_id, &valid_calls);
Expand Down
148 changes: 108 additions & 40 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,50 +721,45 @@ pub fn convert_chat_message(
let mut prompt = String::new();
let mut images = Vec::new();

// Handle tool call messages specially
if role == "tool" {
if let Some(tool_call_id) = &msg.tool_call_id {
if let Some(content) = &msg.content {
let mut tool_text = String::new();
match content {
MessageContentType::PureText(text) => {
tool_text.push_str(text);
}
MessageContentType::Single(item) => {
if let MessageContent::Text { text } = item {
tool_text.push_str(text);
}
}
MessageContentType::Multi(items) => {
for item in items {
if let MessageContent::Text { text } = item {
tool_text.push_str(text);
tool_text.push(' ');
}
}
}
}
let tool_text_trimmed = tool_text.trim();
if !tool_text_trimmed.is_empty() {
prompt = format!("[Tool Result for {}]: {}", tool_call_id, tool_text_trimmed);
}
// Keep assistant tool-call turns structured so chat templates can render proper
// function-calling transcripts (same as vLLM/OpenAI style history).
if role == "assistant" {
if let Some(tool_calls) = &msg.tool_calls {
let mut content = String::new();
if let Some(existing) = &msg.content {
content = extract_text_content(existing).trim().to_owned();
}
let template_calls = tool_calls
.iter()
.map(to_template_tool_call)
.collect::<Vec<_>>();
return Ok(Message {
role,
content,
num_images: 0,
tool_calls: Some(template_calls),
tool_call_id: None,
});
}
return Ok(Message::new(role, prompt.trim().to_owned(), 0));
}

// // Handle assistant messages with tool calls
// if msg.tool_calls.is_some() {
// if let Some(tool_calls) = &msg.tool_calls {
// for tc in tool_calls {
// prompt.push_str(&format!(
// "<tool_call>\n{{\"name\": \"{}\", \"arguments\": {}}}\n</tool_call>\n",
// tc.function.name, tc.function.arguments
// ));
// }
// }
// return Ok(Message::new(role, prompt.trim().to_owned(), 0));
// }
// Handle tool result messages specially
if role == "tool" {
let content = msg
.content
.as_ref()
.map(extract_text_content)
.unwrap_or_default()
.trim()
.to_owned();
return Ok(Message {
role,
content,
num_images: 0,
tool_calls: None,
tool_call_id: msg.tool_call_id.clone(),
});
}

// Normal message handling
if let Some(content) = &msg.content {
Expand Down Expand Up @@ -795,6 +790,43 @@ pub fn convert_chat_message(
Ok(Message::new(role, prompt.trim().to_owned(), images.len()))
}

fn extract_text_content(content: &MessageContentType) -> String {
match content {
MessageContentType::PureText(text) => text.clone(),
MessageContentType::Single(item) => match item {
MessageContent::Text { text } => text.clone(),
_ => String::new(),
},
MessageContentType::Multi(items) => items
.iter()
.filter_map(|item| match item {
MessageContent::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" "),
}
}

fn to_template_tool_call(call: &crate::tools::ToolCall) -> serde_json::Value {
let args = call
.function
.arguments
.as_deref()
.and_then(|raw| serde_json::from_str::<serde_json::Value>(raw).ok())
.filter(|v| v.is_object())
.unwrap_or_else(|| serde_json::json!({}));

serde_json::json!({
"id": call.id.clone(),
"type": call.tool_type.clone(),
"function": {
"name": call.function.name.clone(),
"arguments": args
}
})
}

fn append_message_item(
item: &MessageContent,
prompt: &mut String,
Expand Down Expand Up @@ -1083,6 +1115,42 @@ mod tests {
assert_eq!(tool_result.tool_call_id, Some("call_123".to_string()));
}

#[test]
fn preserves_assistant_tool_calls_in_template_message() {
let tool_call =
crate::tools::new_tool_call("call_1", "Read", r#"{"file_path":"ReadMe.md"}"#);
let msg = ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(vec![tool_call]),
tool_call_id: None,
};
let mut processor = None;
let mut images = Vec::new();
let converted = convert_chat_message(&msg, &mut processor, &mut images).unwrap();

assert_eq!(converted.role, "assistant");
assert_eq!(converted.content, "");
assert!(converted.tool_calls.is_some());
let calls = converted.tool_calls.unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0]["function"]["name"], "Read");
assert!(calls[0]["function"]["arguments"].is_object());
assert_eq!(calls[0]["function"]["arguments"]["file_path"], "ReadMe.md");
}

#[test]
fn preserves_tool_result_metadata_in_template_message() {
let msg = ChatMessage::tool_result("call_1", "{\"ok\":true}");
let mut processor = None;
let mut images = Vec::new();
let converted = convert_chat_message(&msg, &mut processor, &mut images).unwrap();

assert_eq!(converted.role, "tool");
assert_eq!(converted.content, "{\"ok\":true}");
assert_eq!(converted.tool_call_id, Some("call_1".to_string()));
}

#[test]
fn test_tokenize_request_text_parsing() {
let json = r#"{"prompt": "Hello, world!"}"#;
Expand Down
70 changes: 54 additions & 16 deletions src/server/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,13 +354,14 @@ impl StreamToolParser {
}
}

// Don't detect tool calls inside reasoning or code blocks
if self.in_reasoning() || self.in_code_block {
return StreamResult::Content(token_text.to_string());
}

match self.state.clone() {
ParserState::Normal => {
// Don't detect tool-call starts inside reasoning or code blocks.
// Once buffering starts we must continue buffering even if arguments
// contain code fences/backticks.
if self.in_reasoning() || self.in_code_block {
return StreamResult::Content(token_text.to_string());
}
// Check for start trigger
if self.is_start_token(token_id, token_text) {
self.state = ParserState::Buffering;
Expand Down Expand Up @@ -437,21 +438,25 @@ impl StreamToolParser {
}

/// Check if token/text matches start trigger
fn is_start_token(&self, id: u32, text: &str) -> bool {
match self.accumulated_output[..self.accumulated_output.len() - text.len()]
.chars()
.last()
{
// Empty buffer or newline are valid "start of line" checks for tool calls
None | Some('\n') => {}
_ => return false,
};
fn is_start_token(&self, id: u32, _text: &str) -> bool {
// Token ID match (if available)
if self.config.has_start_tokens() {
return self.config.start_token_ids.contains(&id);
}
// Text match
text.contains(&self.config.start_token_str)

// Text-only mode: detect on the current line, allowing split tags while
// avoiding overly eager triggers like a lone "<".
let current_line = self.accumulated_output.rsplit('\n').next().unwrap_or("");
let candidate = current_line.trim_start_matches(|c| c == ' ' || c == '\t' || c == '\r');

if candidate.starts_with(&self.config.start_token_str) {
return true;
}

let min_prefix_len = Self::safe_partial_prefix_len(&self.config.start_token_str);
!candidate.is_empty()
&& candidate.len() >= min_prefix_len
&& self.config.start_token_str.starts_with(candidate)
}

/// Check if token/text matches end trigger
Expand Down Expand Up @@ -604,6 +609,15 @@ impl StreamToolParser {
output
}

fn safe_partial_prefix_len(start_tag: &str) -> usize {
if let Some(idx) = start_tag.find('_') {
// E.g. "<tool_call>" => require at least "<tool"
return idx.max(2);
}
// Default minimum for tags without underscore.
start_tag.find('>').map_or(6, |idx| idx).clamp(2, 6)
}

// --- Chunk creation helpers (for use by server.rs) ---

/// Create a content chunk for streaming
Expand Down Expand Up @@ -795,4 +809,28 @@ mod tests {
_ => panic!("Expected Content without token ID match"),
}
}

#[tokio::test]
async fn test_parser_keeps_buffering_when_args_include_code_fence() {
let tools = vec![crate::tools::function_tool("test", "desc").build()];
let mut parser = StreamToolParser::new_with_config(
&ModelType::Qwen3,
"qwen3".to_string(),
ToolConfig::for_model_type(&ModelType::Qwen3),
tools,
None,
);

match parser.process_token(151657, "<tool_call>").await {
StreamResult::Buffering => {}
_ => panic!("Expected Buffering on start tag"),
}

// Code-fence-like content inside buffered arguments should not switch the
// parser back to normal content mode.
match parser.process_token(0, "\n```markdown\n").await {
StreamResult::Buffering => {}
_ => panic!("Expected Buffering while inside tool call arguments"),
}
}
}
Loading