From cdddc2495e7d38bfc2b135a58dbd869ca81c6905 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Wed, 3 Dec 2025 18:57:05 +0300 Subject: [PATCH 01/18] add tool choice support Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/protocols/openai.rs | 4 +- .../src/protocols/openai/chat_completions.rs | 22 +- .../openai/chat_completions/aggregator.rs | 199 ++-- .../openai/chat_completions/delta.rs | 290 +++++- lib/llm/src/protocols/openai/common_ext.rs | 2 +- lib/llm/src/protocols/openai/completions.rs | 4 +- lib/llm/src/protocols/openai/partial_json.rs | 306 +++++++ lib/llm/src/protocols/openai/tools.rs | 295 ++++++ lib/llm/tests/test_common_ext.rs | 2 +- lib/llm/tests/tool_choice.rs | 848 ++++++++++++++++++ lib/llm/tests/tool_choice_finish_reasons.rs | 214 +++++ 11 files changed, 2119 insertions(+), 67 deletions(-) create mode 100644 lib/llm/src/protocols/openai/partial_json.rs create mode 100644 lib/llm/src/protocols/openai/tools.rs create mode 100644 lib/llm/tests/tool_choice.rs create mode 100644 lib/llm/tests/tool_choice_finish_reasons.rs diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 574d7db31cd..9b9b0e8411e 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -16,7 +16,9 @@ pub mod completions; pub mod embeddings; pub mod models; pub mod nvext; +pub mod partial_json; pub mod responses; +pub mod tools; pub mod validate; use validate::{ @@ -131,7 +133,7 @@ impl SamplingOptionsProvid let guided_whitespace_pattern = self.get_guided_whitespace_pattern(); let guided_decoding = match common::GuidedDecodingOptions::from_optional( - guided_json.cloned(), + guided_json, guided_regex, guided_choice, guided_grammar, diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index bb712086ecb..cab55032cbe 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -12,7 +12,7 @@ use super::{ common_ext::{CommonExt, CommonExtProvider}, nvext::NvExt, nvext::NvExtProvider, - validate, + tools, validate, }; pub mod aggregator; @@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { } /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value> { - self.common.guided_json.as_ref() + fn get_guided_json(&self) -> Option { + if let Some(value) = self.common.guided_json.clone() { + return Some(value); + } + + let tool_choice = self.inner.tool_choice.as_ref()?; + let tools = self.inner.tools.as_deref()?; + + match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) { + Ok(schema) => schema, + Err(err) => { + tracing::warn!( + error = %err, + "failed to derive guided_json from tool_choice" + ); + None + } + } } fn get_guided_regex(&self) -> Option { diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index 6e178bc25a3..3cd47ca622f 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -65,27 +65,6 @@ impl Default for DeltaAggregator { } } -fn convert_tool_chunk_to_message_tool_call( - chunk: &dynamo_async_openai::types::ChatCompletionMessageToolCallChunk, -) -> Option { - // Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall - if let (Some(id), Some(r#type), Some(function)) = (&chunk.id, &chunk.r#type, &chunk.function) { - if let (Some(name), Some(arguments)) = (&function.name, &function.arguments) { - Some(dynamo_async_openai::types::ChatCompletionMessageToolCall { - id: id.clone(), - r#type: r#type.clone(), - function: dynamo_async_openai::types::FunctionCall { - name: name.clone(), - arguments: arguments.clone(), - }, - }) - } else { - None - } - } else { - None - } -} impl DeltaAggregator { /// Creates a new, empty [`DeltaAggregator`] instance. @@ -175,26 +154,51 @@ impl DeltaAggregator { .push_str(reasoning_content); } - // Since one tool call is one chunk, we don't need to aggregate them - // We just need to convert the ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall and append to the state_choice.tool_calls - if let Some(tool_calls) = &choice.delta.tool_calls - && !tool_calls.is_empty() + // Aggregate tool calls incrementally + // Each chunk may add a new tool call or append arguments to existing one + if let Some(tool_call_chunks) = &choice.delta.tool_calls + && !tool_call_chunks.is_empty() { - // Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall - let converted_tool_calls: Vec< - dynamo_async_openai::types::ChatCompletionMessageToolCall, - > = tool_calls - .iter() - .filter_map(convert_tool_chunk_to_message_tool_call) - .collect(); - - // Initialize and push the converted tool calls to state_choice.tool_calls - // Only set tool_calls to Some if there are actual tool calls - if !converted_tool_calls.is_empty() { - if let Some(existing_tool_calls) = &mut state_choice.tool_calls { - existing_tool_calls.extend(converted_tool_calls); - } else { - state_choice.tool_calls = Some(converted_tool_calls); + // Initialize tool_calls vec if needed + let existing_tool_calls = state_choice + .tool_calls + .get_or_insert_with(Vec::new); + + // Process each chunk + for chunk in tool_call_chunks { + let chunk_index = chunk.index as usize; + + // Find or create tool call at this index + if chunk_index >= existing_tool_calls.len() { + // Extend the vec to accommodate this index + existing_tool_calls.resize_with(chunk_index + 1, || { + dynamo_async_openai::types::ChatCompletionMessageToolCall { + id: String::new(), + r#type: dynamo_async_openai::types::ChatCompletionToolType::Function, + function: dynamo_async_openai::types::FunctionCall { + name: String::new(), + arguments: String::new(), + }, + } + }); + } + + let tool_call = &mut existing_tool_calls[chunk_index]; + + // Update fields if present in chunk + if let Some(id) = &chunk.id { + tool_call.id = id.clone(); + } + if let Some(r#type) = &chunk.r#type { + tool_call.r#type = r#type.clone(); + } + if let Some(function) = &chunk.function { + if let Some(name) = &function.name { + tool_call.function.name = name.clone(); + } + if let Some(arguments) = &function.arguments { + tool_call.function.arguments.push_str(arguments); + } } } } @@ -270,10 +274,17 @@ impl From for dynamo_async_openai::types::ChatChoice { /// The `function_call` field is deprecated. fn from(delta: DeltaChoice) -> Self { // If tool calls are present and non-empty, finish reason should be ToolCalls + // Unless it's a critical finish reason (Length, ContentFilter, Stop) that should be preserved let finish_reason = if delta .tool_calls .as_ref() .is_some_and(|calls| !calls.is_empty()) + && !matches!( + delta.finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop) + | Some(dynamo_async_openai::types::FinishReason::Length) + | Some(dynamo_async_openai::types::FinishReason::ContentFilter) + ) { Some(dynamo_async_openai::types::FinishReason::ToolCalls) } else { @@ -691,8 +702,8 @@ mod tests { } #[tokio::test] - async fn test_tool_calling_finish_reason_override_from_stop() { - // Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls + async fn test_tool_calling_finish_reason_respects_explicit_stop() { + // Test that when tool calls are present and finish reason is Stop, it remains Stop let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#; @@ -726,23 +737,22 @@ mod tests { let tool_calls = choice.message.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); - // Most importantly, verify that finish reason was overridden to ToolCalls despite original being Stop + // Finish reason should remain Stop because it was explicitly provided that way assert_eq!( choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::ToolCalls) + Some(dynamo_async_openai::types::FinishReason::Stop) ); } #[tokio::test] - async fn test_tool_calling_finish_reason_override_from_length() { - // Test that when tool calls are present but finish reason is Length, it gets overridden to ToolCalls + async fn test_tool_calling_preserves_length_when_present() { let tool_call_json = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#; let annotated_delta = create_test_delta( 0, "Let me search for that.", Some(dynamo_async_openai::types::Role::Assistant), - Some(dynamo_async_openai::types::FinishReason::Length), // Original finish reason is Length + Some(dynamo_async_openai::types::FinishReason::Length), None, Some(tool_call_json), ); @@ -763,15 +773,13 @@ mod tests { assert_eq!(response.choices.len(), 1); let choice = &response.choices[0]; - // Verify tool calls are present assert!(choice.message.tool_calls.is_some()); let tool_calls = choice.message.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); - // Verify that finish reason was overridden to ToolCalls despite original being Length assert_eq!( choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::ToolCalls) + Some(dynamo_async_openai::types::FinishReason::Length) ); } @@ -964,8 +972,8 @@ mod tests { } #[tokio::test] - async fn test_tool_calling_finish_reason_override_from_stop_alternative() { - // Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls + async fn test_tool_calling_finish_reason_stop_remains_when_set() { + // When finish_reason is explicitly Stop, we preserve it even if tool_calls are present let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#; @@ -991,10 +999,10 @@ mod tests { assert_eq!(response.choices.len(), 1); let choice = &response.choices[0]; - // The finish_reason should be ToolCalls, not Stop, because tool calls are present + // The finish_reason should remain Stop because it was explicitly provided that way assert_eq!( choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::ToolCalls) + Some(dynamo_async_openai::types::FinishReason::Stop) ); // Verify tool calls are present @@ -1003,4 +1011,87 @@ mod tests { assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].function.name, "get_weather"); } + + #[tokio::test] + async fn test_tool_calling_preserves_length_finish_reason() { + // Test that Length finish reason is preserved even with tool_calls present + let tool_call_json = + r#"{"name": "get_weather", "arguments": {"location": "Paris"}}"#; + + let annotated_delta = create_test_delta( + 0, + "", + Some(dynamo_async_openai::types::Role::Assistant), + Some(dynamo_async_openai::types::FinishReason::Length), // Length finish reason + None, + Some(tool_call_json), + ); + + let data = annotated_delta.data.unwrap(); + let annotated_delta = Annotated { + data: Some(data), + id: Some("test_id".to_string()), + event: None, + comment: None, + }; + let stream = Box::pin(stream::iter(vec![annotated_delta])); + + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; + + assert!(result.is_ok()); + let response = result.unwrap(); + assert_eq!(response.choices.len(), 1); + let choice = &response.choices[0]; + + // Critical: Length finish reason MUST be preserved, not replaced with ToolCalls + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::Length), + "Length finish reason must be preserved even when tool_calls are present" + ); + + // Verify tool calls are still present + assert!(choice.message.tool_calls.is_some()); + } + + #[tokio::test] + async fn test_tool_calling_preserves_content_filter_finish_reason() { + // Test that ContentFilter finish reason is preserved even with tool_calls present + let tool_call_json = + r#"{"name": "harmful_action", "arguments": {}}"#; + + let annotated_delta = create_test_delta( + 0, + "", + Some(dynamo_async_openai::types::Role::Assistant), + Some(dynamo_async_openai::types::FinishReason::ContentFilter), // ContentFilter + None, + Some(tool_call_json), + ); + + let data = annotated_delta.data.unwrap(); + let annotated_delta = Annotated { + data: Some(data), + id: Some("test_id".to_string()), + event: None, + comment: None, + }; + let stream = Box::pin(stream::iter(vec![annotated_delta])); + + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; + + assert!(result.is_ok()); + let response = result.unwrap(); + let choice = &response.choices[0]; + + // Critical: ContentFilter finish reason MUST be preserved + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::ContentFilter), + "ContentFilter finish reason must be preserved even when tool_calls are present" + ); + + // Verify tool calls are still present + assert!(choice.message.tool_calls.is_some()); + } } diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 186bb7f0950..fef1ff17986 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -4,9 +4,15 @@ use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use crate::{ local_model::runtime_config::ModelRuntimeConfig, - protocols::common::{self}, + protocols::{common::{self}, openai::partial_json::{loads, AllowPartial}}, types::TokenIdType, }; +use dynamo_async_openai::types::{ + ChatCompletionMessageToolCallChunk, ChatCompletionToolChoiceOption, ChatCompletionToolType, + FunctionCallStream, +}; +use serde_json::{self}; +use tracing::error; /// Provides a method for generating a [`DeltaGenerator`] from a chat completion request. impl NvCreateChatCompletionRequest { @@ -41,6 +47,14 @@ impl NvCreateChatCompletionRequest { /// # Returns /// * [`DeltaGenerator`] configured with model name and response options. pub fn response_generator(&self, request_id: String) -> DeltaGenerator { + let tool_choice_context = match self.inner.tool_choice.as_ref() { + Some(ChatCompletionToolChoiceOption::Named(named)) => { + ToolChoiceContext::Named(named.function.name.clone()) + } + Some(ChatCompletionToolChoiceOption::Required) => ToolChoiceContext::Required, + _ => ToolChoiceContext::None, + }; + let options = DeltaGeneratorOptions { enable_usage: self .inner @@ -51,6 +65,7 @@ impl NvCreateChatCompletionRequest { enable_logprobs: self.inner.logprobs.unwrap_or(false) || self.inner.top_logprobs.unwrap_or(0) > 0, runtime_config: ModelRuntimeConfig::default(), + tool_choice: tool_choice_context, }; DeltaGenerator::new(self.inner.model.clone(), options, request_id) @@ -66,6 +81,15 @@ pub struct DeltaGeneratorOptions { pub enable_logprobs: bool, pub runtime_config: ModelRuntimeConfig, + pub tool_choice: ToolChoiceContext, +} + +#[derive(Debug, Clone, Default)] +pub enum ToolChoiceContext { + #[default] + None, + Named(String), + Required, } /// Generates incremental chat completion responses in a streaming fashion. @@ -88,6 +112,19 @@ pub struct DeltaGenerator { msg_counter: u64, /// Configuration options for response generation. options: DeltaGeneratorOptions, + /// Buffer for accumulating tool call JSON during streaming + tool_call_buffer: String, + /// Length of buffer that was already emitted (for delta calculation) + previous_buffer_len: usize, + /// Previous parsed state (for detecting new tool calls in Required mode) + previous_tool_calls: Vec, +} + +#[derive(Debug, Clone)] +struct ToolCallState { + index: usize, + name: String, + arguments: String, } impl DeltaGenerator { @@ -130,6 +167,9 @@ impl DeltaGenerator { usage, msg_counter: 0, options, + tool_call_buffer: String::new(), + previous_buffer_len: 0, + previous_tool_calls: Vec::new(), } } @@ -223,11 +263,22 @@ impl DeltaGenerator { text: Option, finish_reason: Option, logprobs: Option, + ) -> NvCreateChatCompletionStreamResponse { + self.build_choice(index, text, finish_reason, logprobs, None) + } + + /// Internal method to build a streaming chat completion response with optional tool_calls. + fn build_choice( + &mut self, + index: u32, + text: Option, + finish_reason: Option, + logprobs: Option, + tool_calls: Option>, ) -> NvCreateChatCompletionStreamResponse { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { content: text, - function_call: None, - tool_calls: None, + tool_calls, role: if self.msg_counter == 0 { Some(dynamo_async_openai::types::Role::Assistant) } else { @@ -288,6 +339,200 @@ impl DeltaGenerator { pub fn is_usage_enabled(&self) -> bool { self.options.enable_usage } + + fn process_streaming_tool_calls( + &mut self, + delta_text: &str, + _is_final: bool, + ) -> anyhow::Result> { + // Accumulate the delta into buffer + self.tool_call_buffer.push_str(delta_text); + + // Parse the current buffer state using partial JSON parser + let parsed = match loads(&self.tool_call_buffer, AllowPartial::all()) { + Ok(value) => value, + Err(_e) => { + // If we can't parse yet, just wait for more data + return Ok(vec![]); + } + }; + + // Extract current tool calls from parsed JSON + let current_calls = match &self.options.tool_choice { + ToolChoiceContext::Named(name) => { + // For named tool choice, the output is raw JSON parameters + // Use the buffer directly, not serialized + if parsed.as_object().is_some() { + vec![ToolCallState { + index: 0, + name: name.clone(), + arguments: self.tool_call_buffer.clone(), + }] + } else { + vec![] + } + } + ToolChoiceContext::Required => { + // For required, parse the array of tool calls + if let Some(array) = parsed.as_array() { + array + .iter() + .enumerate() + .filter_map(|(idx, entry)| { + let name = entry.get("name")?.as_str()?.to_string(); + let parameters = entry.get("parameters")?; + let args = serde_json::to_string(parameters).unwrap_or_default(); + Some(ToolCallState { + index: idx, + name, + arguments: args, + }) + }) + .collect() + } else { + vec![] + } + } + ToolChoiceContext::None => vec![], + }; + + // Generate deltas by comparing with previous state + let mut chunks = Vec::new(); + + for current in ¤t_calls { + // Check if this is a new tool or existing one + let previous = self.previous_tool_calls.iter().find(|p| p.index == current.index); + + match previous { + None => { + // New tool - emit first chunk with id, type, name + chunks.push(ChatCompletionMessageToolCallChunk { + index: current.index as u32, + id: Some(format!("call_{}", current.index + 1)), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(current.name.clone()), + arguments: Some(String::new()), + }), + }); + + // For Named mode, emit delta from buffer + // For Required mode, emit serialized parameters + if matches!(self.options.tool_choice, ToolChoiceContext::Named(_)) { + // Use raw buffer delta for Named mode + if self.tool_call_buffer.len() > self.previous_buffer_len { + let delta = &self.tool_call_buffer[self.previous_buffer_len..]; + if !delta.is_empty() { + chunks.push(ChatCompletionMessageToolCallChunk { + index: current.index as u32, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(delta.to_string()), + }), + }); + } + } + } else { + // For Required mode, emit full arguments (serialized parameters) + if !current.arguments.is_empty() { + chunks.push(ChatCompletionMessageToolCallChunk { + index: current.index as u32, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(current.arguments.clone()), + }), + }); + } + } + } + Some(prev) => { + // Existing tool - emit delta of arguments + if matches!(self.options.tool_choice, ToolChoiceContext::Named(_)) { + // For Named mode, use raw buffer delta + if self.tool_call_buffer.len() > self.previous_buffer_len { + let delta = &self.tool_call_buffer[self.previous_buffer_len..]; + if !delta.is_empty() { + chunks.push(ChatCompletionMessageToolCallChunk { + index: current.index as u32, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(delta.to_string()), + }), + }); + } + } + } else { + // For Required mode, compute delta from serialized arguments + if current.arguments.len() > prev.arguments.len() { + let delta = ¤t.arguments[prev.arguments.len()..]; + if !delta.is_empty() { + chunks.push(ChatCompletionMessageToolCallChunk { + index: current.index as u32, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(delta.to_string()), + }), + }); + } + } + } + } + } + } + + // Update previous state + self.previous_tool_calls = current_calls; + self.previous_buffer_len = self.tool_call_buffer.len(); + + Ok(chunks) + } + + fn determine_streaming_finish_reason( + &self, + backend_finish: Option, + ) -> Option { + backend_finish.as_ref()?; + + // For critical/error finish reasons, preserve them regardless of tool_choice mode + match backend_finish { + Some(common::FinishReason::Length) => { + return Some(dynamo_async_openai::types::FinishReason::Length); + } + Some(common::FinishReason::ContentFilter) => { + return Some(dynamo_async_openai::types::FinishReason::ContentFilter); + } + _ => {} + } + + // For normal finish reasons (Stop/EoS/Cancelled), apply tool_choice semantics + match &self.options.tool_choice { + ToolChoiceContext::None => match backend_finish { + Some(common::FinishReason::EoS) | Some(common::FinishReason::Stop) => { + Some(dynamo_async_openai::types::FinishReason::Stop) + } + Some(common::FinishReason::Cancelled) => { + Some(dynamo_async_openai::types::FinishReason::Stop) + } + _ => None, + }, + ToolChoiceContext::Named(_) => { + // Named tool choice finishes with "stop" for normal completion + Some(dynamo_async_openai::types::FinishReason::Stop) + } + ToolChoiceContext::Required => { + // Required tool choice finishes with "tool_calls" for normal completion + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + } + } + } } /// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing @@ -338,7 +583,8 @@ impl crate::protocols::openai::DeltaGeneratorExt Some(dynamo_async_openai::types::FinishReason::Stop), Some(common::FinishReason::Stop) => { Some(dynamo_async_openai::types::FinishReason::Stop) @@ -360,7 +606,41 @@ impl crate::protocols::openai::DeltaGeneratorExt { + tool_call_chunks = Some(chunks); + delta_text = None; // Don't emit raw text when streaming tools + } + Ok(_) => { + // No chunks yet, suppress text output + delta_text = None; + } + Err(err) => { + error!( + error = %err, + "failed to parse streaming tool_choice output" + ); + } + } + } + + // Override finish reason for tool_choice modes + if finish_reason.is_some() { + finish_reason = self.determine_streaming_finish_reason(backend_finish_reason); + } + } + + let mut stream_response = + self.build_choice(index, delta_text, finish_reason, logprobs, tool_call_chunks); // Extract worker_id from disaggregated_params and inject into nvext if present if let Some(worker_id_json) = delta diff --git a/lib/llm/src/protocols/openai/common_ext.rs b/lib/llm/src/protocols/openai/common_ext.rs index a77f765ae6b..51d9ea0a997 100644 --- a/lib/llm/src/protocols/openai/common_ext.rs +++ b/lib/llm/src/protocols/openai/common_ext.rs @@ -94,7 +94,7 @@ pub trait CommonExtProvider { fn common_ext(&self) -> Option<&CommonExt>; /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value>; + fn get_guided_json(&self) -> Option; fn get_guided_regex(&self) -> Option; fn get_guided_grammar(&self) -> Option; fn get_guided_choice(&self) -> Option>; diff --git a/lib/llm/src/protocols/openai/completions.rs b/lib/llm/src/protocols/openai/completions.rs index b62f801c407..adff21a9ce6 100644 --- a/lib/llm/src/protocols/openai/completions.rs +++ b/lib/llm/src/protocols/openai/completions.rs @@ -183,8 +183,8 @@ impl CommonExtProvider for NvCreateCompletionRequest { } /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value> { - self.common.guided_json.as_ref() + fn get_guided_json(&self) -> Option { + self.common.guided_json.clone() } fn get_guided_regex(&self) -> Option { diff --git a/lib/llm/src/protocols/openai/partial_json.rs b/lib/llm/src/protocols/openai/partial_json.rs new file mode 100644 index 00000000000..e925655bf57 --- /dev/null +++ b/lib/llm/src/protocols/openai/partial_json.rs @@ -0,0 +1,306 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Partial JSON Parser for streaming tool calls. +//! +//! This implementation is heavily inspired by the `partial-json-parser` library: +//! https://github.com/promplate/partial-json-parser +//! +//! The original Python library is licensed under MIT License. +//! We've adapted the core logic to Rust for use in Dynamo's streaming tool calls functionality. + +use std::collections::VecDeque; + +/// Options for what types of partial JSON are allowed +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AllowPartial { + pub strings: bool, + pub objects: bool, + pub arrays: bool, +} + +impl Default for AllowPartial { + fn default() -> Self { + Self { + strings: true, + objects: true, + arrays: true, + } + } +} + +impl AllowPartial { + pub fn all() -> Self { + Self::default() + } + + pub fn none() -> Self { + Self { + strings: false, + objects: false, + arrays: false, + } + } +} + +/// Represents a token found during JSON scanning +#[derive(Debug, Clone, PartialEq, Eq)] +struct Token { + index: usize, + char: char, +} + +/// Scans the JSON string for structural characters +fn scan_tokens(json_string: &str) -> Vec { + json_string + .char_indices() + .filter_map(|(i, c)| { + if matches!(c, '"' | '[' | ']' | '{' | '}') { + Some(Token { index: i, char: c }) + } else { + None + } + }) + .collect() +} + +/// Checks if a quote at the given position is escaped +fn is_escaped(json_string: &str, index: usize) -> bool { + let text_before = &json_string[..index]; + let count = index - text_before.trim_end_matches('\\').len(); + count % 2 == 1 +} + +/// Joins closing tokens for unclosed containers +fn join_closing_tokens(stack: &VecDeque) -> String { + stack + .iter() + .rev() + .map(|token| if token.char == '{' { '}' } else { ']' }) + .collect() +} + +/// Completes a partial JSON string by adding necessary closing tokens +/// +/// Returns a tuple of (head, tail) where head is the potentially truncated +/// input and tail is the completion string +pub fn fix_json(json_string: &str, allow: AllowPartial) -> (String, String) { + let tokens = scan_tokens(json_string); + + // Empty or starts with quote - use simple fix + if tokens.is_empty() || tokens[0].char == '"' { + return simple_fix(json_string, allow); + } + + let mut stack: VecDeque = VecDeque::new(); + let mut in_string = false; + let mut last_string_start = None; + let mut last_string_end = None; + + for token in &tokens { + if token.char == '"' { + if !in_string { + in_string = true; + last_string_start = Some(token.index); + } else if !is_escaped(json_string, token.index) { + in_string = false; + last_string_end = Some(token.index); + } + } else if !in_string { + match token.char { + '}' => { + if let Some(open) = stack.pop_back() { + assert_eq!(open.char, '{', "Mismatched braces"); + } + } + ']' => { + if let Some(open) = stack.pop_back() { + assert_eq!(open.char, '[', "Mismatched brackets"); + } + } + _ => { + stack.push_back(token.clone()); + } + } + } + } + + // If stack is empty, JSON is complete + if stack.is_empty() { + return (json_string.to_string(), String::new()); + } + + // Remove trailing comma if present + let mut head = json_string.trim_end(); + if head.ends_with(',') { + head = head[..head.len() - 1].trim_end(); + } + + // Handle unclosed strings + if !allow.strings && in_string { + if let Some(last_container) = stack.back() + && last_container.char == '{' { + // Truncate before the unclosed string key + return ( + head[..=last_container.index].to_string(), + join_closing_tokens(&stack), + ); + } + + // Find last comma before the unclosed string + if let Some(string_start) = last_string_start { + let last_container_pos = stack.back().map(|t| t.index).unwrap_or(0); + let search_start = last_container_pos.max(last_string_end.unwrap_or(0)) + 1; + + if let Some(comma_pos) = head[search_start..string_start].rfind(',') { + let absolute_comma = search_start + comma_pos; + return ( + head[..absolute_comma].to_string(), + join_closing_tokens(&stack), + ); + } + } + } + + // Simple case: just close all open containers + if in_string && allow.strings + && let Some(string_start) = last_string_start { + // Fix the partial string + let partial_str = &head[string_start..]; + let (fixed_head, fixed_tail) = simple_fix(partial_str, allow); + return ( + format!("{}{}", &head[..string_start], fixed_head), + format!("{}{}", fixed_tail, join_closing_tokens(&stack)), + ); + } + + (head.to_string(), join_closing_tokens(&stack)) +} + +/// Simple fix for basic cases (strings, atoms) +fn simple_fix(json_string: &str, allow: AllowPartial) -> (String, String) { + let trimmed = json_string.trim_end(); + + // Handle unclosed strings + if trimmed.starts_with('"') + && allow.strings { + // Count how many unescaped quotes we have + let mut escaped = false; + let mut quote_count = 0; + for ch in trimmed.chars() { + if ch == '\\' && !escaped { + escaped = true; + } else { + if ch == '"' && !escaped { + quote_count += 1; + } + escaped = false; + } + } + + if quote_count % 2 == 1 { + // Unclosed string + return (trimmed.to_string(), "\"".to_string()); + } + } + + // Already complete or can't fix + (trimmed.to_string(), String::new()) +} + +/// Ensures the JSON string is complete by adding necessary tokens +pub fn ensure_json(json_string: &str, allow: AllowPartial) -> String { + let (head, tail) = fix_json(json_string, allow); + format!("{}{}", head, tail) +} + +/// Parses partial JSON string into a serde_json::Value +/// +/// This is the main function inspired by partial-json-parser's `loads()`. +/// It completes the partial JSON and then parses it. +pub fn loads(json_string: &str, allow: AllowPartial) -> Result { + let completed = ensure_json(json_string, allow); + serde_json::from_str(&completed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_complete_json() { + let result = ensure_json(r#"{"key":"value"}"#, AllowPartial::all()); + assert_eq!(result, r#"{"key":"value"}"#); + } + + #[test] + fn test_unclosed_object() { + let result = ensure_json(r#"{"key":"value""#, AllowPartial::all()); + assert_eq!(result, r#"{"key":"value"}"#); + } + + #[test] + fn test_unclosed_string() { + let result = ensure_json(r#"{"key":"val"#, AllowPartial::all()); + assert_eq!(result, r#"{"key":"val"}"#); + } + + #[test] + fn test_nested_objects() { + let result = ensure_json(r#"{"outer":{"inner":"val"#, AllowPartial::all()); + assert_eq!(result, r#"{"outer":{"inner":"val"}}"#); + } + + #[test] + fn test_array() { + let result = ensure_json(r#"[{"name":"test","args":{"val":"a""#, AllowPartial::all()); + assert_eq!(result, r#"[{"name":"test","args":{"val":"a"}}]"#); + } + + #[test] + fn test_parallel_tool_calls() { + let result = ensure_json( + r#"[{"name":"search","parameters":{"query":"rust"}},{"name":"summ"#, + AllowPartial::all(), + ); + // Should complete both the string and close all containers + assert!(result.contains("search")); + assert!(result.ends_with("}]")); + } + + #[test] + fn test_loads_incremental() { + // Test 1: Unclosed string value + let result1 = loads(r#"{"location":""#, AllowPartial::all()).unwrap(); + assert_eq!(result1["location"], ""); + + // Test 2: Complete first field, starting second + let result2 = loads(r#"{"location":"Paris","#, AllowPartial::all()).unwrap(); + assert_eq!(result2["location"], "Paris"); + + // Test 3: Complete object + let result3 = loads(r#"{"location":"Paris","unit":"celsius"}"#, AllowPartial::all()).unwrap(); + assert_eq!(result3["location"], "Paris"); + assert_eq!(result3["unit"], "celsius"); + } + + #[test] + fn test_loads_array_incremental() { + // Test 4: Array with unclosed parameter value + let result4 = loads(r#"[{"name":"search","parameters":{"query":""#, AllowPartial::all()).unwrap(); + assert!(result4.is_array()); + let arr = result4.as_array().unwrap(); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "search"); + assert_eq!(arr[0]["parameters"]["query"], ""); + + // Test 5: Complete first tool, starting second + let result5 = loads(r#"[{"name":"search","parameters":{"query":"rust"}},"#, AllowPartial::all()).unwrap(); + assert!(result5.is_array()); + let arr = result5.as_array().unwrap(); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "search"); + assert_eq!(arr[0]["parameters"]["query"], "rust"); + } +} + diff --git a/lib/llm/src/protocols/openai/tools.rs b/lib/llm/src/protocols/openai/tools.rs new file mode 100644 index 00000000000..f415811f911 --- /dev/null +++ b/lib/llm/src/protocols/openai/tools.rs @@ -0,0 +1,295 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::BTreeMap; + +use dynamo_async_openai::types::{ + ChatCompletionTool, ChatCompletionToolChoiceOption, FunctionObject, +}; +use serde_json::{Value, json}; +use thiserror::Error; + +/// Errors that can occur when deriving JSON schemas for tool_choice requests. +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ToolChoiceError { + #[error("tool_choice requires a matching `tools` array")] + MissingTools, + #[error("tool `{0}` was not provided in `tools`")] + ToolNotFound(String), + #[error("$defs for tool `{0}` must be an object")] + InvalidDefinitionMap(String), + #[error("duplicate $defs entry `{0}` has conflicting schemas")] + ConflictingDefinition(String), + #[error("tool_choice `required` needs at least one tool definition")] + EmptyTools, +} + +/// Builds the JSON schema enforced by Guided Decoding for the given tool_choice/tools pair. +pub fn get_json_schema_from_tools( + tool_choice: Option<&ChatCompletionToolChoiceOption>, + tools: Option<&[ChatCompletionTool]>, +) -> Result, ToolChoiceError> { + let Some(choice) = tool_choice else { + return Ok(None); + }; + + match choice { + ChatCompletionToolChoiceOption::None | ChatCompletionToolChoiceOption::Auto => Ok(None), + ChatCompletionToolChoiceOption::Named(named) => { + let tools = tools.ok_or(ToolChoiceError::MissingTools)?; + let tool = find_tool(tools, &named.function.name) + .ok_or_else(|| ToolChoiceError::ToolNotFound(named.function.name.clone()))?; + Ok(Some(clone_parameters(&tool.function))) + } + ChatCompletionToolChoiceOption::Required => { + let tools = tools.ok_or(ToolChoiceError::MissingTools)?; + if tools.is_empty() { + return Err(ToolChoiceError::EmptyTools); + } + build_required_schema(tools).map(Some) + } + } +} + +fn find_tool<'a>(tools: &'a [ChatCompletionTool], name: &str) -> Option<&'a ChatCompletionTool> { + tools.iter().find(|tool| tool.function.name == name) +} + +fn clone_parameters(function: &FunctionObject) -> Value { + function + .parameters + .clone() + .unwrap_or_else(|| json!({"type": "object", "properties": {}})) +} + +fn build_required_schema(tools: &[ChatCompletionTool]) -> Result { + let mut defs: BTreeMap = BTreeMap::new(); + let mut any_of = Vec::with_capacity(tools.len()); + + for tool in tools { + let ParamsAndDefs { + schema, + defs: new_defs, + } = split_defs(&tool.function)?; + merge_defs(&mut defs, new_defs)?; + any_of.push(json!({ + "properties": { + "name": { + "type": "string", + "enum": [tool.function.name], + }, + "parameters": schema, + }, + "required": ["name", "parameters"], + })); + } + + let mut result = json!({ + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": any_of, + }, + }); + + if !defs.is_empty() + && let Value::Object(map) = &mut result { + map.insert( + "$defs".to_string(), + Value::Object(defs.into_iter().collect()), + ); + } + + Ok(result) +} + +struct ParamsAndDefs { + schema: Value, + defs: Option>, +} + +fn split_defs(function: &FunctionObject) -> Result { + let mut schema = clone_parameters(function); + let defs = match &mut schema { + Value::Object(obj) => { + if let Some(value) = obj.remove("$defs") { + Some(convert_defs(function, value)?) + } else { + None + } + } + _ => None, + }; + + Ok(ParamsAndDefs { schema, defs }) +} + +fn convert_defs( + function: &FunctionObject, + defs_value: Value, +) -> Result, ToolChoiceError> { + match defs_value { + Value::Object(map) => Ok(map.into_iter().collect()), + _ => Err(ToolChoiceError::InvalidDefinitionMap(function.name.clone())), + } +} + +fn merge_defs( + target: &mut BTreeMap, + defs: Option>, +) -> Result<(), ToolChoiceError> { + let Some(defs) = defs else { + return Ok(()); + }; + + for (name, schema) in defs { + if let Some(existing) = target.get(&name) { + if existing != &schema { + return Err(ToolChoiceError::ConflictingDefinition(name)); + } + } else { + target.insert(name, schema); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, ChatCompletionToolType}; + + fn sample_tools() -> Vec { + vec![ + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "add_numbers".to_string(), + description: Some("Add two integers".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + })), + strict: None, + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + })), + strict: None, + }, + }, + ] + } + + #[test] + fn named_choice_returns_parameters() { + let tools = sample_tools(); + let tool_choice = ChatCompletionToolChoiceOption::Named( + dynamo_async_openai::types::ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: dynamo_async_openai::types::FunctionName { + name: "get_weather".to_string(), + }, + }, + ); + let schema = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).expect("schema"); + + assert_eq!( + schema.unwrap(), + json!({ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + }) + ); + } + + #[test] + fn required_choice_builds_any_of_schema() { + let tools = sample_tools(); + let schema = get_json_schema_from_tools( + Some(&ChatCompletionToolChoiceOption::Required), + Some(&tools), + ) + .expect("schema"); + + let schema = schema.expect("required schema"); + assert_eq!(schema["type"], "array"); + assert_eq!(schema["minItems"], 1); + assert!(schema["items"]["anyOf"].is_array()); + + let any_of = schema["items"]["anyOf"].as_array().unwrap(); + assert_eq!(any_of.len(), 2); + assert_eq!( + any_of[0]["properties"]["name"], + json!({"type": "string", "enum": ["add_numbers"]}) + ); + } + + #[test] + fn missing_tool_errors() { + let tools = sample_tools(); + let tool_choice = ChatCompletionToolChoiceOption::Named( + dynamo_async_openai::types::ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: dynamo_async_openai::types::FunctionName { + name: "unknown".to_string(), + }, + }, + ); + let err = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).unwrap_err(); + assert_eq!(err, ToolChoiceError::ToolNotFound("unknown".to_string())); + } + + #[test] + fn conflicting_defs_errors() { + let tool = ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "foo".to_string(), + description: None, + parameters: Some(json!({ + "type": "object", + "$defs": { + "shared": {"type": "string"} + } + })), + strict: None, + }, + }; + + let mut tool_with_conflict = tool.clone(); + tool_with_conflict.function.parameters = Some(json!({ + "type": "object", + "$defs": { + "shared": {"type": "number"} + } + })); + + let tools = vec![tool, tool_with_conflict]; + let err = build_required_schema(&tools).unwrap_err(); + assert_eq!( + err, + ToolChoiceError::ConflictingDefinition("shared".to_string()) + ); + } +} diff --git a/lib/llm/tests/test_common_ext.rs b/lib/llm/tests/test_common_ext.rs index 933e486a86c..149a2054918 100644 --- a/lib/llm/tests/test_common_ext.rs +++ b/lib/llm/tests/test_common_ext.rs @@ -92,7 +92,7 @@ fn test_chat_completions_guided_decoding_from_common() { ); assert_eq!( request.get_guided_json(), - Some(&serde_json::json!({"key": "value"})) + Some(serde_json::json!({"key": "value"})) ); // Test guided_regex can be specified at root level diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs new file mode 100644 index 00000000000..3854cbf976c --- /dev/null +++ b/lib/llm/tests/tool_choice.rs @@ -0,0 +1,848 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use dynamo_llm::protocols::common; +use dynamo_llm::protocols::common::llm_backend::BackendOutput; +use dynamo_llm::protocols::openai::chat_completions::{ + NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, +}; +use dynamo_llm::protocols::openai::DeltaGeneratorExt; +use dynamo_async_openai::types::{ + ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionNamedToolChoice, + ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionStreamResponseDelta, + ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest, + CreateChatCompletionStreamResponse, FunctionCallStream, FunctionName, +}; + +fn create_test_request() -> NvCreateChatCompletionRequest { + let messages = vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text("test".to_string()), + name: None, + }, + )]; + + NvCreateChatCompletionRequest { + inner: CreateChatCompletionRequest { + model: "test-model".to_string(), + messages, + stream: Some(false), + stream_options: None, + ..Default::default() + }, + common: Default::default(), + nvext: None, + chat_template_args: None, + unsupported_fields: Default::default(), + } +} + +fn build_backend_output(text: &str) -> BackendOutput { + BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(text.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: Some(common::FinishReason::Stop), + index: Some(0), + completion_usage: None, + disaggregated_params: None, + } +} + +#[test] +fn test_named_tool_choice_parses_json() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-1".to_string()); + let backend_output = build_backend_output(r#"{"location":"Paris"}"#); + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let choice = &response.choices[0]; + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop) + ); + let delta = &choice.delta; + assert!(delta.content.is_none()); + let tool_calls = delta.tool_calls.as_ref().unwrap(); + + // In streaming mode, we emit 2 chunks: first with id/name, second with arguments + assert!(tool_calls.len() >= 1, "Should have at least 1 tool call chunk"); + + // Find the chunk with the name (first chunk) + let name_chunk = tool_calls.iter().find(|tc| tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()); + assert!(name_chunk.is_some(), "Should have chunk with name"); + let name_chunk = name_chunk.unwrap(); + + assert_eq!(name_chunk.index, 0); + assert_eq!(name_chunk.id.as_deref(), Some("call_1")); + assert_eq!( + name_chunk.function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + + // Arguments may be in the same chunk or a subsequent one + let has_arguments = tool_calls.iter().any(|tc| { + tc.function.as_ref() + .and_then(|f| f.arguments.as_ref()) + .is_some_and(|args| !args.is_empty()) + }); + assert!(has_arguments, "Should have arguments in some chunk"); +} + +#[test] +fn test_required_tool_choice_parses_json_array() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-2".to_string()); + let backend_output = build_backend_output( + r#"[{"name":"search","parameters":{"query":"rust"}}, + {"name":"summarize","parameters":{"topic":"memory"}}]"#, + ); + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let choice = &response.choices[0]; + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + ); + let delta = &choice.delta; + assert!(delta.content.is_none()); + let tool_calls = delta.tool_calls.as_ref().unwrap(); + + // With incremental streaming, we emit separate chunks for name and arguments + // Expected: 4 chunks total (2 per tool: name chunk + arguments chunk) + assert_eq!(tool_calls.len(), 4); + + // First tool: name chunk + assert_eq!(tool_calls[0].index, 0); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("search") + ); + assert!(tool_calls[0].id.is_some()); + + // First tool: arguments chunk + assert_eq!(tool_calls[1].index, 0); + assert!(tool_calls[1].function.as_ref().unwrap().name.is_none()); + assert!(tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_ref() + .unwrap() + .contains("rust")); + + // Second tool: name chunk + assert_eq!(tool_calls[2].index, 1); + assert_eq!( + tool_calls[2].function.as_ref().unwrap().name.as_deref(), + Some("summarize") + ); + assert!(tool_calls[2].id.is_some()); + + // Second tool: arguments chunk + assert_eq!(tool_calls[3].index, 1); + assert!(tool_calls[3].function.as_ref().unwrap().name.is_none()); + assert!(tool_calls[3] + .function + .as_ref() + .unwrap() + .arguments + .as_ref() + .unwrap() + .contains("memory")); +} + +#[test] +fn test_tool_choice_parse_failure_suppresses_text() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-3".to_string()); + let backend_output = build_backend_output("not-json"); + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let delta = &response.choices[0].delta; + // When tool_choice is active but parsing fails, we suppress the text output + assert!(delta.content.is_none()); + assert!(delta.tool_calls.is_none()); +} + +#[test] +fn test_streaming_named_tool_incremental() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-stream-1".to_string()); + + // Simulate streaming chunks + // For simplicity in testing, send complete JSON in final chunk + let chunks = vec![r#"{"location":"Paris","unit":"celsius"}"#]; + + let mut all_responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("streaming chunk"); + all_responses.push(response); + } + + // Last response should have finish_reason + let last_response = all_responses.last().unwrap(); + assert_eq!( + last_response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop) + ); + + // Should have tool_calls somewhere in the stream + let has_tool_calls = all_responses + .iter() + .any(|r| r.choices[0].delta.tool_calls.is_some()); + assert!(has_tool_calls, "No tool calls found in any response"); +} + +#[test] +fn test_streaming_required_tool_parallel() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-stream-2".to_string()); + + // Simulate streaming array of tool calls + let chunks = vec![ + r#"[{"name":"search","parameters":{"query":"rust"}},"#, + r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#, + ]; + + let mut all_responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("streaming chunk"); + all_responses.push(response); + } + + // Final chunk should have finish_reason = ToolCalls + let last_response = all_responses.last().unwrap(); + assert_eq!( + last_response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + ); + + // Should have detected both tools + let mut found_search = false; + let mut found_summarize = false; + for resp in &all_responses { + if let Some(tool_calls) = &resp.choices[0].delta.tool_calls { + for tc in tool_calls { + if let Some(func) = &tc.function { + if let Some(name) = &func.name { + if name == "search" { + found_search = true; + } + if name == "summarize" { + found_summarize = true; + } + } + } + } + } + } + assert!(found_search, "Should detect search tool"); + assert!(found_summarize, "Should detect summarize tool"); +} + +#[test] +fn test_streaming_with_incremental_arguments() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "search".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-stream-3".to_string()); + + // Character-by-character streaming + let full_json = r#"{"query":"rust programming"}"#; + let mut responses = Vec::new(); + + for ch in full_json.chars() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(ch.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: None, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("char chunk"); + responses.push(response); + } + + // Should have suppressed raw text output + for resp in &responses { + assert!(resp.choices[0].delta.content.is_none()); + } +} + +#[test] +fn test_no_streaming_when_tool_choice_none() { + let request = create_test_request(); + // tool_choice = None (default) + + let mut generator = request.response_generator("req-stream-4".to_string()); + + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some("Hello world".to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: None, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("normal text"); + + // Should have text content, not tool_calls + assert_eq!( + response.choices[0].delta.content.as_deref(), + Some("Hello world") + ); + assert!(response.choices[0].delta.tool_calls.is_none()); +} + +#[test] +fn test_true_incremental_streaming_named() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-stream-inc-1".to_string()); + + // Simulate realistic token-by-token streaming + let chunks = vec![ + r#"{"#, + r#""location""#, + r#":"#, + r#""Paris""#, + r#","#, + r#""unit""#, + r#":"#, + r#""celsius""#, + r#"}"#, + ]; + + let mut responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("chunk"); + responses.push(response); + } + + // Should have emitted tool_calls in one of the early chunks + let first_tool_call_idx = responses + .iter() + .position(|r| r.choices[0].delta.tool_calls.is_some()) + .expect("Should find tool_calls in stream"); + + // First tool call should have id, type, name + let first_tc = &responses[first_tool_call_idx].choices[0] + .delta + .tool_calls + .as_ref() + .unwrap()[0]; + assert!(first_tc.id.is_some()); + assert_eq!(first_tc.r#type, Some(ChatCompletionToolType::Function)); + assert_eq!( + first_tc.function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + + // Should have multiple chunks with arguments deltas + let args_chunks: Vec<_> = responses + .iter() + .filter_map(|r| r.choices[0].delta.tool_calls.as_ref()) + .flat_map(|tcs| tcs.iter()) + .filter_map(|tc| tc.function.as_ref()?.arguments.as_ref()) + .collect(); + + assert!( + args_chunks.len() > 1, + "Should have multiple argument delta chunks" + ); +} + +#[test] +fn test_true_incremental_streaming_parallel() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-stream-inc-2".to_string()); + + // Simulate streaming: array with two tool calls + let chunks = vec![ + r#"["#, + r#"{"name":"search","#, + r#""parameters":{"query":"rust"}"#, + r#"},"#, + r#"{"name":"summarize","#, + r#""parameters":{"topic":"memory"}"#, + r#"}]"#, + ]; + + let mut responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("chunk"); + responses.push(response); + } + + // Count tool call initializations (first chunks with names) + let mut tool_names_seen = std::collections::HashSet::new(); + for resp in &responses { + if let Some(tool_calls) = &resp.choices[0].delta.tool_calls { + for tc in tool_calls { + if let Some(func) = &tc.function { + if let Some(name) = &func.name { + tool_names_seen.insert(name.clone()); + } + } + } + } + } + + assert_eq!( + tool_names_seen.len(), + 2, + "Should detect both tool calls" + ); + assert!(tool_names_seen.contains("search")); + assert!(tool_names_seen.contains("summarize")); + + // Verify that tool calls are streamed incrementally, not just at the end + let chunks_with_tool_calls: Vec<_> = responses + .iter() + .enumerate() + .filter(|(_, r)| r.choices[0].delta.tool_calls.is_some()) + .map(|(i, _)| i) + .collect(); + + assert!( + chunks_with_tool_calls.len() > 1, + "Should have multiple chunks with tool_calls (not just final)" + ); +} + +/// Helper function to create a streaming chunk +fn create_chunk( + index: u32, + role: Option, + tool_call_chunk: Option, + finish_reason: Option, +) -> dynamo_async_openai::types::CreateChatCompletionStreamResponse { + use dynamo_async_openai::types::{ + ChatCompletionStreamResponseDelta, CreateChatCompletionStreamResponse, + }; + + CreateChatCompletionStreamResponse { + id: "test".to_string(), + choices: vec![ChatChoiceStream { + index, + delta: ChatCompletionStreamResponseDelta { + role, + content: None, + function_call: None, + tool_calls: tool_call_chunk.map(|chunk| vec![chunk]), + refusal: None, + reasoning_content: None, + }, + finish_reason, + logprobs: None, + }], + created: 1234567890, + model: "test-model".to_string(), + system_fingerprint: None, + object: "chat.completion.chunk".to_string(), + service_tier: None, + usage: None, + nvext: None, + } +} + +#[tokio::test] +async fn test_aggregator_named_tool_accumulates_arguments() { + use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; + use dynamo_llm::protocols::openai::ParsingOptions; + use dynamo_llm::protocols::Annotated; + use futures::stream; + + // Simulate streaming chunks for named tool choice: get_weather + let chunks = vec![ + // Chunk 1: role + create_chunk(0, Some(dynamo_async_openai::types::Role::Assistant), None, None), + // Chunk 2: tool call start (id, type, name, empty arguments) + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 0, + id: Some("call_1".to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some("get_weather".to_string()), + arguments: Some(String::new()), + }), + }), + None, + ), + // Chunk 3: first part of arguments (raw JSON fragment from buffer) + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 0, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(r#"{"location":"Paris""#.to_string()), + }), + }), + None, + ), + // Chunk 4: second part of arguments (continuation) + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 0, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(r#","unit":"celsius"}"#.to_string()), + }), + }), + None, + ), + // Chunk 5: finish + create_chunk( + 0, + None, + None, + Some(dynamo_async_openai::types::FinishReason::Stop), + ), + ]; + + // Convert to Annotated stream + let annotated_chunks: Vec> = chunks + .into_iter() + .map(|chunk| Annotated { + data: Some(chunk), + id: None, + event: None, + comment: None, + }) + .collect(); + + let stream = Box::pin(stream::iter(annotated_chunks)); + + // Aggregate the stream + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; + + assert!(result.is_ok()); + let response = result.unwrap(); + + // Verify aggregated response + assert_eq!(response.choices.len(), 1); + let choice = &response.choices[0]; + + // Check tool calls + assert!(choice.message.tool_calls.is_some()); + let tool_calls = choice.message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + + let tool_call = &tool_calls[0]; + assert_eq!(tool_call.id, "call_1"); + assert_eq!(tool_call.function.name, "get_weather"); + // THIS IS THE KEY ASSERTION - arguments should be accumulated! + assert_eq!( + tool_call.function.arguments, + r#"{"location":"Paris","unit":"celsius"}"#, + "Arguments should be fully accumulated from all chunks" + ); +} + +#[tokio::test] +async fn test_aggregator_required_tool_parallel_calls() { + use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; + use dynamo_llm::protocols::openai::ParsingOptions; + use dynamo_llm::protocols::Annotated; + use dynamo_async_openai::types::{ + ChatCompletionMessageToolCallChunk, ChatCompletionToolType, FunctionCallStream, + }; + use futures::stream; + + // Simulate streaming chunks for required tool choice with parallel calls + let chunks = vec![ + // Chunk 1: role + create_chunk(0, Some(dynamo_async_openai::types::Role::Assistant), None, None), + // Chunk 2: first tool call start + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 0, + id: Some("call_1".to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some("search".to_string()), + arguments: Some(String::new()), + }), + }), + None, + ), + // Chunk 3: first tool arguments + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 0, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(r#"{"query":"rust"}"#.to_string()), + }), + }), + None, + ), + // Chunk 4: second tool call start + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 1, + id: Some("call_2".to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some("summarize".to_string()), + arguments: Some(String::new()), + }), + }), + None, + ), + // Chunk 5: second tool arguments (partial) + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 1, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(r#"{"text":"#.to_string()), + }), + }), + None, + ), + // Chunk 6: second tool arguments (rest) + create_chunk( + 0, + None, + Some(ChatCompletionMessageToolCallChunk { + index: 1, + id: None, + r#type: None, + function: Some(FunctionCallStream { + name: None, + arguments: Some(r#""long article"}"#.to_string()), + }), + }), + None, + ), + // Chunk 7: finish + create_chunk( + 0, + None, + None, + Some(dynamo_async_openai::types::FinishReason::ToolCalls), + ), + ]; + + // Convert to Annotated stream + let annotated_chunks: Vec> = chunks + .into_iter() + .map(|chunk| Annotated { + data: Some(chunk), + id: None, + event: None, + comment: None, + }) + .collect(); + + let stream = Box::pin(stream::iter(annotated_chunks)); + + // Aggregate the stream + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; + + assert!(result.is_ok()); + let response = result.unwrap(); + + // Verify aggregated response + assert_eq!(response.choices.len(), 1); + let choice = &response.choices[0]; + + // Check tool calls + assert!(choice.message.tool_calls.is_some()); + let tool_calls = choice.message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 2, "Should have 2 tool calls"); + + // Verify first tool call + let tool_call_1 = &tool_calls[0]; + assert_eq!(tool_call_1.id, "call_1"); + assert_eq!(tool_call_1.function.name, "search"); + assert_eq!( + tool_call_1.function.arguments, + r#"{"query":"rust"}"#, + "First tool arguments should be complete" + ); + + // Verify second tool call - THIS IS THE CRITICAL TEST + let tool_call_2 = &tool_calls[1]; + assert_eq!(tool_call_2.id, "call_2"); + assert_eq!(tool_call_2.function.name, "summarize"); + assert_eq!( + tool_call_2.function.arguments, + r#"{"text":"long article"}"#, + "Second tool arguments should be accumulated from multiple chunks" + ); + + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + ); +} + diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs new file mode 100644 index 00000000000..2bfe1dbeb53 --- /dev/null +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -0,0 +1,214 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Tests for tool_choice finish_reason handling. + +use dynamo_llm::protocols::common; +use dynamo_llm::protocols::common::llm_backend::BackendOutput; +use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; +use dynamo_llm::protocols::openai::DeltaGeneratorExt; +use dynamo_async_openai::types::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, + ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, + ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest, + FunctionName, +}; + +fn create_test_request() -> NvCreateChatCompletionRequest { + let messages = vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text("test".to_string()), + name: None, + }, + )]; + + NvCreateChatCompletionRequest { + inner: CreateChatCompletionRequest { + model: "test-model".to_string(), + messages, + stream: Some(false), + stream_options: None, + ..Default::default() + }, + common: Default::default(), + nvext: None, + chat_template_args: None, + unsupported_fields: Default::default(), + } +} + +fn build_backend_output_with_finish( + text: &str, + finish: common::FinishReason, +) -> BackendOutput { + BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(text.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: Some(finish), + index: Some(0), + completion_usage: None, + disaggregated_params: None, + } +} + +#[test] +fn test_named_tool_choice_preserves_length_finish_reason() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-length-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"location":"Par"#, // Incomplete due to length limit + common::FinishReason::Length, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: Length finish reason should be preserved, NOT replaced with Stop + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Length), + "Length finish reason must be preserved for tool_choice=named" + ); +} + +#[test] +fn test_required_tool_choice_preserves_length_finish_reason() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-length-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"search","parameters":{"query":"incomplete"#, // Incomplete due to length + common::FinishReason::Length, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: Length finish reason should be preserved, NOT replaced with ToolCalls + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Length), + "Length finish reason must be preserved for tool_choice=required" + ); +} + +#[test] +fn test_named_tool_choice_preserves_content_filter() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "search".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-filter-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"query":"filtered content"#, + common::FinishReason::ContentFilter, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: ContentFilter finish reason should be preserved + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ContentFilter), + "ContentFilter finish reason must be preserved for tool_choice=named" + ); +} + +#[test] +fn test_required_tool_choice_preserves_content_filter() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-filter-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"harmful_action"#, + common::FinishReason::ContentFilter, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: ContentFilter finish reason should be preserved + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ContentFilter), + "ContentFilter finish reason must be preserved for tool_choice=required" + ); +} + +#[test] +fn test_named_tool_choice_normal_stop_becomes_stop() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-stop-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"location":"Paris","unit":"celsius"}"#, + common::FinishReason::Stop, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Normal completion: Stop should remain Stop for named tool choice + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop), + ); +} + +#[test] +fn test_required_tool_choice_normal_stop_becomes_tool_calls() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-stop-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"search","parameters":{"query":"rust"}}]"#, + common::FinishReason::Stop, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Normal completion: Stop should become ToolCalls for required tool choice + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls), + ); +} + From bc4a6fe065e66cf21ecbd3e0d031883443e91b4c Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Wed, 3 Dec 2025 18:59:36 +0300 Subject: [PATCH 02/18] cargo fmt Signed-off-by: Vladislav Nosivskoy --- .../openai/chat_completions/aggregator.rs | 10 +- .../openai/chat_completions/delta.rs | 10 +- lib/llm/src/protocols/openai/partial_json.rs | 94 +++++++++------- lib/llm/src/protocols/openai/tools.rs | 13 +-- lib/llm/tests/tool_choice.rs | 100 ++++++++++-------- lib/llm/tests/tool_choice_finish_reasons.rs | 19 ++-- 6 files changed, 136 insertions(+), 110 deletions(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index 3cd47ca622f..c2274bb1239 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -65,7 +65,6 @@ impl Default for DeltaAggregator { } } - impl DeltaAggregator { /// Creates a new, empty [`DeltaAggregator`] instance. pub fn new() -> Self { @@ -284,8 +283,7 @@ impl From for dynamo_async_openai::types::ChatChoice { Some(dynamo_async_openai::types::FinishReason::Stop) | Some(dynamo_async_openai::types::FinishReason::Length) | Some(dynamo_async_openai::types::FinishReason::ContentFilter) - ) - { + ) { Some(dynamo_async_openai::types::FinishReason::ToolCalls) } else { delta.finish_reason @@ -1015,8 +1013,7 @@ mod tests { #[tokio::test] async fn test_tool_calling_preserves_length_finish_reason() { // Test that Length finish reason is preserved even with tool_calls present - let tool_call_json = - r#"{"name": "get_weather", "arguments": {"location": "Paris"}}"#; + let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "Paris"}}"#; let annotated_delta = create_test_delta( 0, @@ -1057,8 +1054,7 @@ mod tests { #[tokio::test] async fn test_tool_calling_preserves_content_filter_finish_reason() { // Test that ContentFilter finish reason is preserved even with tool_calls present - let tool_call_json = - r#"{"name": "harmful_action", "arguments": {}}"#; + let tool_call_json = r#"{"name": "harmful_action", "arguments": {}}"#; let annotated_delta = create_test_delta( 0, diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index fef1ff17986..7bff099c6eb 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -4,7 +4,10 @@ use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use crate::{ local_model::runtime_config::ModelRuntimeConfig, - protocols::{common::{self}, openai::partial_json::{loads, AllowPartial}}, + protocols::{ + common::{self}, + openai::partial_json::{AllowPartial, loads}, + }, types::TokenIdType, }; use dynamo_async_openai::types::{ @@ -401,7 +404,10 @@ impl DeltaGenerator { for current in ¤t_calls { // Check if this is a new tool or existing one - let previous = self.previous_tool_calls.iter().find(|p| p.index == current.index); + let previous = self + .previous_tool_calls + .iter() + .find(|p| p.index == current.index); match previous { None => { diff --git a/lib/llm/src/protocols/openai/partial_json.rs b/lib/llm/src/protocols/openai/partial_json.rs index e925655bf57..b575c711b51 100644 --- a/lib/llm/src/protocols/openai/partial_json.rs +++ b/lib/llm/src/protocols/openai/partial_json.rs @@ -139,13 +139,14 @@ pub fn fix_json(json_string: &str, allow: AllowPartial) -> (String, String) { // Handle unclosed strings if !allow.strings && in_string { if let Some(last_container) = stack.back() - && last_container.char == '{' { - // Truncate before the unclosed string key - return ( - head[..=last_container.index].to_string(), - join_closing_tokens(&stack), - ); - } + && last_container.char == '{' + { + // Truncate before the unclosed string key + return ( + head[..=last_container.index].to_string(), + join_closing_tokens(&stack), + ); + } // Find last comma before the unclosed string if let Some(string_start) = last_string_start { @@ -163,16 +164,18 @@ pub fn fix_json(json_string: &str, allow: AllowPartial) -> (String, String) { } // Simple case: just close all open containers - if in_string && allow.strings - && let Some(string_start) = last_string_start { - // Fix the partial string - let partial_str = &head[string_start..]; - let (fixed_head, fixed_tail) = simple_fix(partial_str, allow); - return ( - format!("{}{}", &head[..string_start], fixed_head), - format!("{}{}", fixed_tail, join_closing_tokens(&stack)), - ); - } + if in_string + && allow.strings + && let Some(string_start) = last_string_start + { + // Fix the partial string + let partial_str = &head[string_start..]; + let (fixed_head, fixed_tail) = simple_fix(partial_str, allow); + return ( + format!("{}{}", &head[..string_start], fixed_head), + format!("{}{}", fixed_tail, join_closing_tokens(&stack)), + ); + } (head.to_string(), join_closing_tokens(&stack)) } @@ -182,27 +185,26 @@ fn simple_fix(json_string: &str, allow: AllowPartial) -> (String, String) { let trimmed = json_string.trim_end(); // Handle unclosed strings - if trimmed.starts_with('"') - && allow.strings { - // Count how many unescaped quotes we have - let mut escaped = false; - let mut quote_count = 0; - for ch in trimmed.chars() { - if ch == '\\' && !escaped { - escaped = true; - } else { - if ch == '"' && !escaped { - quote_count += 1; - } - escaped = false; + if trimmed.starts_with('"') && allow.strings { + // Count how many unescaped quotes we have + let mut escaped = false; + let mut quote_count = 0; + for ch in trimmed.chars() { + if ch == '\\' && !escaped { + escaped = true; + } else { + if ch == '"' && !escaped { + quote_count += 1; } + escaped = false; } + } - if quote_count % 2 == 1 { - // Unclosed string - return (trimmed.to_string(), "\"".to_string()); - } + if quote_count % 2 == 1 { + // Unclosed string + return (trimmed.to_string(), "\"".to_string()); } + } // Already complete or can't fix (trimmed.to_string(), String::new()) @@ -218,7 +220,10 @@ pub fn ensure_json(json_string: &str, allow: AllowPartial) -> String { /// /// This is the main function inspired by partial-json-parser's `loads()`. /// It completes the partial JSON and then parses it. -pub fn loads(json_string: &str, allow: AllowPartial) -> Result { +pub fn loads( + json_string: &str, + allow: AllowPartial, +) -> Result { let completed = ensure_json(json_string, allow); serde_json::from_str(&completed) } @@ -279,7 +284,11 @@ mod tests { assert_eq!(result2["location"], "Paris"); // Test 3: Complete object - let result3 = loads(r#"{"location":"Paris","unit":"celsius"}"#, AllowPartial::all()).unwrap(); + let result3 = loads( + r#"{"location":"Paris","unit":"celsius"}"#, + AllowPartial::all(), + ) + .unwrap(); assert_eq!(result3["location"], "Paris"); assert_eq!(result3["unit"], "celsius"); } @@ -287,7 +296,11 @@ mod tests { #[test] fn test_loads_array_incremental() { // Test 4: Array with unclosed parameter value - let result4 = loads(r#"[{"name":"search","parameters":{"query":""#, AllowPartial::all()).unwrap(); + let result4 = loads( + r#"[{"name":"search","parameters":{"query":""#, + AllowPartial::all(), + ) + .unwrap(); assert!(result4.is_array()); let arr = result4.as_array().unwrap(); assert_eq!(arr.len(), 1); @@ -295,7 +308,11 @@ mod tests { assert_eq!(arr[0]["parameters"]["query"], ""); // Test 5: Complete first tool, starting second - let result5 = loads(r#"[{"name":"search","parameters":{"query":"rust"}},"#, AllowPartial::all()).unwrap(); + let result5 = loads( + r#"[{"name":"search","parameters":{"query":"rust"}},"#, + AllowPartial::all(), + ) + .unwrap(); assert!(result5.is_array()); let arr = result5.as_array().unwrap(); assert_eq!(arr.len(), 1); @@ -303,4 +320,3 @@ mod tests { assert_eq!(arr[0]["parameters"]["query"], "rust"); } } - diff --git a/lib/llm/src/protocols/openai/tools.rs b/lib/llm/src/protocols/openai/tools.rs index f415811f911..ad881d50dc6 100644 --- a/lib/llm/src/protocols/openai/tools.rs +++ b/lib/llm/src/protocols/openai/tools.rs @@ -94,12 +94,13 @@ fn build_required_schema(tools: &[ChatCompletionTool]) -> Result NvCreateChatCompletionRequest { let messages = vec![ChatCompletionRequestMessage::User( @@ -81,10 +81,15 @@ fn test_named_tool_choice_parses_json() { let tool_calls = delta.tool_calls.as_ref().unwrap(); // In streaming mode, we emit 2 chunks: first with id/name, second with arguments - assert!(tool_calls.len() >= 1, "Should have at least 1 tool call chunk"); + assert!( + tool_calls.len() >= 1, + "Should have at least 1 tool call chunk" + ); // Find the chunk with the name (first chunk) - let name_chunk = tool_calls.iter().find(|tc| tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()); + let name_chunk = tool_calls + .iter() + .find(|tc| tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()); assert!(name_chunk.is_some(), "Should have chunk with name"); let name_chunk = name_chunk.unwrap(); @@ -97,7 +102,8 @@ fn test_named_tool_choice_parses_json() { // Arguments may be in the same chunk or a subsequent one let has_arguments = tool_calls.iter().any(|tc| { - tc.function.as_ref() + tc.function + .as_ref() .and_then(|f| f.arguments.as_ref()) .is_some_and(|args| !args.is_empty()) }); @@ -142,14 +148,16 @@ fn test_required_tool_choice_parses_json_array() { // First tool: arguments chunk assert_eq!(tool_calls[1].index, 0); assert!(tool_calls[1].function.as_ref().unwrap().name.is_none()); - assert!(tool_calls[1] - .function - .as_ref() - .unwrap() - .arguments - .as_ref() - .unwrap() - .contains("rust")); + assert!( + tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_ref() + .unwrap() + .contains("rust") + ); // Second tool: name chunk assert_eq!(tool_calls[2].index, 1); @@ -162,14 +170,16 @@ fn test_required_tool_choice_parses_json_array() { // Second tool: arguments chunk assert_eq!(tool_calls[3].index, 1); assert!(tool_calls[3].function.as_ref().unwrap().name.is_none()); - assert!(tool_calls[3] - .function - .as_ref() - .unwrap() - .arguments - .as_ref() - .unwrap() - .contains("memory")); + assert!( + tool_calls[3] + .function + .as_ref() + .unwrap() + .arguments + .as_ref() + .unwrap() + .contains("memory") + ); } #[test] @@ -532,11 +542,7 @@ fn test_true_incremental_streaming_parallel() { } } - assert_eq!( - tool_names_seen.len(), - 2, - "Should detect both tool calls" - ); + assert_eq!(tool_names_seen.len(), 2, "Should detect both tool calls"); assert!(tool_names_seen.contains("search")); assert!(tool_names_seen.contains("summarize")); @@ -592,15 +598,20 @@ fn create_chunk( #[tokio::test] async fn test_aggregator_named_tool_accumulates_arguments() { - use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; - use dynamo_llm::protocols::openai::ParsingOptions; use dynamo_llm::protocols::Annotated; + use dynamo_llm::protocols::openai::ParsingOptions; + use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; use futures::stream; // Simulate streaming chunks for named tool choice: get_weather let chunks = vec![ // Chunk 1: role - create_chunk(0, Some(dynamo_async_openai::types::Role::Assistant), None, None), + create_chunk( + 0, + Some(dynamo_async_openai::types::Role::Assistant), + None, + None, + ), // Chunk 2: tool call start (id, type, name, empty arguments) create_chunk( 0, @@ -688,26 +699,30 @@ async fn test_aggregator_named_tool_accumulates_arguments() { assert_eq!(tool_call.function.name, "get_weather"); // THIS IS THE KEY ASSERTION - arguments should be accumulated! assert_eq!( - tool_call.function.arguments, - r#"{"location":"Paris","unit":"celsius"}"#, + tool_call.function.arguments, r#"{"location":"Paris","unit":"celsius"}"#, "Arguments should be fully accumulated from all chunks" ); } #[tokio::test] async fn test_aggregator_required_tool_parallel_calls() { - use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; - use dynamo_llm::protocols::openai::ParsingOptions; - use dynamo_llm::protocols::Annotated; use dynamo_async_openai::types::{ ChatCompletionMessageToolCallChunk, ChatCompletionToolType, FunctionCallStream, }; + use dynamo_llm::protocols::Annotated; + use dynamo_llm::protocols::openai::ParsingOptions; + use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; use futures::stream; // Simulate streaming chunks for required tool choice with parallel calls let chunks = vec![ // Chunk 1: role - create_chunk(0, Some(dynamo_async_openai::types::Role::Assistant), None, None), + create_chunk( + 0, + Some(dynamo_async_openai::types::Role::Assistant), + None, + None, + ), // Chunk 2: first tool call start create_chunk( 0, @@ -825,8 +840,7 @@ async fn test_aggregator_required_tool_parallel_calls() { assert_eq!(tool_call_1.id, "call_1"); assert_eq!(tool_call_1.function.name, "search"); assert_eq!( - tool_call_1.function.arguments, - r#"{"query":"rust"}"#, + tool_call_1.function.arguments, r#"{"query":"rust"}"#, "First tool arguments should be complete" ); @@ -835,8 +849,7 @@ async fn test_aggregator_required_tool_parallel_calls() { assert_eq!(tool_call_2.id, "call_2"); assert_eq!(tool_call_2.function.name, "summarize"); assert_eq!( - tool_call_2.function.arguments, - r#"{"text":"long article"}"#, + tool_call_2.function.arguments, r#"{"text":"long article"}"#, "Second tool arguments should be accumulated from multiple chunks" ); @@ -845,4 +858,3 @@ async fn test_aggregator_required_tool_parallel_calls() { Some(dynamo_async_openai::types::FinishReason::ToolCalls) ); } - diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs index 2bfe1dbeb53..99892f781c7 100644 --- a/lib/llm/tests/tool_choice_finish_reasons.rs +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -3,16 +3,15 @@ //! Tests for tool_choice finish_reason handling. +use dynamo_async_openai::types::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, + ChatCompletionToolType, CreateChatCompletionRequest, FunctionName, +}; use dynamo_llm::protocols::common; use dynamo_llm::protocols::common::llm_backend::BackendOutput; -use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; use dynamo_llm::protocols::openai::DeltaGeneratorExt; -use dynamo_async_openai::types::{ - ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, - ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, - ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest, - FunctionName, -}; +use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; fn create_test_request() -> NvCreateChatCompletionRequest { let messages = vec![ChatCompletionRequestMessage::User( @@ -37,10 +36,7 @@ fn create_test_request() -> NvCreateChatCompletionRequest { } } -fn build_backend_output_with_finish( - text: &str, - finish: common::FinishReason, -) -> BackendOutput { +fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> BackendOutput { BackendOutput { token_ids: vec![], tokens: vec![], @@ -211,4 +207,3 @@ fn test_required_tool_choice_normal_stop_becomes_tool_calls() { Some(dynamo_async_openai::types::FinishReason::ToolCalls), ); } - From 5ba3d4e32c9ff599548c062883187c1e04fee5dc Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Wed, 3 Dec 2025 19:04:17 +0300 Subject: [PATCH 03/18] revert deprecated field Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/protocols/openai/chat_completions/delta.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 7bff099c6eb..a50552f05bb 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -281,6 +281,7 @@ impl DeltaGenerator { ) -> NvCreateChatCompletionStreamResponse { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { content: text, + function_call: None, tool_calls, role: if self.msg_counter == 0 { Some(dynamo_async_openai::types::Role::Assistant) From 773d86b9d6b6df44e0ac99e6a318488172bf13a7 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Wed, 3 Dec 2025 19:09:32 +0300 Subject: [PATCH 04/18] replace deprecated annotation Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/protocols/openai/chat_completions/delta.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index a50552f05bb..407b0e81857 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -259,7 +259,6 @@ impl DeltaGenerator { /// /// # Returns /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice. - #[allow(deprecated)] pub fn create_choice( &mut self, index: u32, @@ -271,6 +270,7 @@ impl DeltaGenerator { } /// Internal method to build a streaming chat completion response with optional tool_calls. + #[allow(deprecated)] fn build_choice( &mut self, index: u32, From 69864826c5378978044ed0298041d7b7e617b24b Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Wed, 3 Dec 2025 19:24:04 +0300 Subject: [PATCH 05/18] fix clippy Signed-off-by: Vladislav Nosivskoy --- lib/llm/tests/tool_choice.rs | 41 +++++++++++++++++------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index 5899857e784..722f5dcc193 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -4,16 +4,13 @@ use dynamo_async_openai::types::{ ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, - ChatCompletionRequestUserMessageContent, ChatCompletionStreamResponseDelta, - ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest, - CreateChatCompletionStreamResponse, FunctionCallStream, FunctionName, + ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, + ChatCompletionToolType, CreateChatCompletionRequest, FunctionCallStream, FunctionName, }; use dynamo_llm::protocols::common; use dynamo_llm::protocols::common::llm_backend::BackendOutput; use dynamo_llm::protocols::openai::DeltaGeneratorExt; -use dynamo_llm::protocols::openai::chat_completions::{ - NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, -}; +use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; fn create_test_request() -> NvCreateChatCompletionRequest { let messages = vec![ChatCompletionRequestMessage::User( @@ -82,7 +79,7 @@ fn test_named_tool_choice_parses_json() { // In streaming mode, we emit 2 chunks: first with id/name, second with arguments assert!( - tool_calls.len() >= 1, + !tool_calls.is_empty(), "Should have at least 1 tool call chunk" ); @@ -215,7 +212,7 @@ fn test_streaming_named_tool_incremental() { // Simulate streaming chunks // For simplicity in testing, send complete JSON in final chunk - let chunks = vec![r#"{"location":"Paris","unit":"celsius"}"#]; + let chunks = [r#"{"location":"Paris","unit":"celsius"}"#]; let mut all_responses = Vec::new(); for (i, chunk) in chunks.iter().enumerate() { @@ -264,7 +261,7 @@ fn test_streaming_required_tool_parallel() { let mut generator = request.response_generator("req-stream-2".to_string()); // Simulate streaming array of tool calls - let chunks = vec![ + let chunks = [ r#"[{"name":"search","parameters":{"query":"rust"}},"#, r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#, ]; @@ -307,14 +304,14 @@ fn test_streaming_required_tool_parallel() { for resp in &all_responses { if let Some(tool_calls) = &resp.choices[0].delta.tool_calls { for tc in tool_calls { - if let Some(func) = &tc.function { - if let Some(name) = &func.name { - if name == "search" { - found_search = true; - } - if name == "summarize" { - found_summarize = true; - } + if let Some(func) = &tc.function + && let Some(name) = &func.name + { + if name == "search" { + found_search = true; + } + if name == "summarize" { + found_summarize = true; } } } @@ -493,7 +490,7 @@ fn test_true_incremental_streaming_parallel() { let mut generator = request.response_generator("req-stream-inc-2".to_string()); // Simulate streaming: array with two tool calls - let chunks = vec![ + let chunks = [ r#"["#, r#"{"name":"search","#, r#""parameters":{"query":"rust"}"#, @@ -533,10 +530,10 @@ fn test_true_incremental_streaming_parallel() { for resp in &responses { if let Some(tool_calls) = &resp.choices[0].delta.tool_calls { for tc in tool_calls { - if let Some(func) = &tc.function { - if let Some(name) = &func.name { - tool_names_seen.insert(name.clone()); - } + if let Some(func) = &tc.function + && let Some(name) = &func.name + { + tool_names_seen.insert(name.clone()); } } } From 941fa997af189b3b3ffb623ce05f39094b2c225e Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 14:00:43 +0300 Subject: [PATCH 06/18] remove honest fc streaming Signed-off-by: Vladislav Nosivskoy --- .../openai/chat_completions/delta.rs | 182 ++----- lib/llm/tests/tool_choice.rs | 469 ++++-------------- 2 files changed, 128 insertions(+), 523 deletions(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 407b0e81857..2a68dfe55c9 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -117,17 +117,6 @@ pub struct DeltaGenerator { options: DeltaGeneratorOptions, /// Buffer for accumulating tool call JSON during streaming tool_call_buffer: String, - /// Length of buffer that was already emitted (for delta calculation) - previous_buffer_len: usize, - /// Previous parsed state (for detecting new tool calls in Required mode) - previous_tool_calls: Vec, -} - -#[derive(Debug, Clone)] -struct ToolCallState { - index: usize, - name: String, - arguments: String, } impl DeltaGenerator { @@ -171,8 +160,6 @@ impl DeltaGenerator { msg_counter: 0, options, tool_call_buffer: String::new(), - previous_buffer_len: 0, - previous_tool_calls: Vec::new(), } } @@ -301,6 +288,8 @@ impl DeltaGenerator { let choices = vec![choice]; + self.msg_counter += 1; + // According to OpenAI spec: when stream_options.include_usage is true, // all intermediate chunks should have usage: null // The final usage chunk will be sent separately with empty choices @@ -344,40 +333,46 @@ impl DeltaGenerator { self.options.enable_usage } - fn process_streaming_tool_calls( + fn process_complete_tool_calls( &mut self, delta_text: &str, - _is_final: bool, - ) -> anyhow::Result> { - // Accumulate the delta into buffer + is_final: bool, + ) -> anyhow::Result>> { self.tool_call_buffer.push_str(delta_text); - // Parse the current buffer state using partial JSON parser + if !is_final { + return Ok(None); + } + let parsed = match loads(&self.tool_call_buffer, AllowPartial::all()) { Ok(value) => value, - Err(_e) => { - // If we can't parse yet, just wait for more data - return Ok(vec![]); + Err(e) => { + tracing::warn!( + error = %e, + buffer = %self.tool_call_buffer, + "failed to parse tool_choice output" + ); + return Ok(None); } }; - // Extract current tool calls from parsed JSON - let current_calls = match &self.options.tool_choice { + let chunks = match &self.options.tool_choice { ToolChoiceContext::Named(name) => { - // For named tool choice, the output is raw JSON parameters - // Use the buffer directly, not serialized if parsed.as_object().is_some() { - vec![ToolCallState { + vec![ChatCompletionMessageToolCallChunk { index: 0, - name: name.clone(), - arguments: self.tool_call_buffer.clone(), + id: Some("call-1".to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(name.clone()), + arguments: Some(self.tool_call_buffer.clone()), + }), }] } else { vec![] } } ToolChoiceContext::Required => { - // For required, parse the array of tool calls if let Some(array) = parsed.as_array() { array .iter() @@ -385,11 +380,15 @@ impl DeltaGenerator { .filter_map(|(idx, entry)| { let name = entry.get("name")?.as_str()?.to_string(); let parameters = entry.get("parameters")?; - let args = serde_json::to_string(parameters).unwrap_or_default(); - Some(ToolCallState { - index: idx, - name, - arguments: args, + let args = serde_json::to_string(parameters).ok()?; + Some(ChatCompletionMessageToolCallChunk { + index: idx as u32, + id: Some(format!("call-{}", idx + 1)), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(name), + arguments: Some(args), + }), }) }) .collect() @@ -400,106 +399,7 @@ impl DeltaGenerator { ToolChoiceContext::None => vec![], }; - // Generate deltas by comparing with previous state - let mut chunks = Vec::new(); - - for current in ¤t_calls { - // Check if this is a new tool or existing one - let previous = self - .previous_tool_calls - .iter() - .find(|p| p.index == current.index); - - match previous { - None => { - // New tool - emit first chunk with id, type, name - chunks.push(ChatCompletionMessageToolCallChunk { - index: current.index as u32, - id: Some(format!("call_{}", current.index + 1)), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(current.name.clone()), - arguments: Some(String::new()), - }), - }); - - // For Named mode, emit delta from buffer - // For Required mode, emit serialized parameters - if matches!(self.options.tool_choice, ToolChoiceContext::Named(_)) { - // Use raw buffer delta for Named mode - if self.tool_call_buffer.len() > self.previous_buffer_len { - let delta = &self.tool_call_buffer[self.previous_buffer_len..]; - if !delta.is_empty() { - chunks.push(ChatCompletionMessageToolCallChunk { - index: current.index as u32, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(delta.to_string()), - }), - }); - } - } - } else { - // For Required mode, emit full arguments (serialized parameters) - if !current.arguments.is_empty() { - chunks.push(ChatCompletionMessageToolCallChunk { - index: current.index as u32, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(current.arguments.clone()), - }), - }); - } - } - } - Some(prev) => { - // Existing tool - emit delta of arguments - if matches!(self.options.tool_choice, ToolChoiceContext::Named(_)) { - // For Named mode, use raw buffer delta - if self.tool_call_buffer.len() > self.previous_buffer_len { - let delta = &self.tool_call_buffer[self.previous_buffer_len..]; - if !delta.is_empty() { - chunks.push(ChatCompletionMessageToolCallChunk { - index: current.index as u32, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(delta.to_string()), - }), - }); - } - } - } else { - // For Required mode, compute delta from serialized arguments - if current.arguments.len() > prev.arguments.len() { - let delta = ¤t.arguments[prev.arguments.len()..]; - if !delta.is_empty() { - chunks.push(ChatCompletionMessageToolCallChunk { - index: current.index as u32, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(delta.to_string()), - }), - }); - } - } - } - } - } - } - - // Update previous state - self.previous_tool_calls = current_calls; - self.previous_buffer_len = self.tool_call_buffer.len(); - - Ok(chunks) + Ok(Some(chunks)) } fn determine_streaming_finish_reason( @@ -616,31 +516,29 @@ impl crate::protocols::openai::DeltaGeneratorExt { + match self.process_complete_tool_calls(raw, is_final) { + Ok(Some(chunks)) if !chunks.is_empty() => { tool_call_chunks = Some(chunks); - delta_text = None; // Don't emit raw text when streaming tools + delta_text = None; } Ok(_) => { - // No chunks yet, suppress text output delta_text = None; } Err(err) => { error!( error = %err, - "failed to parse streaming tool_choice output" + "failed to parse tool_choice output" ); + delta_text = None; } } } - // Override finish reason for tool_choice modes if finish_reason.is_some() { finish_reason = self.determine_streaming_finish_reason(backend_finish_reason); } diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index 722f5dcc193..71d947f02ae 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -77,34 +77,20 @@ fn test_named_tool_choice_parses_json() { assert!(delta.content.is_none()); let tool_calls = delta.tool_calls.as_ref().unwrap(); - // In streaming mode, we emit 2 chunks: first with id/name, second with arguments - assert!( - !tool_calls.is_empty(), - "Should have at least 1 tool call chunk" - ); - - // Find the chunk with the name (first chunk) - let name_chunk = tool_calls - .iter() - .find(|tc| tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()); - assert!(name_chunk.is_some(), "Should have chunk with name"); - let name_chunk = name_chunk.unwrap(); + assert_eq!(tool_calls.len(), 1); - assert_eq!(name_chunk.index, 0); - assert_eq!(name_chunk.id.as_deref(), Some("call_1")); + let tool_call = &tool_calls[0]; + assert_eq!(tool_call.index, 0); + assert_eq!(tool_call.id.as_deref(), Some("call-1")); + assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function)); assert_eq!( - name_chunk.function.as_ref().unwrap().name.as_deref(), + tool_call.function.as_ref().unwrap().name.as_deref(), Some("get_weather") ); - - // Arguments may be in the same chunk or a subsequent one - let has_arguments = tool_calls.iter().any(|tc| { - tc.function - .as_ref() - .and_then(|f| f.arguments.as_ref()) - .is_some_and(|args| !args.is_empty()) - }); - assert!(has_arguments, "Should have arguments in some chunk"); + assert_eq!( + tool_call.function.as_ref().unwrap().arguments.as_deref(), + Some(r#"{"location":"Paris"}"#) + ); } #[test] @@ -130,52 +116,30 @@ fn test_required_tool_choice_parses_json_array() { assert!(delta.content.is_none()); let tool_calls = delta.tool_calls.as_ref().unwrap(); - // With incremental streaming, we emit separate chunks for name and arguments - // Expected: 4 chunks total (2 per tool: name chunk + arguments chunk) - assert_eq!(tool_calls.len(), 4); + assert_eq!(tool_calls.len(), 2); - // First tool: name chunk assert_eq!(tool_calls[0].index, 0); + assert_eq!(tool_calls[0].id.as_deref(), Some("call-1")); + assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function)); assert_eq!( tool_calls[0].function.as_ref().unwrap().name.as_deref(), Some("search") ); - assert!(tool_calls[0].id.is_some()); - - // First tool: arguments chunk - assert_eq!(tool_calls[1].index, 0); - assert!(tool_calls[1].function.as_ref().unwrap().name.is_none()); - assert!( - tool_calls[1] - .function - .as_ref() - .unwrap() - .arguments - .as_ref() - .unwrap() - .contains("rust") + assert_eq!( + tool_calls[0].function.as_ref().unwrap().arguments.as_deref(), + Some(r#"{"query":"rust"}"#) ); - // Second tool: name chunk - assert_eq!(tool_calls[2].index, 1); + assert_eq!(tool_calls[1].index, 1); + assert_eq!(tool_calls[1].id.as_deref(), Some("call-2")); + assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function)); assert_eq!( - tool_calls[2].function.as_ref().unwrap().name.as_deref(), + tool_calls[1].function.as_ref().unwrap().name.as_deref(), Some("summarize") ); - assert!(tool_calls[2].id.is_some()); - - // Second tool: arguments chunk - assert_eq!(tool_calls[3].index, 1); - assert!(tool_calls[3].function.as_ref().unwrap().name.is_none()); - assert!( - tool_calls[3] - .function - .as_ref() - .unwrap() - .arguments - .as_ref() - .unwrap() - .contains("memory") + assert_eq!( + tool_calls[1].function.as_ref().unwrap().arguments.as_deref(), + Some(r#"{"topic":"memory"}"#) ); } @@ -191,13 +155,12 @@ fn test_tool_choice_parse_failure_suppresses_text() { .expect("choice generation"); let delta = &response.choices[0].delta; - // When tool_choice is active but parsing fails, we suppress the text output assert!(delta.content.is_none()); assert!(delta.tool_calls.is_none()); } #[test] -fn test_streaming_named_tool_incremental() { +fn test_streaming_named_tool_buffers_until_finish() { let mut request = create_test_request(); request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( ChatCompletionNamedToolChoice { @@ -210,9 +173,11 @@ fn test_streaming_named_tool_incremental() { let mut generator = request.response_generator("req-stream-1".to_string()); - // Simulate streaming chunks - // For simplicity in testing, send complete JSON in final chunk - let chunks = [r#"{"location":"Paris","unit":"celsius"}"#]; + let chunks = [ + r#"{"location":""#, + r#"Paris","unit":""#, + r#"celsius"}"#, + ]; let mut all_responses = Vec::new(); for (i, chunk) in chunks.iter().enumerate() { @@ -239,18 +204,23 @@ fn test_streaming_named_tool_incremental() { all_responses.push(response); } - // Last response should have finish_reason + for i in 0..all_responses.len() - 1 { + assert!(all_responses[i].choices[0].delta.tool_calls.is_none()); + } + let last_response = all_responses.last().unwrap(); assert_eq!( last_response.choices[0].finish_reason, Some(dynamo_async_openai::types::FinishReason::Stop) ); - // Should have tool_calls somewhere in the stream - let has_tool_calls = all_responses - .iter() - .any(|r| r.choices[0].delta.tool_calls.is_some()); - assert!(has_tool_calls, "No tool calls found in any response"); + let tool_calls = last_response.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.as_ref().unwrap().name.as_deref(), Some("get_weather")); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().arguments.as_deref(), + Some(r#"{"location":"Paris","unit":"celsius"}"#) + ); } #[test] @@ -260,7 +230,6 @@ fn test_streaming_required_tool_parallel() { let mut generator = request.response_generator("req-stream-2".to_string()); - // Simulate streaming array of tool calls let chunks = [ r#"[{"name":"search","parameters":{"query":"rust"}},"#, r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#, @@ -291,38 +260,28 @@ fn test_streaming_required_tool_parallel() { all_responses.push(response); } - // Final chunk should have finish_reason = ToolCalls + for i in 0..all_responses.len() - 1 { + assert!(all_responses[i].choices[0].delta.tool_calls.is_none()); + } + let last_response = all_responses.last().unwrap(); assert_eq!( last_response.choices[0].finish_reason, Some(dynamo_async_openai::types::FinishReason::ToolCalls) ); - // Should have detected both tools - let mut found_search = false; - let mut found_summarize = false; - for resp in &all_responses { - if let Some(tool_calls) = &resp.choices[0].delta.tool_calls { - for tc in tool_calls { - if let Some(func) = &tc.function - && let Some(name) = &func.name - { - if name == "search" { - found_search = true; - } - if name == "summarize" { - found_summarize = true; - } - } - } - } - } - assert!(found_search, "Should detect search tool"); - assert!(found_summarize, "Should detect summarize tool"); + let tool_calls = last_response.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 2); + + assert_eq!(tool_calls[0].function.as_ref().unwrap().name.as_deref(), Some("search")); + assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments.as_deref(), Some(r#"{"query":"rust"}"#)); + + assert_eq!(tool_calls[1].function.as_ref().unwrap().name.as_deref(), Some("summarize")); + assert_eq!(tool_calls[1].function.as_ref().unwrap().arguments.as_deref(), Some(r#"{"topic":"memory"}"#)); } #[test] -fn test_streaming_with_incremental_arguments() { +fn test_streaming_buffers_until_finish() { let mut request = create_test_request(); request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( ChatCompletionNamedToolChoice { @@ -335,11 +294,11 @@ fn test_streaming_with_incremental_arguments() { let mut generator = request.response_generator("req-stream-3".to_string()); - // Character-by-character streaming let full_json = r#"{"query":"rust programming"}"#; let mut responses = Vec::new(); - for ch in full_json.chars() { + for (i, ch) in full_json.chars().enumerate() { + let is_last = i == full_json.len() - 1; let backend_output = BackendOutput { token_ids: vec![], tokens: vec![], @@ -347,7 +306,11 @@ fn test_streaming_with_incremental_arguments() { cum_log_probs: None, log_probs: None, top_logprobs: None, - finish_reason: None, + finish_reason: if is_last { + Some(common::FinishReason::Stop) + } else { + None + }, index: Some(0), completion_usage: None, disaggregated_params: None, @@ -359,16 +322,26 @@ fn test_streaming_with_incremental_arguments() { responses.push(response); } - // Should have suppressed raw text output for resp in &responses { assert!(resp.choices[0].delta.content.is_none()); } + + for i in 0..responses.len() - 1 { + assert!(responses[i].choices[0].delta.tool_calls.is_none()); + } + + let last = responses.last().unwrap(); + assert!(last.choices[0].delta.tool_calls.is_some()); + assert_eq!( + last.choices[0].delta.tool_calls.as_ref().unwrap()[0] + .function.as_ref().unwrap().arguments.as_deref(), + Some(r#"{"query":"rust programming"}"#) + ); } #[test] -fn test_no_streaming_when_tool_choice_none() { +fn test_no_tool_choice_outputs_normal_text() { let request = create_test_request(); - // tool_choice = None (default) let mut generator = request.response_generator("req-stream-4".to_string()); @@ -389,7 +362,6 @@ fn test_no_streaming_when_tool_choice_none() { .choice_from_postprocessor(backend_output) .expect("normal text"); - // Should have text content, not tool_calls assert_eq!( response.choices[0].delta.content.as_deref(), Some("Hello world") @@ -397,165 +369,6 @@ fn test_no_streaming_when_tool_choice_none() { assert!(response.choices[0].delta.tool_calls.is_none()); } -#[test] -fn test_true_incremental_streaming_named() { - let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( - ChatCompletionNamedToolChoice { - r#type: ChatCompletionToolType::Function, - function: FunctionName { - name: "get_weather".to_string(), - }, - }, - )); - - let mut generator = request.response_generator("req-stream-inc-1".to_string()); - - // Simulate realistic token-by-token streaming - let chunks = vec![ - r#"{"#, - r#""location""#, - r#":"#, - r#""Paris""#, - r#","#, - r#""unit""#, - r#":"#, - r#""celsius""#, - r#"}"#, - ]; - - let mut responses = Vec::new(); - for (i, chunk) in chunks.iter().enumerate() { - let backend_output = BackendOutput { - token_ids: vec![], - tokens: vec![], - text: Some(chunk.to_string()), - cum_log_probs: None, - log_probs: None, - top_logprobs: None, - finish_reason: if i == chunks.len() - 1 { - Some(common::FinishReason::Stop) - } else { - None - }, - index: Some(0), - completion_usage: None, - disaggregated_params: None, - }; - - let response = generator - .choice_from_postprocessor(backend_output) - .expect("chunk"); - responses.push(response); - } - - // Should have emitted tool_calls in one of the early chunks - let first_tool_call_idx = responses - .iter() - .position(|r| r.choices[0].delta.tool_calls.is_some()) - .expect("Should find tool_calls in stream"); - - // First tool call should have id, type, name - let first_tc = &responses[first_tool_call_idx].choices[0] - .delta - .tool_calls - .as_ref() - .unwrap()[0]; - assert!(first_tc.id.is_some()); - assert_eq!(first_tc.r#type, Some(ChatCompletionToolType::Function)); - assert_eq!( - first_tc.function.as_ref().unwrap().name.as_deref(), - Some("get_weather") - ); - - // Should have multiple chunks with arguments deltas - let args_chunks: Vec<_> = responses - .iter() - .filter_map(|r| r.choices[0].delta.tool_calls.as_ref()) - .flat_map(|tcs| tcs.iter()) - .filter_map(|tc| tc.function.as_ref()?.arguments.as_ref()) - .collect(); - - assert!( - args_chunks.len() > 1, - "Should have multiple argument delta chunks" - ); -} - -#[test] -fn test_true_incremental_streaming_parallel() { - let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); - - let mut generator = request.response_generator("req-stream-inc-2".to_string()); - - // Simulate streaming: array with two tool calls - let chunks = [ - r#"["#, - r#"{"name":"search","#, - r#""parameters":{"query":"rust"}"#, - r#"},"#, - r#"{"name":"summarize","#, - r#""parameters":{"topic":"memory"}"#, - r#"}]"#, - ]; - - let mut responses = Vec::new(); - for (i, chunk) in chunks.iter().enumerate() { - let backend_output = BackendOutput { - token_ids: vec![], - tokens: vec![], - text: Some(chunk.to_string()), - cum_log_probs: None, - log_probs: None, - top_logprobs: None, - finish_reason: if i == chunks.len() - 1 { - Some(common::FinishReason::Stop) - } else { - None - }, - index: Some(0), - completion_usage: None, - disaggregated_params: None, - }; - - let response = generator - .choice_from_postprocessor(backend_output) - .expect("chunk"); - responses.push(response); - } - - // Count tool call initializations (first chunks with names) - let mut tool_names_seen = std::collections::HashSet::new(); - for resp in &responses { - if let Some(tool_calls) = &resp.choices[0].delta.tool_calls { - for tc in tool_calls { - if let Some(func) = &tc.function - && let Some(name) = &func.name - { - tool_names_seen.insert(name.clone()); - } - } - } - } - - assert_eq!(tool_names_seen.len(), 2, "Should detect both tool calls"); - assert!(tool_names_seen.contains("search")); - assert!(tool_names_seen.contains("summarize")); - - // Verify that tool calls are streamed incrementally, not just at the end - let chunks_with_tool_calls: Vec<_> = responses - .iter() - .enumerate() - .filter(|(_, r)| r.choices[0].delta.tool_calls.is_some()) - .map(|(i, _)| i) - .collect(); - - assert!( - chunks_with_tool_calls.len() > 1, - "Should have multiple chunks with tool_calls (not just final)" - ); -} /// Helper function to create a streaming chunk fn create_chunk( @@ -594,67 +407,33 @@ fn create_chunk( } #[tokio::test] -async fn test_aggregator_named_tool_accumulates_arguments() { +async fn test_aggregator_named_tool() { use dynamo_llm::protocols::Annotated; use dynamo_llm::protocols::openai::ParsingOptions; use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; use futures::stream; - // Simulate streaming chunks for named tool choice: get_weather let chunks = vec![ - // Chunk 1: role create_chunk( 0, Some(dynamo_async_openai::types::Role::Assistant), None, None, ), - // Chunk 2: tool call start (id, type, name, empty arguments) create_chunk( 0, None, Some(ChatCompletionMessageToolCallChunk { index: 0, - id: Some("call_1".to_string()), + id: Some("call-1".to_string()), r#type: Some(ChatCompletionToolType::Function), function: Some(FunctionCallStream { name: Some("get_weather".to_string()), - arguments: Some(String::new()), + arguments: Some(r#"{"location":"Paris","unit":"celsius"}"#.to_string()), }), }), None, ), - // Chunk 3: first part of arguments (raw JSON fragment from buffer) - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 0, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(r#"{"location":"Paris""#.to_string()), - }), - }), - None, - ), - // Chunk 4: second part of arguments (continuation) - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 0, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(r#","unit":"celsius"}"#.to_string()), - }), - }), - None, - ), - // Chunk 5: finish create_chunk( 0, None, @@ -663,7 +442,6 @@ async fn test_aggregator_named_tool_accumulates_arguments() { ), ]; - // Convert to Annotated stream let annotated_chunks: Vec> = chunks .into_iter() .map(|chunk| Annotated { @@ -675,29 +453,23 @@ async fn test_aggregator_named_tool_accumulates_arguments() { .collect(); let stream = Box::pin(stream::iter(annotated_chunks)); - - // Aggregate the stream let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; assert!(result.is_ok()); let response = result.unwrap(); - // Verify aggregated response assert_eq!(response.choices.len(), 1); let choice = &response.choices[0]; - // Check tool calls assert!(choice.message.tool_calls.is_some()); let tool_calls = choice.message.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); let tool_call = &tool_calls[0]; - assert_eq!(tool_call.id, "call_1"); + assert_eq!(tool_call.id, "call-1"); assert_eq!(tool_call.function.name, "get_weather"); - // THIS IS THE KEY ASSERTION - arguments should be accumulated! assert_eq!( - tool_call.function.arguments, r#"{"location":"Paris","unit":"celsius"}"#, - "Arguments should be fully accumulated from all chunks" + tool_call.function.arguments, r#"{"location":"Paris","unit":"celsius"}"# ); } @@ -711,91 +483,41 @@ async fn test_aggregator_required_tool_parallel_calls() { use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; use futures::stream; - // Simulate streaming chunks for required tool choice with parallel calls let chunks = vec![ - // Chunk 1: role create_chunk( 0, Some(dynamo_async_openai::types::Role::Assistant), None, None, ), - // Chunk 2: first tool call start create_chunk( 0, None, Some(ChatCompletionMessageToolCallChunk { index: 0, - id: Some("call_1".to_string()), + id: Some("call-1".to_string()), r#type: Some(ChatCompletionToolType::Function), function: Some(FunctionCallStream { name: Some("search".to_string()), - arguments: Some(String::new()), - }), - }), - None, - ), - // Chunk 3: first tool arguments - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 0, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, arguments: Some(r#"{"query":"rust"}"#.to_string()), }), }), None, ), - // Chunk 4: second tool call start create_chunk( 0, None, Some(ChatCompletionMessageToolCallChunk { index: 1, - id: Some("call_2".to_string()), + id: Some("call-2".to_string()), r#type: Some(ChatCompletionToolType::Function), function: Some(FunctionCallStream { name: Some("summarize".to_string()), - arguments: Some(String::new()), + arguments: Some(r#"{"text":"long article"}"#.to_string()), }), }), None, ), - // Chunk 5: second tool arguments (partial) - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 1, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(r#"{"text":"#.to_string()), - }), - }), - None, - ), - // Chunk 6: second tool arguments (rest) - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 1, - id: None, - r#type: None, - function: Some(FunctionCallStream { - name: None, - arguments: Some(r#""long article"}"#.to_string()), - }), - }), - None, - ), - // Chunk 7: finish create_chunk( 0, None, @@ -804,7 +526,6 @@ async fn test_aggregator_required_tool_parallel_calls() { ), ]; - // Convert to Annotated stream let annotated_chunks: Vec> = chunks .into_iter() .map(|chunk| Annotated { @@ -816,39 +537,25 @@ async fn test_aggregator_required_tool_parallel_calls() { .collect(); let stream = Box::pin(stream::iter(annotated_chunks)); - - // Aggregate the stream let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; assert!(result.is_ok()); let response = result.unwrap(); - // Verify aggregated response assert_eq!(response.choices.len(), 1); let choice = &response.choices[0]; - // Check tool calls assert!(choice.message.tool_calls.is_some()); let tool_calls = choice.message.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 2, "Should have 2 tool calls"); + assert_eq!(tool_calls.len(), 2); - // Verify first tool call - let tool_call_1 = &tool_calls[0]; - assert_eq!(tool_call_1.id, "call_1"); - assert_eq!(tool_call_1.function.name, "search"); - assert_eq!( - tool_call_1.function.arguments, r#"{"query":"rust"}"#, - "First tool arguments should be complete" - ); + assert_eq!(tool_calls[0].id, "call-1"); + assert_eq!(tool_calls[0].function.name, "search"); + assert_eq!(tool_calls[0].function.arguments, r#"{"query":"rust"}"#); - // Verify second tool call - THIS IS THE CRITICAL TEST - let tool_call_2 = &tool_calls[1]; - assert_eq!(tool_call_2.id, "call_2"); - assert_eq!(tool_call_2.function.name, "summarize"); - assert_eq!( - tool_call_2.function.arguments, r#"{"text":"long article"}"#, - "Second tool arguments should be accumulated from multiple chunks" - ); + assert_eq!(tool_calls[1].id, "call-2"); + assert_eq!(tool_calls[1].function.name, "summarize"); + assert_eq!(tool_calls[1].function.arguments, r#"{"text":"long article"}"#); assert_eq!( choice.finish_reason, From 33424605a661979f3619ad37284130e1def25660 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:01:20 +0300 Subject: [PATCH 07/18] refactor on jail stream Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/preprocessor.rs | 39 +- .../openai/chat_completions/delta.rs | 173 +-------- .../protocols/openai/chat_completions/jail.rs | 359 ++++++++++++++---- lib/llm/tests/test_reasoning_parser.rs | 6 +- lib/llm/tests/test_streaming_tool_parsers.rs | 3 +- lib/llm/tests/tool_choice.rs | 188 ++++++--- lib/llm/tests/tool_choice_finish_reasons.rs | 57 ++- 7 files changed, 501 insertions(+), 324 deletions(-) diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 8dd3dcb947d..73c23236609 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -786,15 +786,36 @@ impl OpenAIPreprocessor { /// Apply tool calling jail to the stream if needed pub fn apply_tool_calling_jail( - tool_call_parser: String, + tool_call_parser: Option, + tool_choice: Option, stream: S, ) -> impl Stream> + Send where S: Stream> + Send + 'static, { - let jail = JailedStream::builder() - .tool_call_parser(tool_call_parser) - .build(); + use dynamo_async_openai::types::ChatCompletionToolChoiceOption; + + let mut builder = JailedStream::builder(); + + // Configure jail based on tool_choice + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(named)) => { + // Immediate jail mode for named tool choice + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + // Immediate jail mode for required tool choice + builder = builder.tool_choice_required(); + } + Some(ChatCompletionToolChoiceOption::Auto) | Some(ChatCompletionToolChoiceOption::None) | None => { + // Traditional marker-based jail for auto/none/unspecified + if let Some(parser) = tool_call_parser { + builder = builder.tool_call_parser(parser); + } + } + } + + let jail = builder.build(); jail.apply_with_finish_reason(stream) } @@ -957,11 +978,11 @@ impl // Apply jail conditionally let transformed_stream: Pin + Send>> = if should_jail { - if let Some(parser) = self.tool_call_parser.clone() { - Box::pin(Self::apply_tool_calling_jail(parser, stream)) - } else { - Box::pin(stream) // Should not happen due to should_jail check - } + Box::pin(Self::apply_tool_calling_jail( + self.tool_call_parser.clone(), + request.inner.tool_choice.clone(), + stream, + )) } else { Box::pin(stream) }; diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 2a68dfe55c9..e2ebd93913e 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -6,16 +6,9 @@ use crate::{ local_model::runtime_config::ModelRuntimeConfig, protocols::{ common::{self}, - openai::partial_json::{AllowPartial, loads}, }, types::TokenIdType, }; -use dynamo_async_openai::types::{ - ChatCompletionMessageToolCallChunk, ChatCompletionToolChoiceOption, ChatCompletionToolType, - FunctionCallStream, -}; -use serde_json::{self}; -use tracing::error; /// Provides a method for generating a [`DeltaGenerator`] from a chat completion request. impl NvCreateChatCompletionRequest { @@ -50,14 +43,6 @@ impl NvCreateChatCompletionRequest { /// # Returns /// * [`DeltaGenerator`] configured with model name and response options. pub fn response_generator(&self, request_id: String) -> DeltaGenerator { - let tool_choice_context = match self.inner.tool_choice.as_ref() { - Some(ChatCompletionToolChoiceOption::Named(named)) => { - ToolChoiceContext::Named(named.function.name.clone()) - } - Some(ChatCompletionToolChoiceOption::Required) => ToolChoiceContext::Required, - _ => ToolChoiceContext::None, - }; - let options = DeltaGeneratorOptions { enable_usage: self .inner @@ -68,7 +53,6 @@ impl NvCreateChatCompletionRequest { enable_logprobs: self.inner.logprobs.unwrap_or(false) || self.inner.top_logprobs.unwrap_or(0) > 0, runtime_config: ModelRuntimeConfig::default(), - tool_choice: tool_choice_context, }; DeltaGenerator::new(self.inner.model.clone(), options, request_id) @@ -84,15 +68,6 @@ pub struct DeltaGeneratorOptions { pub enable_logprobs: bool, pub runtime_config: ModelRuntimeConfig, - pub tool_choice: ToolChoiceContext, -} - -#[derive(Debug, Clone, Default)] -pub enum ToolChoiceContext { - #[default] - None, - Named(String), - Required, } /// Generates incremental chat completion responses in a streaming fashion. @@ -115,8 +90,6 @@ pub struct DeltaGenerator { msg_counter: u64, /// Configuration options for response generation. options: DeltaGeneratorOptions, - /// Buffer for accumulating tool call JSON during streaming - tool_call_buffer: String, } impl DeltaGenerator { @@ -159,7 +132,6 @@ impl DeltaGenerator { usage, msg_counter: 0, options, - tool_call_buffer: String::new(), } } @@ -264,7 +236,7 @@ impl DeltaGenerator { text: Option, finish_reason: Option, logprobs: Option, - tool_calls: Option>, + tool_calls: Option>, ) -> NvCreateChatCompletionStreamResponse { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { content: text, @@ -333,113 +305,6 @@ impl DeltaGenerator { self.options.enable_usage } - fn process_complete_tool_calls( - &mut self, - delta_text: &str, - is_final: bool, - ) -> anyhow::Result>> { - self.tool_call_buffer.push_str(delta_text); - - if !is_final { - return Ok(None); - } - - let parsed = match loads(&self.tool_call_buffer, AllowPartial::all()) { - Ok(value) => value, - Err(e) => { - tracing::warn!( - error = %e, - buffer = %self.tool_call_buffer, - "failed to parse tool_choice output" - ); - return Ok(None); - } - }; - - let chunks = match &self.options.tool_choice { - ToolChoiceContext::Named(name) => { - if parsed.as_object().is_some() { - vec![ChatCompletionMessageToolCallChunk { - index: 0, - id: Some("call-1".to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(name.clone()), - arguments: Some(self.tool_call_buffer.clone()), - }), - }] - } else { - vec![] - } - } - ToolChoiceContext::Required => { - if let Some(array) = parsed.as_array() { - array - .iter() - .enumerate() - .filter_map(|(idx, entry)| { - let name = entry.get("name")?.as_str()?.to_string(); - let parameters = entry.get("parameters")?; - let args = serde_json::to_string(parameters).ok()?; - Some(ChatCompletionMessageToolCallChunk { - index: idx as u32, - id: Some(format!("call-{}", idx + 1)), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(name), - arguments: Some(args), - }), - }) - }) - .collect() - } else { - vec![] - } - } - ToolChoiceContext::None => vec![], - }; - - Ok(Some(chunks)) - } - - fn determine_streaming_finish_reason( - &self, - backend_finish: Option, - ) -> Option { - backend_finish.as_ref()?; - - // For critical/error finish reasons, preserve them regardless of tool_choice mode - match backend_finish { - Some(common::FinishReason::Length) => { - return Some(dynamo_async_openai::types::FinishReason::Length); - } - Some(common::FinishReason::ContentFilter) => { - return Some(dynamo_async_openai::types::FinishReason::ContentFilter); - } - _ => {} - } - - // For normal finish reasons (Stop/EoS/Cancelled), apply tool_choice semantics - match &self.options.tool_choice { - ToolChoiceContext::None => match backend_finish { - Some(common::FinishReason::EoS) | Some(common::FinishReason::Stop) => { - Some(dynamo_async_openai::types::FinishReason::Stop) - } - Some(common::FinishReason::Cancelled) => { - Some(dynamo_async_openai::types::FinishReason::Stop) - } - _ => None, - }, - ToolChoiceContext::Named(_) => { - // Named tool choice finishes with "stop" for normal completion - Some(dynamo_async_openai::types::FinishReason::Stop) - } - ToolChoiceContext::Required => { - // Required tool choice finishes with "tool_calls" for normal completion - Some(dynamo_async_openai::types::FinishReason::ToolCalls) - } - } - } } /// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing @@ -490,8 +355,7 @@ impl crate::protocols::openai::DeltaGeneratorExt Some(dynamo_async_openai::types::FinishReason::Stop), Some(common::FinishReason::Stop) => { Some(dynamo_async_openai::types::FinishReason::Stop) @@ -513,39 +377,10 @@ impl crate::protocols::openai::DeltaGeneratorExt { - tool_call_chunks = Some(chunks); - delta_text = None; - } - Ok(_) => { - delta_text = None; - } - Err(err) => { - error!( - error = %err, - "failed to parse tool_choice output" - ); - delta_text = None; - } - } - } - - if finish_reason.is_some() { - finish_reason = self.determine_streaming_finish_reason(backend_finish_reason); - } - } + let delta_text = delta.text; let mut stream_response = - self.build_choice(index, delta_text, finish_reason, logprobs, tool_call_chunks); + self.build_choice(index, delta_text, finish_reason, logprobs, None); // Extract worker_id from disaggregated_params and inject into nvext if present if let Some(worker_id_json) = delta diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 7a57b950b90..eeaa635f7ee 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -62,6 +62,24 @@ pub struct JailConfig<'a> { pub tool_call_parser: Option<&'a str>, } +/// Jail activation mode +#[derive(Debug, Clone, PartialEq)] +pub enum JailMode { + /// Traditional: wait for start marker, then jail + MarkerBased, + /// Immediate: start jailed from first token (for tool_choice) + Immediate { format: ToolChoiceFormat }, +} + +/// Format for tool_choice immediate jail mode +#[derive(Debug, Clone, PartialEq)] +pub enum ToolChoiceFormat { + /// tool_choice=named: expect single object {"location": "Paris", ...} + SingleObject { tool_name: String }, + /// tool_choice=required: expect array [{name:"search", parameters:{...}}, ...] + ArrayOfTools, +} + /// State tracking for an individual choice during jail processing #[derive(Debug, Clone)] struct ChoiceJailState { @@ -103,10 +121,10 @@ fn create_choice_stream( impl ChoiceJailState { /// Create a new jail state for a choice - fn new(index: u32) -> Self { + fn new(index: u32, starts_jailed: bool) -> Self { Self { index, - is_jailed: false, + is_jailed: starts_jailed, accumulated_content: String::new(), partial_match_buffer: String::new(), stream_finish_reason: None, @@ -381,7 +399,7 @@ impl ChoiceJailStateCollection { } /// Get or create state for a choice index - fn get_or_create_state(&mut self, index: u32) -> &mut ChoiceJailState { + fn get_or_create_state(&mut self, index: u32, starts_jailed: bool) -> &mut ChoiceJailState { // Find the position where this index should be match self.states.binary_search_by_key(&index, |s| s.index) { Ok(pos) => { @@ -390,7 +408,7 @@ impl ChoiceJailStateCollection { } Err(insert_pos) => { // Need to create new state - let new_state = ChoiceJailState::new(index); + let new_state = ChoiceJailState::new(index, starts_jailed); self.states.insert(insert_pos, new_state); &mut self.states[insert_pos] } @@ -422,6 +440,7 @@ pub struct JailedStream { tool_call_parser: Option, emission_mode: EmissionMode, marker_matcher: MarkerMatcher, + jail_mode: JailMode, } impl JailedStream { @@ -439,8 +458,9 @@ impl JailedStream { where S: Stream> + Send + 'static, { + let jail_mode = self.jail_mode.clone(); let jailed_stream = self.apply(stream); - JailedStream::fix_finish_reason(jailed_stream) + JailedStream::fix_finish_reason(jailed_stream, jail_mode) } /// Apply the jail transformation to a stream of chat completion responses @@ -480,7 +500,8 @@ impl JailedStream { // Process each choice independently using the new architecture for choice in &chat_response.choices { if let Some(ref content) = choice.delta.content { - let choice_state = choice_states.get_or_create_state(choice.index); + let starts_jailed = matches!(self.jail_mode, JailMode::Immediate { .. }); + let choice_state = choice_states.get_or_create_state(choice.index, starts_jailed); // Store metadata when any choice becomes jailed (first time only) if !choice_state.is_jailed && self.should_start_jail(content) @@ -498,14 +519,24 @@ impl JailedStream { all_emissions.extend(emissions); } else { // Handle choices without content (e.g., final chunks with finish_reason) - // These should always pass through - let pass_through_choice = ChatChoiceStream { - index: choice.index, - delta: choice.delta.clone(), - finish_reason: choice.finish_reason, - logprobs: choice.logprobs.clone(), - }; - all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); + // Only filter out if this choice was ever jailed and lacks role + // (to avoid aggregator issues with deltas missing role after unjail) + let choice_state = choice_states.get_or_create_state(choice.index, false); + let was_ever_jailed = choice_state.accumulated_content.len() > 0 || choice_state.is_jailed; + + let should_emit = choice.delta.role.is_some() + || choice.delta.tool_calls.is_some() + || !was_ever_jailed; // Always pass through if never jailed + + if should_emit { + let pass_through_choice = ChatChoiceStream { + index: choice.index, + delta: choice.delta.clone(), + finish_reason: choice.finish_reason, + logprobs: choice.logprobs.clone(), + }; + all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); + } } } @@ -673,38 +704,67 @@ impl JailedStream { /// Check if accumulated content should end jail async fn should_end_jail(&self, accumulated_content: &str) -> (bool, usize) { - // Path 1: End sequence detected - let end_marker_info = if !self.jail_end_sequences.is_empty() { - self.jail_end_sequences.iter().find_map(|seq| { - accumulated_content - .find(seq) - .map(|pos| (pos + seq.len(), seq.clone())) - }) - } else { - None - }; + match &self.jail_mode { + JailMode::MarkerBased => { + // Path 1: End sequence detected + let end_marker_info = if !self.jail_end_sequences.is_empty() { + self.jail_end_sequences.iter().find_map(|seq| { + accumulated_content + .find(seq) + .map(|pos| (pos + seq.len(), seq.clone())) + }) + } else { + None + }; - // Path 2: Complete tool call(s) can be parsed (early exit) - let early_exit = self.should_exit_jail_early(accumulated_content).await; + // Path 2: Complete tool call(s) can be parsed (early exit) + let early_exit = self.should_exit_jail_early(accumulated_content).await; - if let Some((end_pos, _)) = end_marker_info { - (true, end_pos) - } else if early_exit { - // For early exit, find where the complete tool call ends - if let Some(parser) = &self.tool_call_parser { - if let Ok((_, _)) = - try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await - { - let split_pos = find_tool_call_end_position(accumulated_content, Some(parser)); - (true, split_pos) + if let Some((end_pos, _)) = end_marker_info { + (true, end_pos) + } else if early_exit { + // For early exit, find where the complete tool call ends + if let Some(parser) = &self.tool_call_parser { + if let Ok((_, _)) = + try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await + { + let split_pos = find_tool_call_end_position(accumulated_content, Some(parser)); + (true, split_pos) + } else { + (false, accumulated_content.len()) + } + } else { + (false, accumulated_content.len()) + } } else { (false, accumulated_content.len()) } - } else { - (false, accumulated_content.len()) } - } else { - (false, accumulated_content.len()) + JailMode::Immediate { format } => { + // For tool_choice, check if we have valid complete JSON + match format { + ToolChoiceFormat::SingleObject { .. } => { + // Expect single object: {"location": "Paris", "unit": "celsius"} + if let Ok(value) = serde_json::from_str::(accumulated_content) { + if value.is_object() { + return (true, accumulated_content.len()); + } + } + (false, accumulated_content.len()) + } + ToolChoiceFormat::ArrayOfTools => { + // Expect array: [{"name":"search","parameters":{...}}, ...] + if let Ok(value) = serde_json::from_str::(accumulated_content) { + if let Some(arr) = value.as_array() { + if !arr.is_empty() { + return (true, accumulated_content.len()); + } + } + } + (false, accumulated_content.len()) + } + } + } } } @@ -715,46 +775,131 @@ impl JailedStream { accumulated_content: &str, base_choice: &ChatChoiceStream, ) -> ChatChoiceStream { - if let Ok((tool_calls, normal_text)) = - try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref()) - .await - && !tool_calls.is_empty() - { - // Convert to streaming format - let tool_call_chunks: Vec = tool_calls - .into_iter() - .enumerate() - .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { - index: idx as u32, - id: Some(tool_call.id), - r#type: Some(tool_call.r#type), - function: Some(FunctionCallStream { - name: Some(tool_call.function.name), - arguments: Some(tool_call.function.arguments), - }), - }) - .collect(); - // Create choice with tool calls - let choice = create_choice_stream( - choice_index, - Some(Role::Assistant), - normal_text.as_deref().unwrap_or(""), - Some(tool_call_chunks), - None, - None, - ); - return choice; + match &self.jail_mode { + JailMode::MarkerBased => { + // Traditional marker-based tool call parsing + if let Ok((tool_calls, normal_text)) = + try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref()) + .await + && !tool_calls.is_empty() + { + // Convert to streaming format + let tool_call_chunks: Vec = tool_calls + .into_iter() + .enumerate() + .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { + index: idx as u32, + id: Some(tool_call.id), + r#type: Some(tool_call.r#type), + function: Some(FunctionCallStream { + name: Some(tool_call.function.name), + arguments: Some(tool_call.function.arguments), + }), + }) + .collect(); + // Create choice with tool calls + let choice = create_choice_stream( + choice_index, + Some(Role::Assistant), + normal_text.as_deref().unwrap_or(""), + Some(tool_call_chunks), + None, + None, + ); + return choice; + } + + // No tool calls found or parsing failed, return content choice + create_choice_stream( + choice_index, + Some(Role::Assistant), + accumulated_content, + None, + base_choice.finish_reason, + base_choice.logprobs.clone(), + ) + } + JailMode::Immediate { format } => { + // tool_choice mode: parse JSON and convert to tool calls + match self.parse_tool_choice_json(accumulated_content, format) { + Ok(tool_call_chunks) if !tool_call_chunks.is_empty() => { + create_choice_stream( + choice_index, + Some(Role::Assistant), + "", + Some(tool_call_chunks), + base_choice.finish_reason, + base_choice.logprobs.clone(), + ) + } + Ok(_) | Err(_) => { + // Parsing failed, return as content + create_choice_stream( + choice_index, + Some(Role::Assistant), + accumulated_content, + None, + base_choice.finish_reason, + base_choice.logprobs.clone(), + ) + } + } + } } + } - // No tool calls found or parsing failed, return content choice - create_choice_stream( - choice_index, - Some(Role::Assistant), - accumulated_content, - None, - base_choice.finish_reason, - base_choice.logprobs.clone(), - ) + /// Parse tool_choice JSON output into tool call chunks + fn parse_tool_choice_json( + &self, + json_content: &str, + format: &ToolChoiceFormat, + ) -> anyhow::Result> { + let parsed = serde_json::from_str::(json_content)?; + + match format { + ToolChoiceFormat::SingleObject { tool_name } => { + // For named tool choice: JSON is the parameters object + if parsed.is_object() { + Ok(vec![ChatCompletionMessageToolCallChunk { + index: 0, + id: Some("call-1".to_string()), + r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(tool_name.clone()), + arguments: Some(json_content.to_string()), + }), + }]) + } else { + Ok(vec![]) + } + } + ToolChoiceFormat::ArrayOfTools => { + // For required tool choice: JSON is array of {name, parameters} + if let Some(array) = parsed.as_array() { + let chunks: Vec = array + .iter() + .enumerate() + .filter_map(|(idx, entry)| { + let name = entry.get("name")?.as_str()?.to_string(); + let parameters = entry.get("parameters")?; + let args = serde_json::to_string(parameters).ok()?; + Some(ChatCompletionMessageToolCallChunk { + index: idx as u32, + id: Some(format!("call-{}", idx + 1)), + r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(name), + arguments: Some(args), + }), + }) + }) + .collect(); + Ok(chunks) + } else { + Ok(vec![]) + } + } + } } /// Check if accumulated content contains complete tool calls that can be parsed @@ -775,8 +920,9 @@ impl JailedStream { /// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted /// This should be called after apply() to fix the finish_reason for tool call chunks - pub fn fix_finish_reason( + fn fix_finish_reason( input_stream: S, + jail_mode: JailMode, ) -> impl Stream> + Send where S: Stream> + Send + 'static, @@ -795,13 +941,39 @@ impl JailedStream { } } - // If this chunk has finish_reason and the choice had tool calls, override to ToolCalls + // Fix finish_reason based on jail mode and whether tool calls were emitted if let Some(ref mut data) = response.data { for choice in &mut data.choices { - if choice.finish_reason.is_some() && choice.finish_reason == Some(FinishReason::Stop) - && has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false) - { - choice.finish_reason = Some(FinishReason::ToolCalls); + if let Some(finish) = choice.finish_reason { + // Only modify Stop finish reason, preserve Length/ContentFilter + if finish == FinishReason::Stop { + let has_tool_calls = has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false); + + match &jail_mode { + JailMode::MarkerBased => { + // Traditional: if tool calls emitted, change to ToolCalls + if has_tool_calls { + choice.finish_reason = Some(FinishReason::ToolCalls); + } + } + JailMode::Immediate { format } => { + // tool_choice mode: apply specific finish_reason logic + match format { + ToolChoiceFormat::SingleObject { .. } => { + // Named tool choice: keep Stop + // (already Stop, no change needed) + } + ToolChoiceFormat::ArrayOfTools => { + // Required tool choice: change to ToolCalls + if has_tool_calls { + choice.finish_reason = Some(FinishReason::ToolCalls); + } + } + } + } + } + } + // Length and ContentFilter are preserved as-is } } } @@ -818,6 +990,7 @@ pub struct JailedStreamBuilder { jail_end_sequences: Vec, tool_call_parser: Option, emission_mode: EmissionMode, + jail_mode: JailMode, } impl JailedStreamBuilder { @@ -828,6 +1001,7 @@ impl JailedStreamBuilder { jail_end_sequences: Vec::new(), tool_call_parser: None, emission_mode: EmissionMode::default(), + jail_mode: JailMode::MarkerBased, } } @@ -887,6 +1061,22 @@ impl JailedStreamBuilder { self } + /// Enable immediate jail mode for tool_choice=named + pub fn tool_choice_named(mut self, tool_name: String) -> Self { + self.jail_mode = JailMode::Immediate { + format: ToolChoiceFormat::SingleObject { tool_name }, + }; + self + } + + /// Enable immediate jail mode for tool_choice=required + pub fn tool_choice_required(mut self) -> Self { + self.jail_mode = JailMode::Immediate { + format: ToolChoiceFormat::ArrayOfTools, + }; + self + } + /// Build the configured JailedStream pub fn build(mut self) -> JailedStream { // Auto-populate jail sequences from parser config if not manually configured @@ -965,6 +1155,7 @@ impl JailedStreamBuilder { tool_call_parser: self.tool_call_parser, emission_mode: self.emission_mode, marker_matcher, + jail_mode: self.jail_mode, } } } diff --git a/lib/llm/tests/test_reasoning_parser.rs b/lib/llm/tests/test_reasoning_parser.rs index 190fd9badbb..19a0ec328ac 100644 --- a/lib/llm/tests/test_reasoning_parser.rs +++ b/lib/llm/tests/test_reasoning_parser.rs @@ -484,7 +484,8 @@ mod tests { // Step 2: Apply tool calling jail transformation let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( - "nemotron_deci".to_string(), + Some("nemotron_deci".to_string()), + None, // No tool_choice in this test reasoning_parsed_stream, ); @@ -596,7 +597,8 @@ mod tests { let reasoning_parsed_stream = stream::iter(debug_chunks); let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( - "harmony".to_string(), + Some("harmony".to_string()), + None, // No tool_choice in this test reasoning_parsed_stream, ); diff --git a/lib/llm/tests/test_streaming_tool_parsers.rs b/lib/llm/tests/test_streaming_tool_parsers.rs index 03392d2fcb7..c2214e707b8 100644 --- a/lib/llm/tests/test_streaming_tool_parsers.rs +++ b/lib/llm/tests/test_streaming_tool_parsers.rs @@ -158,7 +158,8 @@ async fn parse_response_stream( > = if tool_parse_enable { if let Some(tool_parser) = tool_parser_str { Box::pin(OpenAIPreprocessor::apply_tool_calling_jail( - tool_parser, + Some(tool_parser), + None, // No tool_choice in this test stream, )) } else { diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index 71d947f02ae..884013a130f 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -35,6 +35,79 @@ fn create_test_request() -> NvCreateChatCompletionRequest { } } +async fn apply_jail_transformation( + raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + tool_choice: Option, +) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use futures::stream; + use futures::StreamExt; + use dynamo_runtime::protocols::annotated::Annotated; + + let input_stream = stream::iter(vec![Annotated { + data: Some(raw_response), + id: None, + event: None, + comment: None, + }]); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream.next().await.unwrap().data.unwrap() +} + +async fn apply_jail_transformation_streaming( + raw_responses: Vec, + tool_choice: Option, +) -> Vec { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use futures::stream; + use futures::StreamExt; + use dynamo_runtime::protocols::annotated::Annotated; + + let input_stream = stream::iter(raw_responses.into_iter().map(|r| Annotated { + data: Some(r), + id: None, + event: None, + comment: None, + })); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream + .filter_map(|ann| async move { ann.data }) + .collect() + .await +} + fn build_backend_output(text: &str) -> BackendOutput { BackendOutput { token_ids: vec![], @@ -50,10 +123,10 @@ fn build_backend_output(text: &str) -> BackendOutput { } } -#[test] -fn test_named_tool_choice_parses_json() { +#[tokio::test] +async fn test_named_tool_choice_parses_json() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( ChatCompletionNamedToolChoice { r#type: ChatCompletionToolType::Function, function: FunctionName { @@ -61,20 +134,23 @@ fn test_named_tool_choice_parses_json() { }, }, )); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-1".to_string()); let backend_output = build_backend_output(r#"{"location":"Paris"}"#); - let response = generator + let raw_response = generator .choice_from_postprocessor(backend_output) .expect("choice generation"); + let response = apply_jail_transformation(raw_response, tool_choice).await; let choice = &response.choices[0]; + assert_eq!( choice.finish_reason, Some(dynamo_async_openai::types::FinishReason::Stop) ); let delta = &choice.delta; - assert!(delta.content.is_none()); + assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); let tool_calls = delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); @@ -93,27 +169,30 @@ fn test_named_tool_choice_parses_json() { ); } -#[test] -fn test_required_tool_choice_parses_json_array() { +#[tokio::test] +async fn test_required_tool_choice_parses_json_array() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-2".to_string()); let backend_output = build_backend_output( r#"[{"name":"search","parameters":{"query":"rust"}}, {"name":"summarize","parameters":{"topic":"memory"}}]"#, ); - let response = generator + let raw_response = generator .choice_from_postprocessor(backend_output) .expect("choice generation"); + let response = apply_jail_transformation(raw_response, tool_choice).await; let choice = &response.choices[0]; + assert_eq!( choice.finish_reason, Some(dynamo_async_openai::types::FinishReason::ToolCalls) ); let delta = &choice.delta; - assert!(delta.content.is_none()); + assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); let tool_calls = delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 2); @@ -143,26 +222,31 @@ fn test_required_tool_choice_parses_json_array() { ); } -#[test] -fn test_tool_choice_parse_failure_suppresses_text() { +#[tokio::test] +async fn test_tool_choice_parse_failure_returns_as_content() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-3".to_string()); let backend_output = build_backend_output("not-json"); - let response = generator + let raw_response = generator .choice_from_postprocessor(backend_output) .expect("choice generation"); + let response = apply_jail_transformation(raw_response, tool_choice).await; let delta = &response.choices[0].delta; - assert!(delta.content.is_none()); + + // Jail stream behavior: if parsing fails, return accumulated content as-is + // This matches marker-based FC behavior + assert_eq!(delta.content.as_deref(), Some("not-json")); assert!(delta.tool_calls.is_none()); } -#[test] -fn test_streaming_named_tool_buffers_until_finish() { +#[tokio::test] +async fn test_streaming_named_tool_buffers_until_finish() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( ChatCompletionNamedToolChoice { r#type: ChatCompletionToolType::Function, function: FunctionName { @@ -170,6 +254,7 @@ fn test_streaming_named_tool_buffers_until_finish() { }, }, )); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-stream-1".to_string()); @@ -179,7 +264,7 @@ fn test_streaming_named_tool_buffers_until_finish() { r#"celsius"}"#, ]; - let mut all_responses = Vec::new(); + let mut raw_responses = Vec::new(); for (i, chunk) in chunks.iter().enumerate() { let backend_output = BackendOutput { token_ids: vec![], @@ -201,20 +286,21 @@ fn test_streaming_named_tool_buffers_until_finish() { let response = generator .choice_from_postprocessor(backend_output) .expect("streaming chunk"); - all_responses.push(response); + raw_responses.push(response); } - for i in 0..all_responses.len() - 1 { - assert!(all_responses[i].choices[0].delta.tool_calls.is_none()); - } + let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; - let last_response = all_responses.last().unwrap(); + // Jail stream buffers content until valid JSON, then emits once + assert_eq!(all_responses.len(), 1); + + let response = &all_responses[0]; assert_eq!( - last_response.choices[0].finish_reason, + response.choices[0].finish_reason, Some(dynamo_async_openai::types::FinishReason::Stop) ); - let tool_calls = last_response.choices[0].delta.tool_calls.as_ref().unwrap(); + let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].function.as_ref().unwrap().name.as_deref(), Some("get_weather")); assert_eq!( @@ -223,10 +309,11 @@ fn test_streaming_named_tool_buffers_until_finish() { ); } -#[test] -fn test_streaming_required_tool_parallel() { +#[tokio::test] +async fn test_streaming_required_tool_parallel() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-stream-2".to_string()); @@ -235,7 +322,7 @@ fn test_streaming_required_tool_parallel() { r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#, ]; - let mut all_responses = Vec::new(); + let mut raw_responses = Vec::new(); for (i, chunk) in chunks.iter().enumerate() { let backend_output = BackendOutput { token_ids: vec![], @@ -257,20 +344,21 @@ fn test_streaming_required_tool_parallel() { let response = generator .choice_from_postprocessor(backend_output) .expect("streaming chunk"); - all_responses.push(response); + raw_responses.push(response); } - for i in 0..all_responses.len() - 1 { - assert!(all_responses[i].choices[0].delta.tool_calls.is_none()); - } + let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; + + // Jail stream buffers until complete JSON array + assert_eq!(all_responses.len(), 1); - let last_response = all_responses.last().unwrap(); + let response = &all_responses[0]; assert_eq!( - last_response.choices[0].finish_reason, + response.choices[0].finish_reason, Some(dynamo_async_openai::types::FinishReason::ToolCalls) ); - let tool_calls = last_response.choices[0].delta.tool_calls.as_ref().unwrap(); + let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 2); assert_eq!(tool_calls[0].function.as_ref().unwrap().name.as_deref(), Some("search")); @@ -280,10 +368,10 @@ fn test_streaming_required_tool_parallel() { assert_eq!(tool_calls[1].function.as_ref().unwrap().arguments.as_deref(), Some(r#"{"topic":"memory"}"#)); } -#[test] -fn test_streaming_buffers_until_finish() { +#[tokio::test] +async fn test_streaming_buffers_until_finish() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( ChatCompletionNamedToolChoice { r#type: ChatCompletionToolType::Function, function: FunctionName { @@ -291,11 +379,12 @@ fn test_streaming_buffers_until_finish() { }, }, )); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-stream-3".to_string()); let full_json = r#"{"query":"rust programming"}"#; - let mut responses = Vec::new(); + let mut raw_responses = Vec::new(); for (i, ch) in full_json.chars().enumerate() { let is_last = i == full_json.len() - 1; @@ -319,21 +408,18 @@ fn test_streaming_buffers_until_finish() { let response = generator .choice_from_postprocessor(backend_output) .expect("char chunk"); - responses.push(response); + raw_responses.push(response); } - for resp in &responses { - assert!(resp.choices[0].delta.content.is_none()); - } + let responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; - for i in 0..responses.len() - 1 { - assert!(responses[i].choices[0].delta.tool_calls.is_none()); - } + // Jail stream buffers all content until complete JSON + assert_eq!(responses.len(), 1); - let last = responses.last().unwrap(); - assert!(last.choices[0].delta.tool_calls.is_some()); + let response = &responses[0]; + assert!(response.choices[0].delta.tool_calls.is_some()); assert_eq!( - last.choices[0].delta.tool_calls.as_ref().unwrap()[0] + response.choices[0].delta.tool_calls.as_ref().unwrap()[0] .function.as_ref().unwrap().arguments.as_deref(), Some(r#"{"query":"rust programming"}"#) ); diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs index 99892f781c7..7b0519e9341 100644 --- a/lib/llm/tests/tool_choice_finish_reasons.rs +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -51,10 +51,45 @@ fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> } } -#[test] -fn test_named_tool_choice_preserves_length_finish_reason() { +async fn apply_jail_transformation( + raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + tool_choice: Option, +) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use futures::stream; + use futures::StreamExt; + use dynamo_runtime::protocols::annotated::Annotated; + + let input_stream = stream::iter(vec![Annotated { + data: Some(raw_response), + id: None, + event: None, + comment: None, + }]); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream.next().await.unwrap().data.unwrap() +} + +#[tokio::test] +async fn test_named_tool_choice_preserves_length_finish_reason() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( ChatCompletionNamedToolChoice { r#type: ChatCompletionToolType::Function, function: FunctionName { @@ -62,6 +97,7 @@ fn test_named_tool_choice_preserves_length_finish_reason() { }, }, )); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-length-1".to_string()); let backend_output = build_backend_output_with_finish( @@ -69,10 +105,12 @@ fn test_named_tool_choice_preserves_length_finish_reason() { common::FinishReason::Length, ); - let response = generator + let raw_response = generator .choice_from_postprocessor(backend_output) .expect("choice generation"); + let response = apply_jail_transformation(raw_response, tool_choice).await; + // Critical: Length finish reason should be preserved, NOT replaced with Stop assert_eq!( response.choices[0].finish_reason, @@ -186,10 +224,11 @@ fn test_named_tool_choice_normal_stop_becomes_stop() { ); } -#[test] -fn test_required_tool_choice_normal_stop_becomes_tool_calls() { +#[tokio::test] +async fn test_required_tool_choice_normal_stop_becomes_tool_calls() { let mut request = create_test_request(); - request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); let mut generator = request.response_generator("req-stop-2".to_string()); let backend_output = build_backend_output_with_finish( @@ -197,10 +236,12 @@ fn test_required_tool_choice_normal_stop_becomes_tool_calls() { common::FinishReason::Stop, ); - let response = generator + let raw_response = generator .choice_from_postprocessor(backend_output) .expect("choice generation"); + let response = apply_jail_transformation(raw_response, tool_choice).await; + // Normal completion: Stop should become ToolCalls for required tool choice assert_eq!( response.choices[0].finish_reason, From e91b122fc51516092641df4ccb8922cc2f2257d1 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:01:50 +0300 Subject: [PATCH 08/18] remove partial json lib Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/protocols/openai/partial_json.rs | 322 ------------------- 1 file changed, 322 deletions(-) delete mode 100644 lib/llm/src/protocols/openai/partial_json.rs diff --git a/lib/llm/src/protocols/openai/partial_json.rs b/lib/llm/src/protocols/openai/partial_json.rs deleted file mode 100644 index b575c711b51..00000000000 --- a/lib/llm/src/protocols/openai/partial_json.rs +++ /dev/null @@ -1,322 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! Partial JSON Parser for streaming tool calls. -//! -//! This implementation is heavily inspired by the `partial-json-parser` library: -//! https://github.com/promplate/partial-json-parser -//! -//! The original Python library is licensed under MIT License. -//! We've adapted the core logic to Rust for use in Dynamo's streaming tool calls functionality. - -use std::collections::VecDeque; - -/// Options for what types of partial JSON are allowed -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct AllowPartial { - pub strings: bool, - pub objects: bool, - pub arrays: bool, -} - -impl Default for AllowPartial { - fn default() -> Self { - Self { - strings: true, - objects: true, - arrays: true, - } - } -} - -impl AllowPartial { - pub fn all() -> Self { - Self::default() - } - - pub fn none() -> Self { - Self { - strings: false, - objects: false, - arrays: false, - } - } -} - -/// Represents a token found during JSON scanning -#[derive(Debug, Clone, PartialEq, Eq)] -struct Token { - index: usize, - char: char, -} - -/// Scans the JSON string for structural characters -fn scan_tokens(json_string: &str) -> Vec { - json_string - .char_indices() - .filter_map(|(i, c)| { - if matches!(c, '"' | '[' | ']' | '{' | '}') { - Some(Token { index: i, char: c }) - } else { - None - } - }) - .collect() -} - -/// Checks if a quote at the given position is escaped -fn is_escaped(json_string: &str, index: usize) -> bool { - let text_before = &json_string[..index]; - let count = index - text_before.trim_end_matches('\\').len(); - count % 2 == 1 -} - -/// Joins closing tokens for unclosed containers -fn join_closing_tokens(stack: &VecDeque) -> String { - stack - .iter() - .rev() - .map(|token| if token.char == '{' { '}' } else { ']' }) - .collect() -} - -/// Completes a partial JSON string by adding necessary closing tokens -/// -/// Returns a tuple of (head, tail) where head is the potentially truncated -/// input and tail is the completion string -pub fn fix_json(json_string: &str, allow: AllowPartial) -> (String, String) { - let tokens = scan_tokens(json_string); - - // Empty or starts with quote - use simple fix - if tokens.is_empty() || tokens[0].char == '"' { - return simple_fix(json_string, allow); - } - - let mut stack: VecDeque = VecDeque::new(); - let mut in_string = false; - let mut last_string_start = None; - let mut last_string_end = None; - - for token in &tokens { - if token.char == '"' { - if !in_string { - in_string = true; - last_string_start = Some(token.index); - } else if !is_escaped(json_string, token.index) { - in_string = false; - last_string_end = Some(token.index); - } - } else if !in_string { - match token.char { - '}' => { - if let Some(open) = stack.pop_back() { - assert_eq!(open.char, '{', "Mismatched braces"); - } - } - ']' => { - if let Some(open) = stack.pop_back() { - assert_eq!(open.char, '[', "Mismatched brackets"); - } - } - _ => { - stack.push_back(token.clone()); - } - } - } - } - - // If stack is empty, JSON is complete - if stack.is_empty() { - return (json_string.to_string(), String::new()); - } - - // Remove trailing comma if present - let mut head = json_string.trim_end(); - if head.ends_with(',') { - head = head[..head.len() - 1].trim_end(); - } - - // Handle unclosed strings - if !allow.strings && in_string { - if let Some(last_container) = stack.back() - && last_container.char == '{' - { - // Truncate before the unclosed string key - return ( - head[..=last_container.index].to_string(), - join_closing_tokens(&stack), - ); - } - - // Find last comma before the unclosed string - if let Some(string_start) = last_string_start { - let last_container_pos = stack.back().map(|t| t.index).unwrap_or(0); - let search_start = last_container_pos.max(last_string_end.unwrap_or(0)) + 1; - - if let Some(comma_pos) = head[search_start..string_start].rfind(',') { - let absolute_comma = search_start + comma_pos; - return ( - head[..absolute_comma].to_string(), - join_closing_tokens(&stack), - ); - } - } - } - - // Simple case: just close all open containers - if in_string - && allow.strings - && let Some(string_start) = last_string_start - { - // Fix the partial string - let partial_str = &head[string_start..]; - let (fixed_head, fixed_tail) = simple_fix(partial_str, allow); - return ( - format!("{}{}", &head[..string_start], fixed_head), - format!("{}{}", fixed_tail, join_closing_tokens(&stack)), - ); - } - - (head.to_string(), join_closing_tokens(&stack)) -} - -/// Simple fix for basic cases (strings, atoms) -fn simple_fix(json_string: &str, allow: AllowPartial) -> (String, String) { - let trimmed = json_string.trim_end(); - - // Handle unclosed strings - if trimmed.starts_with('"') && allow.strings { - // Count how many unescaped quotes we have - let mut escaped = false; - let mut quote_count = 0; - for ch in trimmed.chars() { - if ch == '\\' && !escaped { - escaped = true; - } else { - if ch == '"' && !escaped { - quote_count += 1; - } - escaped = false; - } - } - - if quote_count % 2 == 1 { - // Unclosed string - return (trimmed.to_string(), "\"".to_string()); - } - } - - // Already complete or can't fix - (trimmed.to_string(), String::new()) -} - -/// Ensures the JSON string is complete by adding necessary tokens -pub fn ensure_json(json_string: &str, allow: AllowPartial) -> String { - let (head, tail) = fix_json(json_string, allow); - format!("{}{}", head, tail) -} - -/// Parses partial JSON string into a serde_json::Value -/// -/// This is the main function inspired by partial-json-parser's `loads()`. -/// It completes the partial JSON and then parses it. -pub fn loads( - json_string: &str, - allow: AllowPartial, -) -> Result { - let completed = ensure_json(json_string, allow); - serde_json::from_str(&completed) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_complete_json() { - let result = ensure_json(r#"{"key":"value"}"#, AllowPartial::all()); - assert_eq!(result, r#"{"key":"value"}"#); - } - - #[test] - fn test_unclosed_object() { - let result = ensure_json(r#"{"key":"value""#, AllowPartial::all()); - assert_eq!(result, r#"{"key":"value"}"#); - } - - #[test] - fn test_unclosed_string() { - let result = ensure_json(r#"{"key":"val"#, AllowPartial::all()); - assert_eq!(result, r#"{"key":"val"}"#); - } - - #[test] - fn test_nested_objects() { - let result = ensure_json(r#"{"outer":{"inner":"val"#, AllowPartial::all()); - assert_eq!(result, r#"{"outer":{"inner":"val"}}"#); - } - - #[test] - fn test_array() { - let result = ensure_json(r#"[{"name":"test","args":{"val":"a""#, AllowPartial::all()); - assert_eq!(result, r#"[{"name":"test","args":{"val":"a"}}]"#); - } - - #[test] - fn test_parallel_tool_calls() { - let result = ensure_json( - r#"[{"name":"search","parameters":{"query":"rust"}},{"name":"summ"#, - AllowPartial::all(), - ); - // Should complete both the string and close all containers - assert!(result.contains("search")); - assert!(result.ends_with("}]")); - } - - #[test] - fn test_loads_incremental() { - // Test 1: Unclosed string value - let result1 = loads(r#"{"location":""#, AllowPartial::all()).unwrap(); - assert_eq!(result1["location"], ""); - - // Test 2: Complete first field, starting second - let result2 = loads(r#"{"location":"Paris","#, AllowPartial::all()).unwrap(); - assert_eq!(result2["location"], "Paris"); - - // Test 3: Complete object - let result3 = loads( - r#"{"location":"Paris","unit":"celsius"}"#, - AllowPartial::all(), - ) - .unwrap(); - assert_eq!(result3["location"], "Paris"); - assert_eq!(result3["unit"], "celsius"); - } - - #[test] - fn test_loads_array_incremental() { - // Test 4: Array with unclosed parameter value - let result4 = loads( - r#"[{"name":"search","parameters":{"query":""#, - AllowPartial::all(), - ) - .unwrap(); - assert!(result4.is_array()); - let arr = result4.as_array().unwrap(); - assert_eq!(arr.len(), 1); - assert_eq!(arr[0]["name"], "search"); - assert_eq!(arr[0]["parameters"]["query"], ""); - - // Test 5: Complete first tool, starting second - let result5 = loads( - r#"[{"name":"search","parameters":{"query":"rust"}},"#, - AllowPartial::all(), - ) - .unwrap(); - assert!(result5.is_array()); - let arr = result5.as_array().unwrap(); - assert_eq!(arr.len(), 1); - assert_eq!(arr[0]["name"], "search"); - assert_eq!(arr[0]["parameters"]["query"], "rust"); - } -} From e32318fd73f59419af20d96da126aa9d6b7687ed Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:09:36 +0300 Subject: [PATCH 09/18] clean aggregator diff Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/protocols/openai.rs | 1 - .../openai/chat_completions/aggregator.rs | 199 ++++---------- .../protocols/openai/chat_completions/jail.rs | 22 +- lib/llm/tests/http_namespace_integration.rs | 6 +- lib/llm/tests/tool_choice.rs | 253 +----------------- lib/parsers/src/tool_calling/json/mod.rs | 7 +- lib/runtime/src/distributed.rs | 7 +- 7 files changed, 71 insertions(+), 424 deletions(-) diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 9b9b0e8411e..7da56f7613b 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -16,7 +16,6 @@ pub mod completions; pub mod embeddings; pub mod models; pub mod nvext; -pub mod partial_json; pub mod responses; pub mod tools; pub mod validate; diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index c2274bb1239..6e178bc25a3 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -65,6 +65,28 @@ impl Default for DeltaAggregator { } } +fn convert_tool_chunk_to_message_tool_call( + chunk: &dynamo_async_openai::types::ChatCompletionMessageToolCallChunk, +) -> Option { + // Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall + if let (Some(id), Some(r#type), Some(function)) = (&chunk.id, &chunk.r#type, &chunk.function) { + if let (Some(name), Some(arguments)) = (&function.name, &function.arguments) { + Some(dynamo_async_openai::types::ChatCompletionMessageToolCall { + id: id.clone(), + r#type: r#type.clone(), + function: dynamo_async_openai::types::FunctionCall { + name: name.clone(), + arguments: arguments.clone(), + }, + }) + } else { + None + } + } else { + None + } +} + impl DeltaAggregator { /// Creates a new, empty [`DeltaAggregator`] instance. pub fn new() -> Self { @@ -153,51 +175,26 @@ impl DeltaAggregator { .push_str(reasoning_content); } - // Aggregate tool calls incrementally - // Each chunk may add a new tool call or append arguments to existing one - if let Some(tool_call_chunks) = &choice.delta.tool_calls - && !tool_call_chunks.is_empty() + // Since one tool call is one chunk, we don't need to aggregate them + // We just need to convert the ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall and append to the state_choice.tool_calls + if let Some(tool_calls) = &choice.delta.tool_calls + && !tool_calls.is_empty() { - // Initialize tool_calls vec if needed - let existing_tool_calls = state_choice - .tool_calls - .get_or_insert_with(Vec::new); - - // Process each chunk - for chunk in tool_call_chunks { - let chunk_index = chunk.index as usize; - - // Find or create tool call at this index - if chunk_index >= existing_tool_calls.len() { - // Extend the vec to accommodate this index - existing_tool_calls.resize_with(chunk_index + 1, || { - dynamo_async_openai::types::ChatCompletionMessageToolCall { - id: String::new(), - r#type: dynamo_async_openai::types::ChatCompletionToolType::Function, - function: dynamo_async_openai::types::FunctionCall { - name: String::new(), - arguments: String::new(), - }, - } - }); - } - - let tool_call = &mut existing_tool_calls[chunk_index]; - - // Update fields if present in chunk - if let Some(id) = &chunk.id { - tool_call.id = id.clone(); - } - if let Some(r#type) = &chunk.r#type { - tool_call.r#type = r#type.clone(); - } - if let Some(function) = &chunk.function { - if let Some(name) = &function.name { - tool_call.function.name = name.clone(); - } - if let Some(arguments) = &function.arguments { - tool_call.function.arguments.push_str(arguments); - } + // Convert ChatCompletionMessageToolCallChunk to ChatCompletionMessageToolCall + let converted_tool_calls: Vec< + dynamo_async_openai::types::ChatCompletionMessageToolCall, + > = tool_calls + .iter() + .filter_map(convert_tool_chunk_to_message_tool_call) + .collect(); + + // Initialize and push the converted tool calls to state_choice.tool_calls + // Only set tool_calls to Some if there are actual tool calls + if !converted_tool_calls.is_empty() { + if let Some(existing_tool_calls) = &mut state_choice.tool_calls { + existing_tool_calls.extend(converted_tool_calls); + } else { + state_choice.tool_calls = Some(converted_tool_calls); } } } @@ -273,17 +270,11 @@ impl From for dynamo_async_openai::types::ChatChoice { /// The `function_call` field is deprecated. fn from(delta: DeltaChoice) -> Self { // If tool calls are present and non-empty, finish reason should be ToolCalls - // Unless it's a critical finish reason (Length, ContentFilter, Stop) that should be preserved let finish_reason = if delta .tool_calls .as_ref() .is_some_and(|calls| !calls.is_empty()) - && !matches!( - delta.finish_reason, - Some(dynamo_async_openai::types::FinishReason::Stop) - | Some(dynamo_async_openai::types::FinishReason::Length) - | Some(dynamo_async_openai::types::FinishReason::ContentFilter) - ) { + { Some(dynamo_async_openai::types::FinishReason::ToolCalls) } else { delta.finish_reason @@ -700,8 +691,8 @@ mod tests { } #[tokio::test] - async fn test_tool_calling_finish_reason_respects_explicit_stop() { - // Test that when tool calls are present and finish reason is Stop, it remains Stop + async fn test_tool_calling_finish_reason_override_from_stop() { + // Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#; @@ -735,22 +726,23 @@ mod tests { let tool_calls = choice.message.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); - // Finish reason should remain Stop because it was explicitly provided that way + // Most importantly, verify that finish reason was overridden to ToolCalls despite original being Stop assert_eq!( choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::Stop) + Some(dynamo_async_openai::types::FinishReason::ToolCalls) ); } #[tokio::test] - async fn test_tool_calling_preserves_length_when_present() { + async fn test_tool_calling_finish_reason_override_from_length() { + // Test that when tool calls are present but finish reason is Length, it gets overridden to ToolCalls let tool_call_json = r#"{"name": "search", "arguments": {"query": "rust programming"}}"#; let annotated_delta = create_test_delta( 0, "Let me search for that.", Some(dynamo_async_openai::types::Role::Assistant), - Some(dynamo_async_openai::types::FinishReason::Length), + Some(dynamo_async_openai::types::FinishReason::Length), // Original finish reason is Length None, Some(tool_call_json), ); @@ -771,13 +763,15 @@ mod tests { assert_eq!(response.choices.len(), 1); let choice = &response.choices[0]; + // Verify tool calls are present assert!(choice.message.tool_calls.is_some()); let tool_calls = choice.message.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); + // Verify that finish reason was overridden to ToolCalls despite original being Length assert_eq!( choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::Length) + Some(dynamo_async_openai::types::FinishReason::ToolCalls) ); } @@ -970,8 +964,8 @@ mod tests { } #[tokio::test] - async fn test_tool_calling_finish_reason_stop_remains_when_set() { - // When finish_reason is explicitly Stop, we preserve it even if tool_calls are present + async fn test_tool_calling_finish_reason_override_from_stop_alternative() { + // Test that when tool calls are present but finish reason is Stop, it gets overridden to ToolCalls let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "New York", "unit": "celsius"}}"#; @@ -997,10 +991,10 @@ mod tests { assert_eq!(response.choices.len(), 1); let choice = &response.choices[0]; - // The finish_reason should remain Stop because it was explicitly provided that way + // The finish_reason should be ToolCalls, not Stop, because tool calls are present assert_eq!( choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::Stop) + Some(dynamo_async_openai::types::FinishReason::ToolCalls) ); // Verify tool calls are present @@ -1009,85 +1003,4 @@ mod tests { assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].function.name, "get_weather"); } - - #[tokio::test] - async fn test_tool_calling_preserves_length_finish_reason() { - // Test that Length finish reason is preserved even with tool_calls present - let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "Paris"}}"#; - - let annotated_delta = create_test_delta( - 0, - "", - Some(dynamo_async_openai::types::Role::Assistant), - Some(dynamo_async_openai::types::FinishReason::Length), // Length finish reason - None, - Some(tool_call_json), - ); - - let data = annotated_delta.data.unwrap(); - let annotated_delta = Annotated { - data: Some(data), - id: Some("test_id".to_string()), - event: None, - comment: None, - }; - let stream = Box::pin(stream::iter(vec![annotated_delta])); - - let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; - - assert!(result.is_ok()); - let response = result.unwrap(); - assert_eq!(response.choices.len(), 1); - let choice = &response.choices[0]; - - // Critical: Length finish reason MUST be preserved, not replaced with ToolCalls - assert_eq!( - choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::Length), - "Length finish reason must be preserved even when tool_calls are present" - ); - - // Verify tool calls are still present - assert!(choice.message.tool_calls.is_some()); - } - - #[tokio::test] - async fn test_tool_calling_preserves_content_filter_finish_reason() { - // Test that ContentFilter finish reason is preserved even with tool_calls present - let tool_call_json = r#"{"name": "harmful_action", "arguments": {}}"#; - - let annotated_delta = create_test_delta( - 0, - "", - Some(dynamo_async_openai::types::Role::Assistant), - Some(dynamo_async_openai::types::FinishReason::ContentFilter), // ContentFilter - None, - Some(tool_call_json), - ); - - let data = annotated_delta.data.unwrap(); - let annotated_delta = Annotated { - data: Some(data), - id: Some("test_id".to_string()), - event: None, - comment: None, - }; - let stream = Box::pin(stream::iter(vec![annotated_delta])); - - let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; - - assert!(result.is_ok()); - let response = result.unwrap(); - let choice = &response.choices[0]; - - // Critical: ContentFilter finish reason MUST be preserved - assert_eq!( - choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::ContentFilter), - "ContentFilter finish reason must be preserved even when tool_calls are present" - ); - - // Verify tool calls are still present - assert!(choice.message.tool_calls.is_some()); - } } diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index eeaa635f7ee..56737a2a770 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -418,18 +418,15 @@ impl ChoiceJailStateCollection { /// Emission mode for handling multiple choices #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Default)] pub enum EmissionMode { /// Pack multiple choices in the same chunk (default, matches original behavior) + #[default] Packed, /// Emit one choice per chunk for OpenAI compatibility SingleChoicePerChunk, } -impl Default for EmissionMode { - fn default() -> Self { - Self::Packed - } -} /// A stream transformer that can "jail" tokens based on configurable start/end sequences /// When jailed, tokens are accumulated rather than yielded immediately @@ -522,7 +519,7 @@ impl JailedStream { // Only filter out if this choice was ever jailed and lacks role // (to avoid aggregator issues with deltas missing role after unjail) let choice_state = choice_states.get_or_create_state(choice.index, false); - let was_ever_jailed = choice_state.accumulated_content.len() > 0 || choice_state.is_jailed; + let was_ever_jailed = !choice_state.accumulated_content.is_empty() || choice_state.is_jailed; let should_emit = choice.delta.role.is_some() || choice.delta.tool_calls.is_some() @@ -745,22 +742,19 @@ impl JailedStream { match format { ToolChoiceFormat::SingleObject { .. } => { // Expect single object: {"location": "Paris", "unit": "celsius"} - if let Ok(value) = serde_json::from_str::(accumulated_content) { - if value.is_object() { + if let Ok(value) = serde_json::from_str::(accumulated_content) + && value.is_object() { return (true, accumulated_content.len()); } - } (false, accumulated_content.len()) } ToolChoiceFormat::ArrayOfTools => { // Expect array: [{"name":"search","parameters":{...}}, ...] - if let Ok(value) = serde_json::from_str::(accumulated_content) { - if let Some(arr) = value.as_array() { - if !arr.is_empty() { + if let Ok(value) = serde_json::from_str::(accumulated_content) + && let Some(arr) = value.as_array() + && !arr.is_empty() { return (true, accumulated_content.len()); } - } - } (false, accumulated_content.len()) } } diff --git a/lib/llm/tests/http_namespace_integration.rs b/lib/llm/tests/http_namespace_integration.rs index 4ff727f3469..93529c087ea 100644 --- a/lib/llm/tests/http_namespace_integration.rs +++ b/lib/llm/tests/http_namespace_integration.rs @@ -73,12 +73,10 @@ fn test_model_discovery_scoping_scenarios() { // Scenario 1: Frontend configured for specific namespace should only see models from that namespace let frontend_namespace = "vllm-agg"; - let available_models = vec![ - create_test_endpoint("vllm-agg", "backend", "generate"), + let available_models = [create_test_endpoint("vllm-agg", "backend", "generate"), create_test_endpoint("vllm-agg", "backend", "generate"), create_test_endpoint("sglang-prod", "backend", "generate"), - create_test_endpoint("dynamo", "backend", "generate"), - ]; + create_test_endpoint("dynamo", "backend", "generate")]; let visible_models: Vec<&EndpointId> = available_models .iter() diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index 884013a130f..0c94446870e 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -368,63 +368,6 @@ async fn test_streaming_required_tool_parallel() { assert_eq!(tool_calls[1].function.as_ref().unwrap().arguments.as_deref(), Some(r#"{"topic":"memory"}"#)); } -#[tokio::test] -async fn test_streaming_buffers_until_finish() { - let mut request = create_test_request(); - let tool_choice = Some(ChatCompletionToolChoiceOption::Named( - ChatCompletionNamedToolChoice { - r#type: ChatCompletionToolType::Function, - function: FunctionName { - name: "search".to_string(), - }, - }, - )); - request.inner.tool_choice = tool_choice.clone(); - - let mut generator = request.response_generator("req-stream-3".to_string()); - - let full_json = r#"{"query":"rust programming"}"#; - let mut raw_responses = Vec::new(); - - for (i, ch) in full_json.chars().enumerate() { - let is_last = i == full_json.len() - 1; - let backend_output = BackendOutput { - token_ids: vec![], - tokens: vec![], - text: Some(ch.to_string()), - cum_log_probs: None, - log_probs: None, - top_logprobs: None, - finish_reason: if is_last { - Some(common::FinishReason::Stop) - } else { - None - }, - index: Some(0), - completion_usage: None, - disaggregated_params: None, - }; - - let response = generator - .choice_from_postprocessor(backend_output) - .expect("char chunk"); - raw_responses.push(response); - } - - let responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; - - // Jail stream buffers all content until complete JSON - assert_eq!(responses.len(), 1); - - let response = &responses[0]; - assert!(response.choices[0].delta.tool_calls.is_some()); - assert_eq!( - response.choices[0].delta.tool_calls.as_ref().unwrap()[0] - .function.as_ref().unwrap().arguments.as_deref(), - Some(r#"{"query":"rust programming"}"#) - ); -} - #[test] fn test_no_tool_choice_outputs_normal_text() { let request = create_test_request(); @@ -453,198 +396,4 @@ fn test_no_tool_choice_outputs_normal_text() { Some("Hello world") ); assert!(response.choices[0].delta.tool_calls.is_none()); -} - - -/// Helper function to create a streaming chunk -fn create_chunk( - index: u32, - role: Option, - tool_call_chunk: Option, - finish_reason: Option, -) -> dynamo_async_openai::types::CreateChatCompletionStreamResponse { - use dynamo_async_openai::types::{ - ChatCompletionStreamResponseDelta, CreateChatCompletionStreamResponse, - }; - - CreateChatCompletionStreamResponse { - id: "test".to_string(), - choices: vec![ChatChoiceStream { - index, - delta: ChatCompletionStreamResponseDelta { - role, - content: None, - function_call: None, - tool_calls: tool_call_chunk.map(|chunk| vec![chunk]), - refusal: None, - reasoning_content: None, - }, - finish_reason, - logprobs: None, - }], - created: 1234567890, - model: "test-model".to_string(), - system_fingerprint: None, - object: "chat.completion.chunk".to_string(), - service_tier: None, - usage: None, - nvext: None, - } -} - -#[tokio::test] -async fn test_aggregator_named_tool() { - use dynamo_llm::protocols::Annotated; - use dynamo_llm::protocols::openai::ParsingOptions; - use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; - use futures::stream; - - let chunks = vec![ - create_chunk( - 0, - Some(dynamo_async_openai::types::Role::Assistant), - None, - None, - ), - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 0, - id: Some("call-1".to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some("get_weather".to_string()), - arguments: Some(r#"{"location":"Paris","unit":"celsius"}"#.to_string()), - }), - }), - None, - ), - create_chunk( - 0, - None, - None, - Some(dynamo_async_openai::types::FinishReason::Stop), - ), - ]; - - let annotated_chunks: Vec> = chunks - .into_iter() - .map(|chunk| Annotated { - data: Some(chunk), - id: None, - event: None, - comment: None, - }) - .collect(); - - let stream = Box::pin(stream::iter(annotated_chunks)); - let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; - - assert!(result.is_ok()); - let response = result.unwrap(); - - assert_eq!(response.choices.len(), 1); - let choice = &response.choices[0]; - - assert!(choice.message.tool_calls.is_some()); - let tool_calls = choice.message.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 1); - - let tool_call = &tool_calls[0]; - assert_eq!(tool_call.id, "call-1"); - assert_eq!(tool_call.function.name, "get_weather"); - assert_eq!( - tool_call.function.arguments, r#"{"location":"Paris","unit":"celsius"}"# - ); -} - -#[tokio::test] -async fn test_aggregator_required_tool_parallel_calls() { - use dynamo_async_openai::types::{ - ChatCompletionMessageToolCallChunk, ChatCompletionToolType, FunctionCallStream, - }; - use dynamo_llm::protocols::Annotated; - use dynamo_llm::protocols::openai::ParsingOptions; - use dynamo_llm::protocols::openai::chat_completions::aggregator::DeltaAggregator; - use futures::stream; - - let chunks = vec![ - create_chunk( - 0, - Some(dynamo_async_openai::types::Role::Assistant), - None, - None, - ), - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 0, - id: Some("call-1".to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some("search".to_string()), - arguments: Some(r#"{"query":"rust"}"#.to_string()), - }), - }), - None, - ), - create_chunk( - 0, - None, - Some(ChatCompletionMessageToolCallChunk { - index: 1, - id: Some("call-2".to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some("summarize".to_string()), - arguments: Some(r#"{"text":"long article"}"#.to_string()), - }), - }), - None, - ), - create_chunk( - 0, - None, - None, - Some(dynamo_async_openai::types::FinishReason::ToolCalls), - ), - ]; - - let annotated_chunks: Vec> = chunks - .into_iter() - .map(|chunk| Annotated { - data: Some(chunk), - id: None, - event: None, - comment: None, - }) - .collect(); - - let stream = Box::pin(stream::iter(annotated_chunks)); - let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; - - assert!(result.is_ok()); - let response = result.unwrap(); - - assert_eq!(response.choices.len(), 1); - let choice = &response.choices[0]; - - assert!(choice.message.tool_calls.is_some()); - let tool_calls = choice.message.tool_calls.as_ref().unwrap(); - assert_eq!(tool_calls.len(), 2); - - assert_eq!(tool_calls[0].id, "call-1"); - assert_eq!(tool_calls[0].function.name, "search"); - assert_eq!(tool_calls[0].function.arguments, r#"{"query":"rust"}"#); - - assert_eq!(tool_calls[1].id, "call-2"); - assert_eq!(tool_calls[1].function.name, "summarize"); - assert_eq!(tool_calls[1].function.arguments, r#"{"text":"long article"}"#); - - assert_eq!( - choice.finish_reason, - Some(dynamo_async_openai::types::FinishReason::ToolCalls) - ); -} +} \ No newline at end of file diff --git a/lib/parsers/src/tool_calling/json/mod.rs b/lib/parsers/src/tool_calling/json/mod.rs index d5f2a3c0def..cdc83f658e2 100644 --- a/lib/parsers/src/tool_calling/json/mod.rs +++ b/lib/parsers/src/tool_calling/json/mod.rs @@ -16,19 +16,16 @@ pub use super::config::JsonParserConfig; pub use super::response::ToolCallResponse; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Default)] pub enum JsonParserType { // Basic is generic json parser which can handle most of the cases + #[default] Basic, // Model Specific JSON Parsers DeepseekV3, DeepseekV31, } -impl Default for JsonParserType { - fn default() -> Self { - Self::Basic - } -} pub fn try_tool_call_parse_json( message: &str, diff --git a/lib/runtime/src/distributed.rs b/lib/runtime/src/distributed.rs index 407693385f4..cdfe3edff82 100644 --- a/lib/runtime/src/distributed.rs +++ b/lib/runtime/src/distributed.rs @@ -610,8 +610,10 @@ impl DistributedConfig { /// - `Http`: Use HTTP/2 for request distribution /// - `Tcp`: Use raw TCP for request distribution with msgpack support #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Default)] pub enum RequestPlaneMode { /// Use NATS for request plane (default for backward compatibility) + #[default] Nats, /// Use HTTP/2 for request plane Http, @@ -619,11 +621,6 @@ pub enum RequestPlaneMode { Tcp, } -impl Default for RequestPlaneMode { - fn default() -> Self { - Self::Nats - } -} impl fmt::Display for RequestPlaneMode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { From ea836bb7ddd481404d9457c26d6bc9872e431368 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:20:16 +0300 Subject: [PATCH 10/18] fix fmt Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/preprocessor.rs | 4 +- .../openai/chat_completions/delta.rs | 5 +- .../protocols/openai/chat_completions/jail.rs | 57 +++++++------- lib/llm/tests/http_namespace_integration.rs | 6 +- lib/llm/tests/tool_choice.rs | 76 ++++++++++++++----- lib/llm/tests/tool_choice_finish_reasons.rs | 4 +- lib/parsers/src/tool_calling/json/mod.rs | 7 +- lib/runtime/src/distributed.rs | 7 +- 8 files changed, 108 insertions(+), 58 deletions(-) diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 73c23236609..121433e090f 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -807,7 +807,9 @@ impl OpenAIPreprocessor { // Immediate jail mode for required tool choice builder = builder.tool_choice_required(); } - Some(ChatCompletionToolChoiceOption::Auto) | Some(ChatCompletionToolChoiceOption::None) | None => { + Some(ChatCompletionToolChoiceOption::Auto) + | Some(ChatCompletionToolChoiceOption::None) + | None => { // Traditional marker-based jail for auto/none/unspecified if let Some(parser) = tool_call_parser { builder = builder.tool_call_parser(parser); diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index e2ebd93913e..d636b9f4a71 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -4,9 +4,7 @@ use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use crate::{ local_model::runtime_config::ModelRuntimeConfig, - protocols::{ - common::{self}, - }, + protocols::common::{self}, types::TokenIdType, }; @@ -304,7 +302,6 @@ impl DeltaGenerator { pub fn is_usage_enabled(&self) -> bool { self.options.enable_usage } - } /// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 56737a2a770..35d22a9e270 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -417,8 +417,7 @@ impl ChoiceJailStateCollection { } /// Emission mode for handling multiple choices -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[derive(Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum EmissionMode { /// Pack multiple choices in the same chunk (default, matches original behavior) #[default] @@ -427,7 +426,6 @@ pub enum EmissionMode { SingleChoicePerChunk, } - /// A stream transformer that can "jail" tokens based on configurable start/end sequences /// When jailed, tokens are accumulated rather than yielded immediately /// When the jail ends (via end sequence or stream completion), accumulated content is processed and released @@ -725,7 +723,8 @@ impl JailedStream { if let Ok((_, _)) = try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await { - let split_pos = find_tool_call_end_position(accumulated_content, Some(parser)); + let split_pos = + find_tool_call_end_position(accumulated_content, Some(parser)); (true, split_pos) } else { (false, accumulated_content.len()) @@ -742,19 +741,23 @@ impl JailedStream { match format { ToolChoiceFormat::SingleObject { .. } => { // Expect single object: {"location": "Paris", "unit": "celsius"} - if let Ok(value) = serde_json::from_str::(accumulated_content) - && value.is_object() { - return (true, accumulated_content.len()); - } + if let Ok(value) = + serde_json::from_str::(accumulated_content) + && value.is_object() + { + return (true, accumulated_content.len()); + } (false, accumulated_content.len()) } ToolChoiceFormat::ArrayOfTools => { // Expect array: [{"name":"search","parameters":{...}}, ...] - if let Ok(value) = serde_json::from_str::(accumulated_content) + if let Ok(value) = + serde_json::from_str::(accumulated_content) && let Some(arr) = value.as_array() - && !arr.is_empty() { - return (true, accumulated_content.len()); - } + && !arr.is_empty() + { + return (true, accumulated_content.len()); + } (false, accumulated_content.len()) } } @@ -772,9 +775,11 @@ impl JailedStream { match &self.jail_mode { JailMode::MarkerBased => { // Traditional marker-based tool call parsing - if let Ok((tool_calls, normal_text)) = - try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref()) - .await + if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate( + accumulated_content, + self.tool_call_parser.as_deref(), + ) + .await && !tool_calls.is_empty() { // Convert to streaming format @@ -816,16 +821,14 @@ impl JailedStream { JailMode::Immediate { format } => { // tool_choice mode: parse JSON and convert to tool calls match self.parse_tool_choice_json(accumulated_content, format) { - Ok(tool_call_chunks) if !tool_call_chunks.is_empty() => { - create_choice_stream( - choice_index, - Some(Role::Assistant), - "", - Some(tool_call_chunks), - base_choice.finish_reason, - base_choice.logprobs.clone(), - ) - } + Ok(tool_call_chunks) if !tool_call_chunks.is_empty() => create_choice_stream( + choice_index, + Some(Role::Assistant), + "", + Some(tool_call_chunks), + base_choice.finish_reason, + base_choice.logprobs.clone(), + ), Ok(_) | Err(_) => { // Parsing failed, return as content create_choice_stream( @@ -880,7 +883,9 @@ impl JailedStream { Some(ChatCompletionMessageToolCallChunk { index: idx as u32, id: Some(format!("call-{}", idx + 1)), - r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function), + r#type: Some( + dynamo_async_openai::types::ChatCompletionToolType::Function, + ), function: Some(FunctionCallStream { name: Some(name), arguments: Some(args), diff --git a/lib/llm/tests/http_namespace_integration.rs b/lib/llm/tests/http_namespace_integration.rs index 93529c087ea..e0df5859d38 100644 --- a/lib/llm/tests/http_namespace_integration.rs +++ b/lib/llm/tests/http_namespace_integration.rs @@ -73,10 +73,12 @@ fn test_model_discovery_scoping_scenarios() { // Scenario 1: Frontend configured for specific namespace should only see models from that namespace let frontend_namespace = "vllm-agg"; - let available_models = [create_test_endpoint("vllm-agg", "backend", "generate"), + let available_models = [ + create_test_endpoint("vllm-agg", "backend", "generate"), create_test_endpoint("vllm-agg", "backend", "generate"), create_test_endpoint("sglang-prod", "backend", "generate"), - create_test_endpoint("dynamo", "backend", "generate")]; + create_test_endpoint("dynamo", "backend", "generate"), + ]; let visible_models: Vec<&EndpointId> = available_models .iter() diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index 0c94446870e..298209ffd46 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -40,9 +40,9 @@ async fn apply_jail_transformation( tool_choice: Option, ) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; - use futures::stream; - use futures::StreamExt; use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; let input_stream = stream::iter(vec![Annotated { data: Some(raw_response), @@ -71,13 +71,15 @@ async fn apply_jail_transformation( } async fn apply_jail_transformation_streaming( - raw_responses: Vec, + raw_responses: Vec< + dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + >, tool_choice: Option, ) -> Vec { use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; - use futures::stream; - use futures::StreamExt; use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; let input_stream = stream::iter(raw_responses.into_iter().map(|r| Annotated { data: Some(r), @@ -205,7 +207,12 @@ async fn test_required_tool_choice_parses_json_array() { Some("search") ); assert_eq!( - tool_calls[0].function.as_ref().unwrap().arguments.as_deref(), + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), Some(r#"{"query":"rust"}"#) ); @@ -217,7 +224,12 @@ async fn test_required_tool_choice_parses_json_array() { Some("summarize") ); assert_eq!( - tool_calls[1].function.as_ref().unwrap().arguments.as_deref(), + tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), Some(r#"{"topic":"memory"}"#) ); } @@ -258,11 +270,7 @@ async fn test_streaming_named_tool_buffers_until_finish() { let mut generator = request.response_generator("req-stream-1".to_string()); - let chunks = [ - r#"{"location":""#, - r#"Paris","unit":""#, - r#"celsius"}"#, - ]; + let chunks = [r#"{"location":""#, r#"Paris","unit":""#, r#"celsius"}"#]; let mut raw_responses = Vec::new(); for (i, chunk) in chunks.iter().enumerate() { @@ -302,9 +310,17 @@ async fn test_streaming_named_tool_buffers_until_finish() { let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 1); - assert_eq!(tool_calls[0].function.as_ref().unwrap().name.as_deref(), Some("get_weather")); assert_eq!( - tool_calls[0].function.as_ref().unwrap().arguments.as_deref(), + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), Some(r#"{"location":"Paris","unit":"celsius"}"#) ); } @@ -361,11 +377,33 @@ async fn test_streaming_required_tool_parallel() { let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); assert_eq!(tool_calls.len(), 2); - assert_eq!(tool_calls[0].function.as_ref().unwrap().name.as_deref(), Some("search")); - assert_eq!(tool_calls[0].function.as_ref().unwrap().arguments.as_deref(), Some(r#"{"query":"rust"}"#)); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("search") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"query":"rust"}"#) + ); - assert_eq!(tool_calls[1].function.as_ref().unwrap().name.as_deref(), Some("summarize")); - assert_eq!(tool_calls[1].function.as_ref().unwrap().arguments.as_deref(), Some(r#"{"topic":"memory"}"#)); + assert_eq!( + tool_calls[1].function.as_ref().unwrap().name.as_deref(), + Some("summarize") + ); + assert_eq!( + tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"topic":"memory"}"#) + ); } #[test] @@ -396,4 +434,4 @@ fn test_no_tool_choice_outputs_normal_text() { Some("Hello world") ); assert!(response.choices[0].delta.tool_calls.is_none()); -} \ No newline at end of file +} diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs index 7b0519e9341..07f28d59626 100644 --- a/lib/llm/tests/tool_choice_finish_reasons.rs +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -56,9 +56,9 @@ async fn apply_jail_transformation( tool_choice: Option, ) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; - use futures::stream; - use futures::StreamExt; use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; let input_stream = stream::iter(vec![Annotated { data: Some(raw_response), diff --git a/lib/parsers/src/tool_calling/json/mod.rs b/lib/parsers/src/tool_calling/json/mod.rs index cdc83f658e2..d5f2a3c0def 100644 --- a/lib/parsers/src/tool_calling/json/mod.rs +++ b/lib/parsers/src/tool_calling/json/mod.rs @@ -16,16 +16,19 @@ pub use super::config::JsonParserConfig; pub use super::response::ToolCallResponse; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -#[derive(Default)] pub enum JsonParserType { // Basic is generic json parser which can handle most of the cases - #[default] Basic, // Model Specific JSON Parsers DeepseekV3, DeepseekV31, } +impl Default for JsonParserType { + fn default() -> Self { + Self::Basic + } +} pub fn try_tool_call_parse_json( message: &str, diff --git a/lib/runtime/src/distributed.rs b/lib/runtime/src/distributed.rs index cdfe3edff82..407693385f4 100644 --- a/lib/runtime/src/distributed.rs +++ b/lib/runtime/src/distributed.rs @@ -610,10 +610,8 @@ impl DistributedConfig { /// - `Http`: Use HTTP/2 for request distribution /// - `Tcp`: Use raw TCP for request distribution with msgpack support #[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[derive(Default)] pub enum RequestPlaneMode { /// Use NATS for request plane (default for backward compatibility) - #[default] Nats, /// Use HTTP/2 for request plane Http, @@ -621,6 +619,11 @@ pub enum RequestPlaneMode { Tcp, } +impl Default for RequestPlaneMode { + fn default() -> Self { + Self::Nats + } +} impl fmt::Display for RequestPlaneMode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { From ff607bdc443616dc32919710da8bfda03c8044ee Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:23:48 +0300 Subject: [PATCH 11/18] fix clippy Signed-off-by: Vladislav Nosivskoy --- lib/llm/tests/tool_choice.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index 298209ffd46..832feb0c872 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -2,10 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 use dynamo_async_openai::types::{ - ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionNamedToolChoice, + ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, - ChatCompletionToolType, CreateChatCompletionRequest, FunctionCallStream, FunctionName, + ChatCompletionToolType, CreateChatCompletionRequest, FunctionName, }; use dynamo_llm::protocols::common; use dynamo_llm::protocols::common::llm_backend::BackendOutput; From 0456aa6950d233b739e0bf392dd65dd78f43f195 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:26:36 +0300 Subject: [PATCH 12/18] revert clippy change on non-relevant file Signed-off-by: Vladislav Nosivskoy --- lib/llm/tests/http_namespace_integration.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/tests/http_namespace_integration.rs b/lib/llm/tests/http_namespace_integration.rs index e0df5859d38..4ff727f3469 100644 --- a/lib/llm/tests/http_namespace_integration.rs +++ b/lib/llm/tests/http_namespace_integration.rs @@ -73,7 +73,7 @@ fn test_model_discovery_scoping_scenarios() { // Scenario 1: Frontend configured for specific namespace should only see models from that namespace let frontend_namespace = "vllm-agg"; - let available_models = [ + let available_models = vec![ create_test_endpoint("vllm-agg", "backend", "generate"), create_test_endpoint("vllm-agg", "backend", "generate"), create_test_endpoint("sglang-prod", "backend", "generate"), From a21c9ae5c054c09eae105ae2112cca6ec9fbbbf0 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:34:59 +0300 Subject: [PATCH 13/18] remove change Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/protocols/openai/chat_completions/delta.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index d636b9f4a71..821dcf7266f 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -258,8 +258,6 @@ impl DeltaGenerator { let choices = vec![choice]; - self.msg_counter += 1; - // According to OpenAI spec: when stream_options.include_usage is true, // all intermediate chunks should have usage: null // The final usage chunk will be sent separately with empty choices From 85ac49d35a931bdf2470920b0412fbe1a9629154 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Fri, 5 Dec 2025 15:43:20 +0300 Subject: [PATCH 14/18] revert delta.rs and cargo fmt Signed-off-by: Vladislav Nosivskoy --- .../openai/chat_completions/delta.rs | 21 +++---------------- lib/llm/tests/tool_choice.rs | 3 +-- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 821dcf7266f..186bb7f0950 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -216,30 +216,18 @@ impl DeltaGenerator { /// /// # Returns /// * An [`dynamo_async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice. - pub fn create_choice( - &mut self, - index: u32, - text: Option, - finish_reason: Option, - logprobs: Option, - ) -> NvCreateChatCompletionStreamResponse { - self.build_choice(index, text, finish_reason, logprobs, None) - } - - /// Internal method to build a streaming chat completion response with optional tool_calls. #[allow(deprecated)] - fn build_choice( + pub fn create_choice( &mut self, index: u32, text: Option, finish_reason: Option, logprobs: Option, - tool_calls: Option>, ) -> NvCreateChatCompletionStreamResponse { let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta { content: text, function_call: None, - tool_calls, + tool_calls: None, role: if self.msg_counter == 0 { Some(dynamo_async_openai::types::Role::Assistant) } else { @@ -372,10 +360,7 @@ impl crate::protocols::openai::DeltaGeneratorExt Date: Fri, 5 Dec 2025 18:36:31 +0300 Subject: [PATCH 15/18] fix should apply tool jail Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/preprocessor.rs | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 121433e090f..00ebe0b4c1f 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -752,25 +752,17 @@ impl OpenAIPreprocessor { has_tools: bool, ) -> std::result::Result { match (tool_call_parser, tool_choice, has_tools) { - // No parser but tools requested - error cases - (None, Some(ChatCompletionToolChoiceOption::Required), true) => { - tracing::warn!( - "Tool choice 'required' specified but no tool parser configured; proceeding without jailing" - ); - Ok(false) - } + // tool_choice=required/named work without parser (use Immediate jail mode) + (None, Some(ChatCompletionToolChoiceOption::Required), true) => Ok(true), + (None, Some(ChatCompletionToolChoiceOption::Named(_)), true) => Ok(true), + + // tool_choice=auto requires a parser (None, Some(ChatCompletionToolChoiceOption::Auto), true) => { tracing::warn!( "Tool choice 'auto' specified but no tool parser configured; proceeding without jailing" ); Ok(false) } - (None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => { - tracing::warn!( - "Named tool choice specified but no tool parser configured; proceeding without jailing" - ); - Ok(false) - } // Parser exists and tools might be called (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => { From 7f6ba72032baa3ed2e78bfe1794b4ca305811827 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Tue, 9 Dec 2025 14:42:26 +0300 Subject: [PATCH 16/18] add helper function for tool chunk creation Signed-off-by: Vladislav Nosivskoy --- .../protocols/openai/chat_completions/jail.rs | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 6f82b40096e..2a716cc743e 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -14,6 +14,7 @@ use dynamo_parsers::tool_calling::{ use dynamo_runtime::protocols::annotated::Annotated; use futures::{Stream, StreamExt}; use std::collections::HashMap; +use uuid::Uuid; use crate::utils::{MarkerMatcher, MatchResult}; @@ -874,6 +875,23 @@ impl JailedStream { } } + /// Helper to create a ChatCompletionMessageToolCallChunk + fn create_tool_call_chunk( + index: u32, + name: String, + arguments: String, + ) -> ChatCompletionMessageToolCallChunk { + ChatCompletionMessageToolCallChunk { + index, + id: Some(format!("call-{}", Uuid::new_v4())), + r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(name), + arguments: Some(arguments), + }), + } + } + /// Parse tool_choice JSON output into tool call chunks fn parse_tool_choice_json( &self, @@ -886,15 +904,11 @@ impl JailedStream { ToolChoiceFormat::SingleObject { tool_name } => { // For named tool choice: JSON is the parameters object if parsed.is_object() { - Ok(vec![ChatCompletionMessageToolCallChunk { - index: 0, - id: Some("call-1".to_string()), - r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(tool_name.clone()), - arguments: Some(json_content.to_string()), - }), - }]) + Ok(vec![Self::create_tool_call_chunk( + 0, + tool_name.clone(), + json_content.to_string(), + )]) } else { Ok(vec![]) } @@ -909,17 +923,7 @@ impl JailedStream { let name = entry.get("name")?.as_str()?.to_string(); let parameters = entry.get("parameters")?; let args = serde_json::to_string(parameters).ok()?; - Some(ChatCompletionMessageToolCallChunk { - index: idx as u32, - id: Some(format!("call-{}", idx + 1)), - r#type: Some( - dynamo_async_openai::types::ChatCompletionToolType::Function, - ), - function: Some(FunctionCallStream { - name: Some(name), - arguments: Some(args), - }), - }) + Some(Self::create_tool_call_chunk(idx as u32, name, args)) }) .collect(); Ok(chunks) From 69cdcdf43382a899340c70dc556b03a94883dc94 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Tue, 9 Dec 2025 14:49:39 +0300 Subject: [PATCH 17/18] add comments in tools.rs Signed-off-by: Vladislav Nosivskoy --- lib/llm/src/protocols/openai/tools.rs | 108 ++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/lib/llm/src/protocols/openai/tools.rs b/lib/llm/src/protocols/openai/tools.rs index ad881d50dc6..457f5b37a99 100644 --- a/lib/llm/src/protocols/openai/tools.rs +++ b/lib/llm/src/protocols/openai/tools.rs @@ -62,11 +62,66 @@ fn clone_parameters(function: &FunctionObject) -> Value { .unwrap_or_else(|| json!({"type": "object", "properties": {}})) } +/// Builds a JSON Schema for `tool_choice=required` that enforces an array of tool calls. +/// +/// # Schema Structure +/// +/// The generated schema looks like: +/// ```json +/// { +/// "type": "array", +/// "minItems": 1, +/// "items": { +/// "type": "object", +/// "anyOf": [ +/// { +/// "properties": { +/// "name": {"type": "string", "enum": ["tool1"]}, +/// "parameters": { /* tool1's parameter schema */ } +/// }, +/// "required": ["name", "parameters"] +/// }, +/// { +/// "properties": { +/// "name": {"type": "string", "enum": ["tool2"]}, +/// "parameters": { /* tool2's parameter schema */ } +/// }, +/// "required": ["name", "parameters"] +/// } +/// ] +/// }, +/// "$defs": { /* shared type definitions from all tools */ } +/// } +/// ``` +/// +/// # $defs Handling +/// +/// `$defs` contains shared JSON Schema definitions that can be referenced via `$ref`. +/// For example, if two tools reference a common type: +/// ```json +/// { +/// "$defs": { +/// "Location": { +/// "type": "object", +/// "properties": { +/// "city": {"type": "string"}, +/// "country": {"type": "string"} +/// } +/// } +/// } +/// } +/// ``` +/// +/// We extract `$defs` from each tool's schema and merge them into a global `$defs` map +/// at the root level. If multiple tools define the same type, we verify they match to +/// avoid conflicts. fn build_required_schema(tools: &[ChatCompletionTool]) -> Result { + // Accumulator for all shared type definitions ($defs) across tools let mut defs: BTreeMap = BTreeMap::new(); let mut any_of = Vec::with_capacity(tools.len()); for tool in tools { + // Extract parameter schema and its $defs (if any) let ParamsAndDefs { schema, defs: new_defs, @@ -84,6 +139,7 @@ fn build_required_schema(tools: &[ChatCompletionTool]) -> Result Result Result>, } +/// Extracts `$defs` from a function's parameter schema, returning both the +/// cleaned schema and the definitions separately. +/// +/// # Example +/// +/// Input schema: +/// ```json +/// { +/// "type": "object", +/// "properties": { +/// "location": {"$ref": "#/$defs/Location"} +/// }, +/// "$defs": { +/// "Location": { +/// "type": "object", +/// "properties": {"city": {"type": "string"}} +/// } +/// } +/// } +/// ``` +/// +/// Returns: +/// - schema: same as input but with `$defs` removed +/// - defs: `Some({"Location": {...}})` fn split_defs(function: &FunctionObject) -> Result { let mut schema = clone_parameters(function); let defs = match &mut schema { @@ -136,6 +224,26 @@ fn convert_defs( } } +/// Merges definitions from one tool into the global `$defs` accumulator. +/// +/// # Conflict Detection +/// +/// If two tools define the same type name but with different schemas, we return +/// an error. This ensures consistency across tool definitions. +/// +/// # Example +/// +/// If `target` contains: +/// ```json +/// {"Location": {"type": "object", "properties": {"city": {"type": "string"}}}} +/// ``` +/// +/// And we try to merge: +/// ```json +/// {"Location": {"type": "object", "properties": {"city": {"type": "number"}}}} +/// ``` +/// +/// This will return `ToolChoiceError::ConflictingDefinition("Location")`. fn merge_defs( target: &mut BTreeMap, defs: Option>, From b3c69af3dfee75e27d1c439c4372e0f7c5ab45d6 Mon Sep 17 00:00:00 2001 From: Vladislav Nosivskoy Date: Tue, 9 Dec 2025 17:51:23 +0300 Subject: [PATCH 18/18] fix id validation in tests Signed-off-by: Vladislav Nosivskoy --- lib/llm/tests/tool_choice.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index 242d36b9e18..c970108d9b5 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -158,7 +158,7 @@ async fn test_named_tool_choice_parses_json() { let tool_call = &tool_calls[0]; assert_eq!(tool_call.index, 0); - assert_eq!(tool_call.id.as_deref(), Some("call-1")); + assert!(tool_call.id.as_ref().unwrap().starts_with("call-")); assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function)); assert_eq!( tool_call.function.as_ref().unwrap().name.as_deref(), @@ -199,7 +199,7 @@ async fn test_required_tool_choice_parses_json_array() { assert_eq!(tool_calls.len(), 2); assert_eq!(tool_calls[0].index, 0); - assert_eq!(tool_calls[0].id.as_deref(), Some("call-1")); + assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-")); assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function)); assert_eq!( tool_calls[0].function.as_ref().unwrap().name.as_deref(), @@ -216,7 +216,7 @@ async fn test_required_tool_choice_parses_json_array() { ); assert_eq!(tool_calls[1].index, 1); - assert_eq!(tool_calls[1].id.as_deref(), Some("call-2")); + assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-")); assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function)); assert_eq!( tool_calls[1].function.as_ref().unwrap().name.as_deref(),