Skip to content
Merged
20 changes: 19 additions & 1 deletion components/backends/vllm/src/dynamo/vllm/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class Config:
# Connector list from CLI
connector_list: Optional[list] = None

# tool and reasoning parser info
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None


def parse_args() -> Config:
parser = FlexibleArgumentParser(
Expand Down Expand Up @@ -102,6 +106,19 @@ def parse_args() -> Config:
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
)
# To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
help="Reasoning parser name for the model.",
)

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -151,7 +168,8 @@ def parse_args() -> Config:
config.port_range = DynamoPortRange(
min=args.dynamo_port_min, max=args.dynamo_port_max
)

config.tool_call_parser = args.dyn_tool_call_parser
config.reasoning_parser = args.dyn_reasoning_parser
# Check for conflicting flags
has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config")
Expand Down
2 changes: 2 additions & 0 deletions components/backends/vllm/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config):
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser

await register_llm(
ModelType.Backend,
Expand Down
20 changes: 20 additions & 0 deletions lib/bindings/python/rust/llm/local_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ impl ModelRuntimeConfig {
self.inner.max_num_batched_tokens = Some(max_num_batched_tokens);
}

#[setter]
fn set_tool_call_parser(&mut self, tool_call_parser: Option<String>) {
self.inner.tool_call_parser = tool_call_parser;
}

#[setter]
fn set_reasoning_parser(&mut self, reasoning_parser: Option<String>) {
self.inner.reasoning_parser = reasoning_parser;
}

fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
Expand All @@ -57,6 +67,16 @@ impl ModelRuntimeConfig {
self.inner.max_num_batched_tokens
}

#[getter]
fn tool_call_parser(&self) -> Option<String> {
self.inner.tool_call_parser.clone()
}

#[getter]
fn reasoning_parser(&self) -> Option<String> {
self.inner.reasoning_parser.clone()
}

#[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py);
Expand Down
12 changes: 12 additions & 0 deletions lib/llm/src/discovery/model_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,18 @@ impl ModelManager {
.insert(model_name.to_string(), new_kv_chooser.clone());
Ok(new_kv_chooser)
}

pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
match self.entries.lock() {
Ok(entries) => entries
.values()
.find(|entry| entry.name == model)
.and_then(|entry| entry.runtime_config.as_ref())
.and_then(|config| config.tool_call_parser.clone())
.map(|parser| parser.to_string()),
Err(_) => None,
}
}
}

pub struct ModelEngines<E> {
Expand Down
70 changes: 43 additions & 27 deletions lib/llm/src/http/service/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::protocols::openai::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
responses::{NvCreateResponse, NvResponse},
ParsingOptions,
};
use crate::request_template::RequestTemplate;
use crate::types::Annotated;
Expand Down Expand Up @@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin
uuid.to_string()
}

fn get_parsing_options(state: &Arc<service_v2::State>, model: &str) -> ParsingOptions {
let tool_call_parser = state.manager().get_model_tool_call_parser(model);
let reasoning_parser = None; // TODO: Implement reasoning parser

ParsingOptions::new(tool_call_parser, reasoning_parser)
}

/// OpenAI Completions Request Handler
///
/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
Expand Down Expand Up @@ -267,6 +275,8 @@ async fn completions(
.get_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?;

let parsing_options = get_parsing_options(&state, model);

let mut inflight_guard =
state
.metrics_clone()
Expand Down Expand Up @@ -325,7 +335,7 @@ async fn completions(
process_metrics_only(response, &mut response_collector);
});

