diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index ac9ff9b3860e..b453ec9d4bc5 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -752,25 +752,17 @@ impl OpenAIPreprocessor { has_tools: bool, ) -> std::result::Result { match (tool_call_parser, tool_choice, has_tools) { - // No parser but tools requested - error cases - (None, Some(ChatCompletionToolChoiceOption::Required), true) => { - tracing::warn!( - "Tool choice 'required' specified but no tool parser configured; proceeding without jailing" - ); - Ok(false) - } + // tool_choice=required/named work without parser (use Immediate jail mode) + (None, Some(ChatCompletionToolChoiceOption::Required), true) => Ok(true), + (None, Some(ChatCompletionToolChoiceOption::Named(_)), true) => Ok(true), + + // tool_choice=auto requires a parser (None, Some(ChatCompletionToolChoiceOption::Auto), true) => { tracing::warn!( "Tool choice 'auto' specified but no tool parser configured; proceeding without jailing" ); Ok(false) } - (None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => { - tracing::warn!( - "Named tool choice specified but no tool parser configured; proceeding without jailing" - ); - Ok(false) - } // Parser exists and tools might be called (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => { @@ -786,15 +778,38 @@ impl OpenAIPreprocessor { /// Apply tool calling jail to the stream if needed pub fn apply_tool_calling_jail( - tool_call_parser: String, + tool_call_parser: Option, + tool_choice: Option, stream: S, ) -> impl Stream> + Send where S: Stream> + Send + 'static, { - let jail = JailedStream::builder() - .tool_call_parser(tool_call_parser) - .build(); + use dynamo_async_openai::types::ChatCompletionToolChoiceOption; + + let mut builder = JailedStream::builder(); + + // Configure jail based on tool_choice + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(named)) => { + // Immediate jail mode for named tool choice + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + // Immediate jail mode for required tool choice + builder = builder.tool_choice_required(); + } + Some(ChatCompletionToolChoiceOption::Auto) + | Some(ChatCompletionToolChoiceOption::None) + | None => { + // Traditional marker-based jail for auto/none/unspecified + if let Some(parser) = tool_call_parser { + builder = builder.tool_call_parser(parser); + } + } + } + + let jail = builder.build(); jail.apply_with_finish_reason(stream) } @@ -957,11 +972,11 @@ impl // Apply jail conditionally let transformed_stream: Pin + Send>> = if should_jail { - if let Some(parser) = self.tool_call_parser.clone() { - Box::pin(Self::apply_tool_calling_jail(parser, stream)) - } else { - Box::pin(stream) // Should not happen due to should_jail check - } + Box::pin(Self::apply_tool_calling_jail( + self.tool_call_parser.clone(), + request.inner.tool_choice.clone(), + stream, + )) } else { Box::pin(stream) }; diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 574d7db31cdf..7da56f7613b4 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -17,6 +17,7 @@ pub mod embeddings; pub mod models; pub mod nvext; pub mod responses; +pub mod tools; pub mod validate; use validate::{ @@ -131,7 +132,7 @@ impl SamplingOptionsProvid let guided_whitespace_pattern = self.get_guided_whitespace_pattern(); let guided_decoding = match common::GuidedDecodingOptions::from_optional( - guided_json.cloned(), + guided_json, guided_regex, guided_choice, guided_grammar, diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index bb712086ecbd..cab55032cbe9 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -12,7 +12,7 @@ use super::{ common_ext::{CommonExt, CommonExtProvider}, nvext::NvExt, nvext::NvExtProvider, - validate, + tools, validate, }; pub mod aggregator; @@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { } /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value> { - self.common.guided_json.as_ref() + fn get_guided_json(&self) -> Option { + if let Some(value) = self.common.guided_json.clone() { + return Some(value); + } + + let tool_choice = self.inner.tool_choice.as_ref()?; + let tools = self.inner.tools.as_deref()?; + + match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) { + Ok(schema) => schema, + Err(err) => { + tracing::warn!( + error = %err, + "failed to derive guided_json from tool_choice" + ); + None + } + } } fn get_guided_regex(&self) -> Option { diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index d25d637684a8..2a716cc743e3 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -14,6 +14,7 @@ use dynamo_parsers::tool_calling::{ use dynamo_runtime::protocols::annotated::Annotated; use futures::{Stream, StreamExt}; use std::collections::HashMap; +use uuid::Uuid; use crate::utils::{MarkerMatcher, MatchResult}; @@ -62,6 +63,24 @@ pub struct JailConfig<'a> { pub tool_call_parser: Option<&'a str>, } +/// Jail activation mode +#[derive(Debug, Clone, PartialEq)] +pub enum JailMode { + /// Traditional: wait for start marker, then jail + MarkerBased, + /// Immediate: start jailed from first token (for tool_choice) + Immediate { format: ToolChoiceFormat }, +} + +/// Format for tool_choice immediate jail mode +#[derive(Debug, Clone, PartialEq)] +pub enum ToolChoiceFormat { + /// tool_choice=named: expect single object {"location": "Paris", ...} + SingleObject { tool_name: String }, + /// tool_choice=required: expect array [{name:"search", parameters:{...}}, ...] + ArrayOfTools, +} + /// State tracking for an individual choice during jail processing #[derive(Debug, Clone)] struct ChoiceJailState { @@ -105,10 +124,10 @@ fn create_choice_stream( impl ChoiceJailState { /// Create a new jail state for a choice - fn new(index: u32) -> Self { + fn new(index: u32, starts_jailed: bool) -> Self { Self { index, - is_jailed: false, + is_jailed: starts_jailed, accumulated_content: String::new(), partial_match_buffer: String::new(), stream_finish_reason: None, @@ -409,7 +428,7 @@ impl ChoiceJailStateCollection { } /// Get or create state for a choice index - fn get_or_create_state(&mut self, index: u32) -> &mut ChoiceJailState { + fn get_or_create_state(&mut self, index: u32, starts_jailed: bool) -> &mut ChoiceJailState { // Find the position where this index should be match self.states.binary_search_by_key(&index, |s| s.index) { Ok(pos) => { @@ -418,7 +437,7 @@ impl ChoiceJailStateCollection { } Err(insert_pos) => { // Need to create new state - let new_state = ChoiceJailState::new(index); + let new_state = ChoiceJailState::new(index, starts_jailed); self.states.insert(insert_pos, new_state); &mut self.states[insert_pos] } @@ -427,20 +446,15 @@ impl ChoiceJailStateCollection { } /// Emission mode for handling multiple choices -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum EmissionMode { /// Pack multiple choices in the same chunk (default, matches original behavior) + #[default] Packed, /// Emit one choice per chunk for OpenAI compatibility SingleChoicePerChunk, } -impl Default for EmissionMode { - fn default() -> Self { - Self::Packed - } -} - /// A stream transformer that can "jail" tokens based on configurable start/end sequences /// When jailed, tokens are accumulated rather than yielded immediately /// When the jail ends (via end sequence or stream completion), accumulated content is processed and released @@ -450,6 +464,7 @@ pub struct JailedStream { tool_call_parser: Option, emission_mode: EmissionMode, marker_matcher: MarkerMatcher, + jail_mode: JailMode, } impl JailedStream { @@ -467,8 +482,9 @@ impl JailedStream { where S: Stream> + Send + 'static, { + let jail_mode = self.jail_mode.clone(); let jailed_stream = self.apply(stream); - JailedStream::fix_finish_reason(jailed_stream) + JailedStream::fix_finish_reason(jailed_stream, jail_mode) } /// Apply the jail transformation to a stream of chat completion responses @@ -508,7 +524,8 @@ impl JailedStream { // Process each choice independently using the new architecture for choice in &chat_response.choices { if let Some(ref content) = choice.delta.content { - let choice_state = choice_states.get_or_create_state(choice.index); + let starts_jailed = matches!(self.jail_mode, JailMode::Immediate { .. }); + let choice_state = choice_states.get_or_create_state(choice.index, starts_jailed); // Store metadata when any choice becomes jailed (first time only) if !choice_state.is_jailed && self.should_start_jail(content) @@ -526,14 +543,24 @@ impl JailedStream { all_emissions.extend(emissions); } else { // Handle choices without content (e.g., final chunks with finish_reason) - // These should always pass through - let pass_through_choice = ChatChoiceStream { - index: choice.index, - delta: choice.delta.clone(), - finish_reason: choice.finish_reason, - logprobs: choice.logprobs.clone(), - }; - all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); + // Only filter out if this choice was ever jailed and lacks role + // (to avoid aggregator issues with deltas missing role after unjail) + let choice_state = choice_states.get_or_create_state(choice.index, false); + let was_ever_jailed = !choice_state.accumulated_content.is_empty() || choice_state.is_jailed; + + let should_emit = choice.delta.role.is_some() + || choice.delta.tool_calls.is_some() + || !was_ever_jailed; // Always pass through if never jailed + + if should_emit { + let pass_through_choice = ChatChoiceStream { + index: choice.index, + delta: choice.delta.clone(), + finish_reason: choice.finish_reason, + logprobs: choice.logprobs.clone(), + }; + all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); + } } } @@ -701,38 +728,69 @@ impl JailedStream { /// Check if accumulated content should end jail async fn should_end_jail(&self, accumulated_content: &str) -> (bool, usize) { - // Path 1: End sequence detected - let end_marker_info = if !self.jail_end_sequences.is_empty() { - self.jail_end_sequences.iter().find_map(|seq| { - accumulated_content - .find(seq) - .map(|pos| (pos + seq.len(), seq.clone())) - }) - } else { - None - }; + match &self.jail_mode { + JailMode::MarkerBased => { + // Path 1: End sequence detected + let end_marker_info = if !self.jail_end_sequences.is_empty() { + self.jail_end_sequences.iter().find_map(|seq| { + accumulated_content + .find(seq) + .map(|pos| (pos + seq.len(), seq.clone())) + }) + } else { + None + }; - // Path 2: Complete tool call(s) can be parsed (early exit) - let early_exit = self.should_exit_jail_early(accumulated_content).await; + // Path 2: Complete tool call(s) can be parsed (early exit) + let early_exit = self.should_exit_jail_early(accumulated_content).await; - if let Some((end_pos, _)) = end_marker_info { - (true, end_pos) - } else if early_exit { - // For early exit, find where the complete tool call ends - if let Some(parser) = &self.tool_call_parser { - if let Ok((_, _)) = - try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await - { - let split_pos = find_tool_call_end_position(accumulated_content, Some(parser)); - (true, split_pos) + if let Some((end_pos, _)) = end_marker_info { + (true, end_pos) + } else if early_exit { + // For early exit, find where the complete tool call ends + if let Some(parser) = &self.tool_call_parser { + if let Ok((_, _)) = + try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await + { + let split_pos = + find_tool_call_end_position(accumulated_content, Some(parser)); + (true, split_pos) + } else { + (false, accumulated_content.len()) + } + } else { + (false, accumulated_content.len()) + } } else { (false, accumulated_content.len()) } - } else { - (false, accumulated_content.len()) } - } else { - (false, accumulated_content.len()) + JailMode::Immediate { format } => { + // For tool_choice, check if we have valid complete JSON + match format { + ToolChoiceFormat::SingleObject { .. } => { + // Expect single object: {"location": "Paris", "unit": "celsius"} + if let Ok(value) = + serde_json::from_str::(accumulated_content) + && value.is_object() + { + return (true, accumulated_content.len()); + } + (false, accumulated_content.len()) + } + ToolChoiceFormat::ArrayOfTools => { + // Expect array: [{"name":"search","parameters":{...}}, ...] + if let Ok(value) = + serde_json::from_str::(accumulated_content) + && let Some(arr) = value.as_array() + && !arr.is_empty() + { + return (true, accumulated_content.len()); + } + (false, accumulated_content.len()) + } + } + } } } @@ -744,46 +802,136 @@ impl JailedStream { base_choice: &ChatChoiceStream, tool_call_offset: usize, ) -> ChatChoiceStream { - if let Ok((tool_calls, normal_text)) = - try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref()) + match &self.jail_mode { + JailMode::MarkerBased => { + // Traditional marker-based tool call parsing + if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate( + accumulated_content, + self.tool_call_parser.as_deref(), + ) .await - && !tool_calls.is_empty() - { - // Convert to streaming format - let tool_call_chunks: Vec = tool_calls - .into_iter() - .enumerate() - .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { - index: (tool_call_offset + idx) as u32, - id: Some(tool_call.id), - r#type: Some(tool_call.r#type), - function: Some(FunctionCallStream { - name: Some(tool_call.function.name), - arguments: Some(tool_call.function.arguments), - }), - }) - .collect(); - // Create choice with tool calls - let choice = create_choice_stream( - choice_index, - Some(Role::Assistant), - normal_text.as_deref().unwrap_or(""), - Some(tool_call_chunks), - None, - None, - ); - return choice; + && !tool_calls.is_empty() + { + // Convert to streaming format + let tool_call_chunks: Vec = tool_calls + .into_iter() + .enumerate() + .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { + index: (tool_call_offset + idx) as u32, + id: Some(tool_call.id), + r#type: Some(tool_call.r#type), + function: Some(FunctionCallStream { + name: Some(tool_call.function.name), + arguments: Some(tool_call.function.arguments), + }), + }) + .collect(); + // Create choice with tool calls + let choice = create_choice_stream( + choice_index, + Some(Role::Assistant), + normal_text.as_deref().unwrap_or(""), + Some(tool_call_chunks), + None, + None, + ); + return choice; + } + + // No tool calls found or parsing failed, return content choice + create_choice_stream( + choice_index, + Some(Role::Assistant), + accumulated_content, + None, + base_choice.finish_reason, + base_choice.logprobs.clone(), + ) + } + JailMode::Immediate { format } => { + // tool_choice mode: parse JSON and convert to tool calls + match self.parse_tool_choice_json(accumulated_content, format) { + Ok(tool_call_chunks) if !tool_call_chunks.is_empty() => create_choice_stream( + choice_index, + Some(Role::Assistant), + "", + Some(tool_call_chunks), + base_choice.finish_reason, + base_choice.logprobs.clone(), + ), + Ok(_) | Err(_) => { + // Parsing failed, return as content + create_choice_stream( + choice_index, + Some(Role::Assistant), + accumulated_content, + None, + base_choice.finish_reason, + base_choice.logprobs.clone(), + ) + } + } + } } + } - // No tool calls found or parsing failed, return content choice - create_choice_stream( - choice_index, - Some(Role::Assistant), - accumulated_content, - None, - base_choice.finish_reason, - base_choice.logprobs.clone(), - ) + /// Helper to create a ChatCompletionMessageToolCallChunk + fn create_tool_call_chunk( + index: u32, + name: String, + arguments: String, + ) -> ChatCompletionMessageToolCallChunk { + ChatCompletionMessageToolCallChunk { + index, + id: Some(format!("call-{}", Uuid::new_v4())), + r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(name), + arguments: Some(arguments), + }), + } + } + + /// Parse tool_choice JSON output into tool call chunks + fn parse_tool_choice_json( + &self, + json_content: &str, + format: &ToolChoiceFormat, + ) -> anyhow::Result> { + let parsed = serde_json::from_str::(json_content)?; + + match format { + ToolChoiceFormat::SingleObject { tool_name } => { + // For named tool choice: JSON is the parameters object + if parsed.is_object() { + Ok(vec![Self::create_tool_call_chunk( + 0, + tool_name.clone(), + json_content.to_string(), + )]) + } else { + Ok(vec![]) + } + } + ToolChoiceFormat::ArrayOfTools => { + // For required tool choice: JSON is array of {name, parameters} + if let Some(array) = parsed.as_array() { + let chunks: Vec = array + .iter() + .enumerate() + .filter_map(|(idx, entry)| { + let name = entry.get("name")?.as_str()?.to_string(); + let parameters = entry.get("parameters")?; + let args = serde_json::to_string(parameters).ok()?; + Some(Self::create_tool_call_chunk(idx as u32, name, args)) + }) + .collect(); + Ok(chunks) + } else { + Ok(vec![]) + } + } + } } /// Check if accumulated content contains complete tool calls that can be parsed @@ -804,8 +952,9 @@ impl JailedStream { /// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted /// This should be called after apply() to fix the finish_reason for tool call chunks - pub fn fix_finish_reason( + fn fix_finish_reason( input_stream: S, + jail_mode: JailMode, ) -> impl Stream> + Send where S: Stream> + Send + 'static, @@ -824,13 +973,39 @@ impl JailedStream { } } - // If this chunk has finish_reason and the choice had tool calls, override to ToolCalls + // Fix finish_reason based on jail mode and whether tool calls were emitted if let Some(ref mut data) = response.data { for choice in &mut data.choices { - if choice.finish_reason.is_some() && choice.finish_reason == Some(FinishReason::Stop) - && has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false) - { - choice.finish_reason = Some(FinishReason::ToolCalls); + if let Some(finish) = choice.finish_reason { + // Only modify Stop finish reason, preserve Length/ContentFilter + if finish == FinishReason::Stop { + let has_tool_calls = has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false); + + match &jail_mode { + JailMode::MarkerBased => { + // Traditional: if tool calls emitted, change to ToolCalls + if has_tool_calls { + choice.finish_reason = Some(FinishReason::ToolCalls); + } + } + JailMode::Immediate { format } => { + // tool_choice mode: apply specific finish_reason logic + match format { + ToolChoiceFormat::SingleObject { .. } => { + // Named tool choice: keep Stop + // (already Stop, no change needed) + } + ToolChoiceFormat::ArrayOfTools => { + // Required tool choice: change to ToolCalls + if has_tool_calls { + choice.finish_reason = Some(FinishReason::ToolCalls); + } + } + } + } + } + } + // Length and ContentFilter are preserved as-is } } } @@ -847,6 +1022,7 @@ pub struct JailedStreamBuilder { jail_end_sequences: Vec, tool_call_parser: Option, emission_mode: EmissionMode, + jail_mode: JailMode, } impl JailedStreamBuilder { @@ -857,6 +1033,7 @@ impl JailedStreamBuilder { jail_end_sequences: Vec::new(), tool_call_parser: None, emission_mode: EmissionMode::default(), + jail_mode: JailMode::MarkerBased, } } @@ -916,6 +1093,22 @@ impl JailedStreamBuilder { self } + /// Enable immediate jail mode for tool_choice=named + pub fn tool_choice_named(mut self, tool_name: String) -> Self { + self.jail_mode = JailMode::Immediate { + format: ToolChoiceFormat::SingleObject { tool_name }, + }; + self + } + + /// Enable immediate jail mode for tool_choice=required + pub fn tool_choice_required(mut self) -> Self { + self.jail_mode = JailMode::Immediate { + format: ToolChoiceFormat::ArrayOfTools, + }; + self + } + /// Build the configured JailedStream pub fn build(mut self) -> JailedStream { // Auto-populate jail sequences from parser config if not manually configured @@ -994,6 +1187,7 @@ impl JailedStreamBuilder { tool_call_parser: self.tool_call_parser, emission_mode: self.emission_mode, marker_matcher, + jail_mode: self.jail_mode, } } } diff --git a/lib/llm/src/protocols/openai/common_ext.rs b/lib/llm/src/protocols/openai/common_ext.rs index a77f765ae6b2..51d9ea0a997b 100644 --- a/lib/llm/src/protocols/openai/common_ext.rs +++ b/lib/llm/src/protocols/openai/common_ext.rs @@ -94,7 +94,7 @@ pub trait CommonExtProvider { fn common_ext(&self) -> Option<&CommonExt>; /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value>; + fn get_guided_json(&self) -> Option; fn get_guided_regex(&self) -> Option; fn get_guided_grammar(&self) -> Option; fn get_guided_choice(&self) -> Option>; diff --git a/lib/llm/src/protocols/openai/completions.rs b/lib/llm/src/protocols/openai/completions.rs index 902c41fc25cf..056c2a3a2d82 100644 --- a/lib/llm/src/protocols/openai/completions.rs +++ b/lib/llm/src/protocols/openai/completions.rs @@ -183,8 +183,8 @@ impl CommonExtProvider for NvCreateCompletionRequest { } /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value> { - self.common.guided_json.as_ref() + fn get_guided_json(&self) -> Option { + self.common.guided_json.clone() } fn get_guided_regex(&self) -> Option { diff --git a/lib/llm/src/protocols/openai/tools.rs b/lib/llm/src/protocols/openai/tools.rs new file mode 100644 index 000000000000..457f5b37a99f --- /dev/null +++ b/lib/llm/src/protocols/openai/tools.rs @@ -0,0 +1,404 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::BTreeMap; + +use dynamo_async_openai::types::{ + ChatCompletionTool, ChatCompletionToolChoiceOption, FunctionObject, +}; +use serde_json::{Value, json}; +use thiserror::Error; + +/// Errors that can occur when deriving JSON schemas for tool_choice requests. +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ToolChoiceError { + #[error("tool_choice requires a matching `tools` array")] + MissingTools, + #[error("tool `{0}` was not provided in `tools`")] + ToolNotFound(String), + #[error("$defs for tool `{0}` must be an object")] + InvalidDefinitionMap(String), + #[error("duplicate $defs entry `{0}` has conflicting schemas")] + ConflictingDefinition(String), + #[error("tool_choice `required` needs at least one tool definition")] + EmptyTools, +} + +/// Builds the JSON schema enforced by Guided Decoding for the given tool_choice/tools pair. +pub fn get_json_schema_from_tools( + tool_choice: Option<&ChatCompletionToolChoiceOption>, + tools: Option<&[ChatCompletionTool]>, +) -> Result, ToolChoiceError> { + let Some(choice) = tool_choice else { + return Ok(None); + }; + + match choice { + ChatCompletionToolChoiceOption::None | ChatCompletionToolChoiceOption::Auto => Ok(None), + ChatCompletionToolChoiceOption::Named(named) => { + let tools = tools.ok_or(ToolChoiceError::MissingTools)?; + let tool = find_tool(tools, &named.function.name) + .ok_or_else(|| ToolChoiceError::ToolNotFound(named.function.name.clone()))?; + Ok(Some(clone_parameters(&tool.function))) + } + ChatCompletionToolChoiceOption::Required => { + let tools = tools.ok_or(ToolChoiceError::MissingTools)?; + if tools.is_empty() { + return Err(ToolChoiceError::EmptyTools); + } + build_required_schema(tools).map(Some) + } + } +} + +fn find_tool<'a>(tools: &'a [ChatCompletionTool], name: &str) -> Option<&'a ChatCompletionTool> { + tools.iter().find(|tool| tool.function.name == name) +} + +fn clone_parameters(function: &FunctionObject) -> Value { + function + .parameters + .clone() + .unwrap_or_else(|| json!({"type": "object", "properties": {}})) +} + +/// Builds a JSON Schema for `tool_choice=required` that enforces an array of tool calls. +/// +/// # Schema Structure +/// +/// The generated schema looks like: +/// ```json +/// { +/// "type": "array", +/// "minItems": 1, +/// "items": { +/// "type": "object", +/// "anyOf": [ +/// { +/// "properties": { +/// "name": {"type": "string", "enum": ["tool1"]}, +/// "parameters": { /* tool1's parameter schema */ } +/// }, +/// "required": ["name", "parameters"] +/// }, +/// { +/// "properties": { +/// "name": {"type": "string", "enum": ["tool2"]}, +/// "parameters": { /* tool2's parameter schema */ } +/// }, +/// "required": ["name", "parameters"] +/// } +/// ] +/// }, +/// "$defs": { /* shared type definitions from all tools */ } +/// } +/// ``` +/// +/// # $defs Handling +/// +/// `$defs` contains shared JSON Schema definitions that can be referenced via `$ref`. +/// For example, if two tools reference a common type: +/// ```json +/// { +/// "$defs": { +/// "Location": { +/// "type": "object", +/// "properties": { +/// "city": {"type": "string"}, +/// "country": {"type": "string"} +/// } +/// } +/// } +/// } +/// ``` +/// +/// We extract `$defs` from each tool's schema and merge them into a global `$defs` map +/// at the root level. If multiple tools define the same type, we verify they match to +/// avoid conflicts. +fn build_required_schema(tools: &[ChatCompletionTool]) -> Result { + // Accumulator for all shared type definitions ($defs) across tools + let mut defs: BTreeMap = BTreeMap::new(); + let mut any_of = Vec::with_capacity(tools.len()); + + for tool in tools { + // Extract parameter schema and its $defs (if any) + let ParamsAndDefs { + schema, + defs: new_defs, + } = split_defs(&tool.function)?; + merge_defs(&mut defs, new_defs)?; + any_of.push(json!({ + "properties": { + "name": { + "type": "string", + "enum": [tool.function.name], + }, + "parameters": schema, + }, + "required": ["name", "parameters"], + })); + } + + // Build the top-level array schema with anyOf constraints + let mut result = json!({ + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": any_of, + }, + }); + + // Attach the merged $defs at the root level if any were collected + if !defs.is_empty() + && let Value::Object(map) = &mut result + { + map.insert( + "$defs".to_string(), + Value::Object(defs.into_iter().collect()), + ); + } + + Ok(result) +} + +/// Holds a tool's parameter schema and its extracted $defs (if any). +/// +/// When a tool's parameters reference shared types via `$ref`, those types +/// are defined in a `$defs` section within the schema. We extract them separately +/// to merge into a global definitions map. +struct ParamsAndDefs { + /// The parameter schema with `$defs` removed (if it had one) + schema: Value, + /// Extracted `$defs` map, or None if the schema had no definitions + defs: Option>, +} + +/// Extracts `$defs` from a function's parameter schema, returning both the +/// cleaned schema and the definitions separately. +/// +/// # Example +/// +/// Input schema: +/// ```json +/// { +/// "type": "object", +/// "properties": { +/// "location": {"$ref": "#/$defs/Location"} +/// }, +/// "$defs": { +/// "Location": { +/// "type": "object", +/// "properties": {"city": {"type": "string"}} +/// } +/// } +/// } +/// ``` +/// +/// Returns: +/// - schema: same as input but with `$defs` removed +/// - defs: `Some({"Location": {...}})` +fn split_defs(function: &FunctionObject) -> Result { + let mut schema = clone_parameters(function); + let defs = match &mut schema { + Value::Object(obj) => { + if let Some(value) = obj.remove("$defs") { + Some(convert_defs(function, value)?) + } else { + None + } + } + _ => None, + }; + + Ok(ParamsAndDefs { schema, defs }) +} + +fn convert_defs( + function: &FunctionObject, + defs_value: Value, +) -> Result, ToolChoiceError> { + match defs_value { + Value::Object(map) => Ok(map.into_iter().collect()), + _ => Err(ToolChoiceError::InvalidDefinitionMap(function.name.clone())), + } +} + +/// Merges definitions from one tool into the global `$defs` accumulator. +/// +/// # Conflict Detection +/// +/// If two tools define the same type name but with different schemas, we return +/// an error. This ensures consistency across tool definitions. +/// +/// # Example +/// +/// If `target` contains: +/// ```json +/// {"Location": {"type": "object", "properties": {"city": {"type": "string"}}}} +/// ``` +/// +/// And we try to merge: +/// ```json +/// {"Location": {"type": "object", "properties": {"city": {"type": "number"}}}} +/// ``` +/// +/// This will return `ToolChoiceError::ConflictingDefinition("Location")`. +fn merge_defs( + target: &mut BTreeMap, + defs: Option>, +) -> Result<(), ToolChoiceError> { + let Some(defs) = defs else { + return Ok(()); + }; + + for (name, schema) in defs { + if let Some(existing) = target.get(&name) { + if existing != &schema { + return Err(ToolChoiceError::ConflictingDefinition(name)); + } + } else { + target.insert(name, schema); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, ChatCompletionToolType}; + + fn sample_tools() -> Vec { + vec![ + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "add_numbers".to_string(), + description: Some("Add two integers".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + })), + strict: None, + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + })), + strict: None, + }, + }, + ] + } + + #[test] + fn named_choice_returns_parameters() { + let tools = sample_tools(); + let tool_choice = ChatCompletionToolChoiceOption::Named( + dynamo_async_openai::types::ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: dynamo_async_openai::types::FunctionName { + name: "get_weather".to_string(), + }, + }, + ); + let schema = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).expect("schema"); + + assert_eq!( + schema.unwrap(), + json!({ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + }) + ); + } + + #[test] + fn required_choice_builds_any_of_schema() { + let tools = sample_tools(); + let schema = get_json_schema_from_tools( + Some(&ChatCompletionToolChoiceOption::Required), + Some(&tools), + ) + .expect("schema"); + + let schema = schema.expect("required schema"); + assert_eq!(schema["type"], "array"); + assert_eq!(schema["minItems"], 1); + assert!(schema["items"]["anyOf"].is_array()); + + let any_of = schema["items"]["anyOf"].as_array().unwrap(); + assert_eq!(any_of.len(), 2); + assert_eq!( + any_of[0]["properties"]["name"], + json!({"type": "string", "enum": ["add_numbers"]}) + ); + } + + #[test] + fn missing_tool_errors() { + let tools = sample_tools(); + let tool_choice = ChatCompletionToolChoiceOption::Named( + dynamo_async_openai::types::ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: dynamo_async_openai::types::FunctionName { + name: "unknown".to_string(), + }, + }, + ); + let err = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).unwrap_err(); + assert_eq!(err, ToolChoiceError::ToolNotFound("unknown".to_string())); + } + + #[test] + fn conflicting_defs_errors() { + let tool = ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "foo".to_string(), + description: None, + parameters: Some(json!({ + "type": "object", + "$defs": { + "shared": {"type": "string"} + } + })), + strict: None, + }, + }; + + let mut tool_with_conflict = tool.clone(); + tool_with_conflict.function.parameters = Some(json!({ + "type": "object", + "$defs": { + "shared": {"type": "number"} + } + })); + + let tools = vec![tool, tool_with_conflict]; + let err = build_required_schema(&tools).unwrap_err(); + assert_eq!( + err, + ToolChoiceError::ConflictingDefinition("shared".to_string()) + ); + } +} diff --git a/lib/llm/tests/test_common_ext.rs b/lib/llm/tests/test_common_ext.rs index 933e486a86cb..149a20549182 100644 --- a/lib/llm/tests/test_common_ext.rs +++ b/lib/llm/tests/test_common_ext.rs @@ -92,7 +92,7 @@ fn test_chat_completions_guided_decoding_from_common() { ); assert_eq!( request.get_guided_json(), - Some(&serde_json::json!({"key": "value"})) + Some(serde_json::json!({"key": "value"})) ); // Test guided_regex can be specified at root level diff --git a/lib/llm/tests/test_reasoning_parser.rs b/lib/llm/tests/test_reasoning_parser.rs index 190fd9badbba..19a0ec328ac2 100644 --- a/lib/llm/tests/test_reasoning_parser.rs +++ b/lib/llm/tests/test_reasoning_parser.rs @@ -484,7 +484,8 @@ mod tests { // Step 2: Apply tool calling jail transformation let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( - "nemotron_deci".to_string(), + Some("nemotron_deci".to_string()), + None, // No tool_choice in this test reasoning_parsed_stream, ); @@ -596,7 +597,8 @@ mod tests { let reasoning_parsed_stream = stream::iter(debug_chunks); let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( - "harmony".to_string(), + Some("harmony".to_string()), + None, // No tool_choice in this test reasoning_parsed_stream, ); diff --git a/lib/llm/tests/test_streaming_tool_parsers.rs b/lib/llm/tests/test_streaming_tool_parsers.rs index 03392d2fcb73..c2214e707b81 100644 --- a/lib/llm/tests/test_streaming_tool_parsers.rs +++ b/lib/llm/tests/test_streaming_tool_parsers.rs @@ -158,7 +158,8 @@ async fn parse_response_stream( > = if tool_parse_enable { if let Some(tool_parser) = tool_parser_str { Box::pin(OpenAIPreprocessor::apply_tool_calling_jail( - tool_parser, + Some(tool_parser), + None, // No tool_choice in this test stream, )) } else { diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs new file mode 100644 index 000000000000..c970108d9b59 --- /dev/null +++ b/lib/llm/tests/tool_choice.rs @@ -0,0 +1,436 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use dynamo_async_openai::types::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, + ChatCompletionToolType, CreateChatCompletionRequest, FunctionName, +}; +use dynamo_llm::protocols::common; +use dynamo_llm::protocols::common::llm_backend::BackendOutput; +use dynamo_llm::protocols::openai::DeltaGeneratorExt; +use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; + +fn create_test_request() -> NvCreateChatCompletionRequest { + let messages = vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text("test".to_string()), + name: None, + }, + )]; + + NvCreateChatCompletionRequest { + inner: CreateChatCompletionRequest { + model: "test-model".to_string(), + messages, + stream: Some(false), + stream_options: None, + ..Default::default() + }, + common: Default::default(), + nvext: None, + chat_template_args: None, + unsupported_fields: Default::default(), + } +} + +async fn apply_jail_transformation( + raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + tool_choice: Option, +) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; + + let input_stream = stream::iter(vec![Annotated { + data: Some(raw_response), + id: None, + event: None, + comment: None, + }]); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream.next().await.unwrap().data.unwrap() +} + +async fn apply_jail_transformation_streaming( + raw_responses: Vec< + dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + >, + tool_choice: Option, +) -> Vec { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; + + let input_stream = stream::iter(raw_responses.into_iter().map(|r| Annotated { + data: Some(r), + id: None, + event: None, + comment: None, + })); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream + .filter_map(|ann| async move { ann.data }) + .collect() + .await +} + +fn build_backend_output(text: &str) -> BackendOutput { + BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(text.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: Some(common::FinishReason::Stop), + index: Some(0), + completion_usage: None, + disaggregated_params: None, + } +} + +#[tokio::test] +async fn test_named_tool_choice_parses_json() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-1".to_string()); + let backend_output = build_backend_output(r#"{"location":"Paris"}"#); + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + let choice = &response.choices[0]; + + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop) + ); + let delta = &choice.delta; + assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); + let tool_calls = delta.tool_calls.as_ref().unwrap(); + + assert_eq!(tool_calls.len(), 1); + + let tool_call = &tool_calls[0]; + assert_eq!(tool_call.index, 0); + assert!(tool_call.id.as_ref().unwrap().starts_with("call-")); + assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function)); + assert_eq!( + tool_call.function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + assert_eq!( + tool_call.function.as_ref().unwrap().arguments.as_deref(), + Some(r#"{"location":"Paris"}"#) + ); +} + +#[tokio::test] +async fn test_required_tool_choice_parses_json_array() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-2".to_string()); + let backend_output = build_backend_output( + r#"[{"name":"search","parameters":{"query":"rust"}}, + {"name":"summarize","parameters":{"topic":"memory"}}]"#, + ); + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + let choice = &response.choices[0]; + + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + ); + let delta = &choice.delta; + assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); + let tool_calls = delta.tool_calls.as_ref().unwrap(); + + assert_eq!(tool_calls.len(), 2); + + assert_eq!(tool_calls[0].index, 0); + assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-")); + assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function)); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("search") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"query":"rust"}"#) + ); + + assert_eq!(tool_calls[1].index, 1); + assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-")); + assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function)); + assert_eq!( + tool_calls[1].function.as_ref().unwrap().name.as_deref(), + Some("summarize") + ); + assert_eq!( + tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"topic":"memory"}"#) + ); +} + +#[tokio::test] +async fn test_tool_choice_parse_failure_returns_as_content() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-3".to_string()); + let backend_output = build_backend_output("not-json"); + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + let delta = &response.choices[0].delta; + + // Jail stream behavior: if parsing fails, return accumulated content as-is + // This matches marker-based FC behavior + assert_eq!(delta.content.as_deref(), Some("not-json")); + assert!(delta.tool_calls.is_none()); +} + +#[tokio::test] +async fn test_streaming_named_tool_buffers_until_finish() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-stream-1".to_string()); + + let chunks = [r#"{"location":""#, r#"Paris","unit":""#, r#"celsius"}"#]; + + let mut raw_responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("streaming chunk"); + raw_responses.push(response); + } + + let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; + + // Jail stream buffers content until valid JSON, then emits once + assert_eq!(all_responses.len(), 1); + + let response = &all_responses[0]; + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop) + ); + + let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"location":"Paris","unit":"celsius"}"#) + ); +} + +#[tokio::test] +async fn test_streaming_required_tool_parallel() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-stream-2".to_string()); + + let chunks = [ + r#"[{"name":"search","parameters":{"query":"rust"}},"#, + r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#, + ]; + + let mut raw_responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("streaming chunk"); + raw_responses.push(response); + } + + let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; + + // Jail stream buffers until complete JSON array + assert_eq!(all_responses.len(), 1); + + let response = &all_responses[0]; + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + ); + + let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 2); + + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("search") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"query":"rust"}"#) + ); + + assert_eq!( + tool_calls[1].function.as_ref().unwrap().name.as_deref(), + Some("summarize") + ); + assert_eq!( + tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"topic":"memory"}"#) + ); +} + +#[test] +fn test_no_tool_choice_outputs_normal_text() { + let request = create_test_request(); + + let mut generator = request.response_generator("req-stream-4".to_string()); + + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some("Hello world".to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: None, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("normal text"); + + assert_eq!( + response.choices[0].delta.content.as_deref(), + Some("Hello world") + ); + assert!(response.choices[0].delta.tool_calls.is_none()); +} diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs new file mode 100644 index 000000000000..07f28d59626a --- /dev/null +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -0,0 +1,250 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Tests for tool_choice finish_reason handling. + +use dynamo_async_openai::types::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, + ChatCompletionToolType, CreateChatCompletionRequest, FunctionName, +}; +use dynamo_llm::protocols::common; +use dynamo_llm::protocols::common::llm_backend::BackendOutput; +use dynamo_llm::protocols::openai::DeltaGeneratorExt; +use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; + +fn create_test_request() -> NvCreateChatCompletionRequest { + let messages = vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text("test".to_string()), + name: None, + }, + )]; + + NvCreateChatCompletionRequest { + inner: CreateChatCompletionRequest { + model: "test-model".to_string(), + messages, + stream: Some(false), + stream_options: None, + ..Default::default() + }, + common: Default::default(), + nvext: None, + chat_template_args: None, + unsupported_fields: Default::default(), + } +} + +fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> BackendOutput { + BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(text.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: Some(finish), + index: Some(0), + completion_usage: None, + disaggregated_params: None, + } +} + +async fn apply_jail_transformation( + raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + tool_choice: Option, +) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; + + let input_stream = stream::iter(vec![Annotated { + data: Some(raw_response), + id: None, + event: None, + comment: None, + }]); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream.next().await.unwrap().data.unwrap() +} + +#[tokio::test] +async fn test_named_tool_choice_preserves_length_finish_reason() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-length-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"location":"Par"#, // Incomplete due to length limit + common::FinishReason::Length, + ); + + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + + // Critical: Length finish reason should be preserved, NOT replaced with Stop + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Length), + "Length finish reason must be preserved for tool_choice=named" + ); +} + +#[test] +fn test_required_tool_choice_preserves_length_finish_reason() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-length-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"search","parameters":{"query":"incomplete"#, // Incomplete due to length + common::FinishReason::Length, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: Length finish reason should be preserved, NOT replaced with ToolCalls + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Length), + "Length finish reason must be preserved for tool_choice=required" + ); +} + +#[test] +fn test_named_tool_choice_preserves_content_filter() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "search".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-filter-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"query":"filtered content"#, + common::FinishReason::ContentFilter, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: ContentFilter finish reason should be preserved + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ContentFilter), + "ContentFilter finish reason must be preserved for tool_choice=named" + ); +} + +#[test] +fn test_required_tool_choice_preserves_content_filter() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-filter-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"harmful_action"#, + common::FinishReason::ContentFilter, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: ContentFilter finish reason should be preserved + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ContentFilter), + "ContentFilter finish reason must be preserved for tool_choice=required" + ); +} + +#[test] +fn test_named_tool_choice_normal_stop_becomes_stop() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-stop-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"location":"Paris","unit":"celsius"}"#, + common::FinishReason::Stop, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Normal completion: Stop should remain Stop for named tool choice + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop), + ); +} + +#[tokio::test] +async fn test_required_tool_choice_normal_stop_becomes_tool_calls() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-stop-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"search","parameters":{"query":"rust"}}]"#, + common::FinishReason::Stop, + ); + + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + + // Normal completion: Stop should become ToolCalls for required tool choice + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls), + ); +}