Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -752,25 +752,17 @@ impl OpenAIPreprocessor {
has_tools: bool,
) -> std::result::Result<bool, Error> {
match (tool_call_parser, tool_choice, has_tools) {
// No parser but tools requested - error cases
(None, Some(ChatCompletionToolChoiceOption::Required), true) => {
tracing::warn!(
"Tool choice 'required' specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
}
// tool_choice=required/named work without parser (use Immediate jail mode)
(None, Some(ChatCompletionToolChoiceOption::Required), true) => Ok(true),
(None, Some(ChatCompletionToolChoiceOption::Named(_)), true) => Ok(true),

// tool_choice=auto requires a parser
(None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
tracing::warn!(
"Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
}
(None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => {
tracing::warn!(
"Named tool choice specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
}

// Parser exists and tools might be called
(Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
Expand All @@ -786,15 +778,38 @@ impl OpenAIPreprocessor {

/// Apply tool calling jail to the stream if needed
pub fn apply_tool_calling_jail<S>(
tool_call_parser: String,
tool_call_parser: Option<String>,
tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>,
stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
let jail = JailedStream::builder()
.tool_call_parser(tool_call_parser)
.build();
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;

let mut builder = JailedStream::builder();

// Configure jail based on tool_choice
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => {
// Immediate jail mode for named tool choice
builder = builder.tool_choice_named(named.function.name.clone());
}
Some(ChatCompletionToolChoiceOption::Required) => {
// Immediate jail mode for required tool choice
builder = builder.tool_choice_required();
}
Some(ChatCompletionToolChoiceOption::Auto)
| Some(ChatCompletionToolChoiceOption::None)
| None => {
// Traditional marker-based jail for auto/none/unspecified
if let Some(parser) = tool_call_parser {
builder = builder.tool_call_parser(parser);
}
}
}

let jail = builder.build();
jail.apply_with_finish_reason(stream)
}

Expand Down Expand Up @@ -957,11 +972,11 @@ impl

// Apply jail conditionally
let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
if let Some(parser) = self.tool_call_parser.clone() {
Box::pin(Self::apply_tool_calling_jail(parser, stream))
} else {
Box::pin(stream) // Should not happen due to should_jail check
}
Box::pin(Self::apply_tool_calling_jail(
self.tool_call_parser.clone(),
request.inner.tool_choice.clone(),
stream,
))
} else {
Box::pin(stream)
};
Expand Down
3 changes: 2 additions & 1 deletion lib/llm/src/protocols/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod embeddings;
pub mod models;
pub mod nvext;
pub mod responses;
pub mod tools;
pub mod validate;

use validate::{
Expand Down Expand Up @@ -131,7 +132,7 @@ impl<T: OpenAISamplingOptionsProvider + CommonExtProvider> SamplingOptionsProvid
let guided_whitespace_pattern = self.get_guided_whitespace_pattern();

let guided_decoding = match common::GuidedDecodingOptions::from_optional(
guided_json.cloned(),
guided_json,
guided_regex,
guided_choice,
guided_grammar,
Expand Down
22 changes: 19 additions & 3 deletions lib/llm/src/protocols/openai/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::{
common_ext::{CommonExt, CommonExtProvider},
nvext::NvExt,
nvext::NvExtProvider,
validate,
tools, validate,
};

pub mod aggregator;
Expand Down Expand Up @@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest {
}

/// Guided Decoding Options
fn get_guided_json(&self) -> Option<&serde_json::Value> {
self.common.guided_json.as_ref()
fn get_guided_json(&self) -> Option<serde_json::Value> {
if let Some(value) = self.common.guided_json.clone() {
return Some(value);
}

let tool_choice = self.inner.tool_choice.as_ref()?;
let tools = self.inner.tools.as_deref()?;

match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) {
Ok(schema) => schema,
Err(err) => {
tracing::warn!(
error = %err,
"failed to derive guided_json from tool_choice"
);
None
}
}
}

fn get_guided_regex(&self) -> Option<String> {
Expand Down
Loading
Loading