let response = NvCreateCompletionResponse::from_annotated_stream(stream)
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
.await
.map_err(|e| {
tracing::error!(
Expand Down Expand Up @@ -494,6 +504,8 @@ async fn chat_completions(
.get_chat_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?;

let parsing_options = get_parsing_options(&state, model);

let mut inflight_guard =
state
.metrics_clone()
Expand Down Expand Up @@ -553,19 +565,20 @@ async fn chat_completions(
process_metrics_only(response, &mut response_collector);
});

let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;

inflight_guard.mark_ok();
Ok(Json(response).into_response())
Expand Down Expand Up @@ -726,6 +739,8 @@ async fn responses(
.get_chat_completions_engine(model)
.map_err(|_| ErrorMessage::model_not_found())?;

let parsing_options = get_parsing_options(&state, model);

let mut inflight_guard =
state
.metrics_clone()
Expand All @@ -742,19 +757,20 @@ async fn responses(
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;

// TODO: handle streaming, currently just unary
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream)
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;
let response =
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone())
.await
.map_err(|e| {
tracing::error!(
request_id,
"Failed to fold chat completions stream for: {:?}",
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold chat completions stream: {}",
e
))
})?;

// Convert NvCreateChatCompletionResponse --> NvResponse
let response: NvResponse = response.try_into().map_err(|e| {
Expand Down
2 changes: 2 additions & 0 deletions lib/llm/src/local_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ impl LocalModelBuilder {
);
card.migration_limit = self.migration_limit;
card.user_data = self.user_data.take();

return Ok(LocalModel {
card,
full_path: PathBuf::new(),
Expand Down Expand Up @@ -392,6 +393,7 @@ impl LocalModel {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string();

card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?;
Expand Down
4 changes: 4 additions & 0 deletions lib/llm/src/local_model/runtime_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig {

pub max_num_batched_tokens: Option<u64>,

pub tool_call_parser: Option<String>,

pub reasoning_parser: Option<String>,

/// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
Expand Down
1 change: 0 additions & 1 deletion lib/llm/src/preprocessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ impl OpenAIPreprocessor {
let mdcsum = mdc.mdcsum();
let formatter = PromptFormatter::from_mdc(mdc.clone()).await?;
let PromptFormatter::OAI(formatter) = formatter;

let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?,
Some(TokenizerKind::GGUF(tokenizer)) => {
Expand Down
16 changes: 16 additions & 0 deletions lib/llm/src/protocols/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Gets the current prompt token count (Input Sequence Length).
fn get_isl(&self) -> Option<u32>;
}

#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct ParsingOptions {
pub tool_call_parser: Option<String>,

pub reasoning_parser: Option<String>,
}

impl ParsingOptions {
pub fn new(tool_call_parser: Option<String>, reasoning_parser: Option<String>) -> Self {
Self {
tool_call_parser,
reasoning_parser,
}
}
}
28 changes: 19 additions & 9 deletions lib/llm/src/protocols/openai/chat_completions/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use std::collections::HashMap;
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{
codec::{Message, SseCodecError},
convert_sse_stream, Annotated,
convert_sse_stream,
openai::ParsingOptions,
Annotated,
};

use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
Expand Down Expand Up @@ -99,6 +101,7 @@ impl DeltaAggregator {
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
Expand Down Expand Up @@ -175,7 +178,10 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() {
if let Ok(tool_calls) = try_tool_call_parse_aggregate(&choice.text, None) {
if let Ok(tool_calls) = try_tool_call_parse_aggregate(
&choice.text,
parsing_options.tool_call_parser.as_deref(),
) {
if tool_calls.is_empty() {
continue;
}
Expand Down Expand Up @@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String>;

/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
Expand All @@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator {
/// * `Err(String)` if an error occurs.
async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String>;
}

impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse {
async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await
DeltaAggregator::apply(stream, parsing_options).await
}

async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
parsing_options: ParsingOptions,
) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options).await
}
}

Expand Down Expand Up @@ -347,7 +357,7 @@ mod tests {
Box::pin(stream::empty());

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -377,7 +387,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -421,7 +431,7 @@ mod tests {
let stream = Box::pin(stream::iter(annotated_deltas));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -492,7 +502,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;

// Check the result
assert!(result.is_ok());
Expand Down Expand Up @@ -550,7 +560,7 @@ mod tests {
let stream = Box::pin(stream::iter(vec![annotated_delta]));

// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;

// Check the result
assert!(result.is_ok());
Expand Down
Loading
Loading