Skip to content
29 changes: 25 additions & 4 deletions lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,14 +1108,35 @@ impl OpenAIPreprocessor {
}

// Configure jail based on tool_choice
//
// When a tool_call_parser is configured, always use marker-based mode
// so that format-specific parsers (e.g. qwen3_coder XML) are invoked.
// Immediate JSON mode is only a fallback for required/named when no
// parser exists (the model is expected to emit raw JSON in that case).
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => {
// Immediate jail mode for named tool choice
builder = builder.tool_choice_named(named.function.name.clone());
if let Some(parser) = tool_call_parser {
// Parser-aware path: use marker-based jail so the parser
// handles format-specific output (XML, pythonic, etc.).
// Also install a named-tool filter so that if the model emits
// the wrong tool, the parsed call is rejected before emission.
builder = builder
.tool_call_parser(parser)
.named_tool_filter(named.function.name.clone());
} else {
// No parser: fall back to Immediate JSON jail mode.
builder = builder.tool_choice_named(named.function.name.clone());
}
}
Some(ChatCompletionToolChoiceOption::Required) => {
// Immediate jail mode for required tool choice
builder = builder.tool_choice_required();
if let Some(parser) = tool_call_parser {
// Parser-aware path: use marker-based jail so the parser
// handles format-specific output (XML, pythonic, etc.).
builder = builder.tool_call_parser(parser);
} else {
// No parser: fall back to Immediate JSON jail mode.
builder = builder.tool_choice_required();
}
}
Some(ChatCompletionToolChoiceOption::Auto)
| Some(ChatCompletionToolChoiceOption::None)
Expand Down
55 changes: 52 additions & 3 deletions lib/llm/src/protocols/openai/chat_completions/jail.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,9 @@ pub struct JailedStream {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
/// When set, only tool calls with this name are emitted (enforces tool_choice=named
/// when a tool_call_parser is active and the parser-aware MarkerBased path is used).
named_tool_name: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode,
marker_matcher: MarkerMatcher,
Expand All @@ -492,8 +495,9 @@ impl JailedStream {
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
let jail_mode = self.jail_mode.clone();
let named_tool_active = self.named_tool_name.is_some();
let jailed_stream = self.apply(stream);
JailedStream::fix_finish_reason(jailed_stream, jail_mode)
JailedStream::fix_finish_reason(jailed_stream, jail_mode, named_tool_active)
}

/// Apply the jail transformation to a stream of chat completion responses
Expand Down Expand Up @@ -856,6 +860,37 @@ impl JailedStream {
if let Ok((tool_calls, normal_text)) = parse_result
&& !tool_calls.is_empty()
{
// If a named tool filter is set (tool_choice=named + parser path), reject
// tool calls that don't match the required tool name.
let tool_calls = if let Some(ref required_name) = self.named_tool_name {
let filtered: Vec<_> = tool_calls
.into_iter()
.filter(|tc| tc.function.name == *required_name)
.collect();
if filtered.is_empty() {
tracing::warn!(
required = %required_name,
"tool_choice=named: parser emitted no matching tool calls; dropping jail output"
);
}
filtered
} else {
tool_calls
};

if tool_calls.is_empty() {
// All parsed calls were for the wrong tool — return content choice
return create_choice_stream(
choice_index,
Some(Role::Assistant),
accumulated_content,
None,
base_choice.finish_reason,
base_choice.stop_reason.clone(),
base_choice.logprobs.clone(),
);
}

// Convert to streaming format
let tool_call_chunks: Vec<ChatCompletionMessageToolCallChunk> = tool_calls
.into_iter()
Expand Down Expand Up @@ -1004,6 +1039,7 @@ impl JailedStream {
fn fix_finish_reason<S>(
input_stream: S,
jail_mode: JailMode,
named_tool_active: bool,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
Expand Down Expand Up @@ -1032,10 +1068,10 @@ impl JailedStream {

match &jail_mode {
JailMode::MarkerBased => {
// Traditional: if tool calls emitted, change to ToolCalls
if has_tool_calls {
if has_tool_calls && !named_tool_active {
choice.finish_reason = Some(FinishReason::ToolCalls);
}
// When named_tool_active, keep Stop (OpenAI spec for tool_choice=named)
}
JailMode::Immediate { format } => {
// tool_choice mode: apply specific finish_reason logic
Expand Down Expand Up @@ -1070,6 +1106,9 @@ pub struct JailedStreamBuilder {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
/// When set, only tool calls with this name are emitted (enforces tool_choice=named
/// when a tool_call_parser is active and the parser-aware MarkerBased path is used).
named_tool_name: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode,
jail_mode: JailMode,
Expand All @@ -1082,6 +1121,7 @@ impl JailedStreamBuilder {
jail_start_sequences: Vec::new(),
jail_end_sequences: Vec::new(),
tool_call_parser: None,
named_tool_name: None,
tool_definitions: None,
emission_mode: EmissionMode::default(),
jail_mode: JailMode::MarkerBased,
Expand Down Expand Up @@ -1126,6 +1166,14 @@ impl JailedStreamBuilder {
self
}

/// Constrain parsed output to a single named tool (for tool_choice=named + parser path).
/// When set, tool calls emitted by the parser that don't match `tool_name` are silently
/// filtered out, enforcing the named-tool contract even when the model emits the wrong tool.
pub fn named_tool_filter(mut self, tool_name: impl Into<String>) -> Self {
self.named_tool_name = Some(tool_name.into());
self
}

/// Set the tool definitions for runtime validation and parsing
pub fn tool_definitions(
mut self,
Expand Down Expand Up @@ -1245,6 +1293,7 @@ impl JailedStreamBuilder {
jail_start_sequences: self.jail_start_sequences,
jail_end_sequences: self.jail_end_sequences,
tool_call_parser: self.tool_call_parser,
named_tool_name: self.named_tool_name,
tool_definitions: self.tool_definitions,
emission_mode: self.emission_mode,
marker_matcher,
Expand Down
191 changes: 190 additions & 1 deletion lib/llm/tests/test_jail.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_protocols::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, CompletionUsage, FinishReason, Role,
ChatChoiceStream, ChatCompletionStreamResponseDelta, ChatCompletionToolChoiceOption,
CompletionUsage, FinishReason, Role,
};
use dynamo_runtime::protocols::annotated::Annotated;

Expand Down Expand Up @@ -3080,4 +3082,191 @@ mod parallel_jail_tests {
"Should have no tool calls for empty array"
);
}

/// Regression test for #6821: tool_choice=required with qwen3_coder parser.
///
/// When tool_choice=required AND a tool_call_parser (e.g. qwen3_coder) is
/// configured, the jail must use marker-based mode so the parser handles the
/// XML output. Previously this fell through to Immediate JSON mode which
/// could not parse qwen3_coder XML, causing raw XML to leak as content.
#[tokio::test]
async fn test_tool_choice_required_with_qwen3_coder_parser() {
// Simulate qwen3_coder XML output for a single tool call
let xml_output = r#"<tool_call>
<function=get_weather>
<parameter=city>
San Francisco
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"#;

let input_chunks = vec![test_utils::create_mock_response_chunk(
xml_output.to_string(),
0,
)];

let input_stream = stream::iter(input_chunks);
let results: Vec<_> = OpenAIPreprocessor::apply_tool_calling_jail(
Some("qwen3_coder".to_string()),
Some(ChatCompletionToolChoiceOption::Required),
None,
input_stream,
)
.collect()
.await;

// Should have parsed a tool call, not leaked raw XML as content
let tool_call_count: usize = results
.iter()
.map(|r| {
r.data.as_ref().map_or(0, |d| {
d.inner
.choices
.iter()
.map(|c: &ChatChoiceStream| {
c.delta.tool_calls.as_ref().map_or(0, |tc| tc.len())
})
.sum::<usize>()
})
})
.sum();

assert!(
tool_call_count >= 1,
"tool_choice=required with qwen3_coder should produce at least one tool call, got {}",
tool_call_count
);

// Verify the tool call was parsed correctly
for r in &results {
if let Some(data) = &r.data {
for choice in &data.inner.choices {
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
assert_eq!(
tc.function.as_ref().unwrap().name.as_deref(),
Some("get_weather"),
"Tool call name should be 'get_weather'"
);
}
}
// Content should be empty, not raw XML
if let Some(content) = &choice.delta.content {
let text = test_utils::extract_text(content);
assert!(
!text.contains("<tool_call>"),
"Raw XML should not leak as content, got: {}",
text
);
}
}
}
}
}

/// Test for tool_choice=named with qwen3_coder parser and named_tool_filter.
///
/// When tool_choice=named is used with a specific tool_name, the
/// preprocessor decision logic should apply the named_tool_filter to ensure
/// only the requested tool is parsed, even if the model emits other tools.
#[tokio::test]
async fn test_tool_choice_named_with_qwen3_coder_parser() {
// Simulate qwen3_coder XML output for a single tool call
let xml_output = r#"<tool_call>
<function=get_weather>
<parameter=city>
San Francisco
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"#;

let input_chunks = vec![
test_utils::create_mock_response_chunk(xml_output.to_string(), 0),
test_utils::create_final_response_chunk(0),
];

let input_stream = stream::iter(input_chunks);

// Apply tool_choice=named for get_weather
let results: Vec<_> = OpenAIPreprocessor::apply_tool_calling_jail(
Some("qwen3_coder".to_string()),
Some(ChatCompletionToolChoiceOption::Named(
"get_weather".to_string().into(),
)),
None,
input_stream,
)
.collect()
.await;

// Should have parsed the named tool call
let tool_call_count: usize = results
.iter()
.map(|r| {
r.data.as_ref().map_or(0, |d| {
d.inner
.choices
.iter()
.map(|c: &ChatChoiceStream| {
c.delta.tool_calls.as_ref().map_or(0, |tc| tc.len())
})
.sum::<usize>()
})
})
.sum();

assert!(
tool_call_count >= 1,
"tool_choice=named with qwen3_coder should produce at least one tool call, got {}",
tool_call_count
);

// Verify the tool call was parsed correctly and matches the named tool
for r in &results {
if let Some(data) = &r.data {
for choice in &data.inner.choices {
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
assert_eq!(
tc.function.as_ref().unwrap().name.as_deref(),
Some("get_weather"),
"Tool call name should match the named tool choice"
);
}
}
// Content should be empty, not raw XML
if let Some(content) = &choice.delta.content {
let text = test_utils::extract_text(content);
assert!(
!text.contains("<tool_call>"),
"Raw XML should not leak as content, got: {}",
text
);
}
}
}
}

// Verify finish_reason is Stop (not ToolCalls) for named tool choice
let finish_reasons: Vec<_> = results
.iter()
.filter_map(|r| {
r.data
.as_ref()
.and_then(|d| d.inner.choices.first().and_then(|c| c.finish_reason))
})
.collect();

// For tool_choice=named, finish_reason should be Stop (OpenAI spec)
assert!(
finish_reasons.contains(&FinishReason::Stop),
"tool_choice=named should have Stop finish reason"
);
}
}
Loading
Loading