diff --git a/components/backends/vllm/src/dynamo/vllm/args.py b/components/backends/vllm/src/dynamo/vllm/args.py index d0f0c9b3414..381fcb38fc1 100644 --- a/components/backends/vllm/src/dynamo/vllm/args.py +++ b/components/backends/vllm/src/dynamo/vllm/args.py @@ -58,6 +58,10 @@ class Config: # Connector list from CLI connector_list: Optional[list] = None + # tool and reasoning parser info + tool_call_parser: Optional[str] = None + reasoning_parser: Optional[str] = None + def parse_args() -> Config: parser = FlexibleArgumentParser( @@ -102,6 +106,19 @@ def parse_args() -> Config: help="List of connectors to use in order (e.g., --connector nixl lmcache). " "Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.", ) + # To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args + parser.add_argument( + "--dyn-tool-call-parser", + type=str, + default=None, + help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.", + ) + parser.add_argument( + "--dyn-reasoning-parser", + type=str, + default=None, + help="Reasoning parser name for the model.", + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -151,7 +168,8 @@ def parse_args() -> Config: config.port_range = DynamoPortRange( min=args.dynamo_port_min, max=args.dynamo_port_max ) - + config.tool_call_parser = args.dyn_tool_call_parser + config.reasoning_parser = args.dyn_reasoning_parser # Check for conflicting flags has_kv_transfer_config = ( hasattr(engine_args, "kv_transfer_config") diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 7e0486c915c..5be1830c99d 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"] + runtime_config.tool_call_parser = config.tool_call_parser + runtime_config.reasoning_parser = config.reasoning_parser await register_llm( ModelType.Backend, diff --git a/lib/bindings/python/rust/llm/local_model.rs b/lib/bindings/python/rust/llm/local_model.rs index 2fdc1a153be..fc1f365906b 100644 --- a/lib/bindings/python/rust/llm/local_model.rs +++ b/lib/bindings/python/rust/llm/local_model.rs @@ -34,6 +34,16 @@ impl ModelRuntimeConfig { self.inner.max_num_batched_tokens = Some(max_num_batched_tokens); } + #[setter] + fn set_tool_call_parser(&mut self, tool_call_parser: Option) { + self.inner.tool_call_parser = tool_call_parser; + } + + #[setter] + fn set_reasoning_parser(&mut self, reasoning_parser: Option) { + self.inner.reasoning_parser = reasoning_parser; + } + fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; self.inner @@ -57,6 +67,16 @@ impl ModelRuntimeConfig { self.inner.max_num_batched_tokens } + #[getter] + fn tool_call_parser(&self) -> Option { + self.inner.tool_call_parser.clone() + } + + #[getter] + fn reasoning_parser(&self) -> Option { + self.inner.reasoning_parser.clone() + } + #[getter] fn runtime_data(&self, py: Python<'_>) -> PyResult { let dict = PyDict::new(py); diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index b934a75ccc3..95e05baca0b 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -246,6 +246,18 @@ impl ModelManager { .insert(model_name.to_string(), new_kv_chooser.clone()); Ok(new_kv_chooser) } + + pub fn get_model_tool_call_parser(&self, model: &str) -> Option { + match self.entries.lock() { + Ok(entries) => entries + .values() + .find(|entry| entry.name == model) + .and_then(|entry| entry.runtime_config.as_ref()) + .and_then(|config| config.tool_call_parser.clone()) + .map(|parser| parser.to_string()), + Err(_) => None, + } + } } pub struct ModelEngines { diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index ee3691f09fa..8a2c9c53a0c 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -37,6 +37,7 @@ use crate::protocols::openai::{ completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, responses::{NvCreateResponse, NvResponse}, + ParsingOptions, }; use crate::request_template::RequestTemplate; use crate::types::Annotated; @@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin uuid.to_string() } +fn get_parsing_options(state: &Arc, model: &str) -> ParsingOptions { + let tool_call_parser = state.manager().get_model_tool_call_parser(model); + let reasoning_parser = None; // TODO: Implement reasoning parser + + ParsingOptions::new(tool_call_parser, reasoning_parser) +} + /// OpenAI Completions Request Handler /// /// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source" @@ -267,6 +275,8 @@ async fn completions( .get_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; + let parsing_options = get_parsing_options(&state, model); + let mut inflight_guard = state .metrics_clone() @@ -325,7 +335,7 @@ async fn completions( process_metrics_only(response, &mut response_collector); }); - let response = NvCreateCompletionResponse::from_annotated_stream(stream) + let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options) .await .map_err(|e| { tracing::error!( @@ -494,6 +504,8 @@ async fn chat_completions( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; + let parsing_options = get_parsing_options(&state, model); + let mut inflight_guard = state .metrics_clone() @@ -553,19 +565,20 @@ async fn chat_completions( process_metrics_only(response, &mut response_collector); }); - let response = NvCreateChatCompletionResponse::from_annotated_stream(stream) - .await - .map_err(|e| { - tracing::error!( - request_id, - "Failed to fold chat completions stream for: {:?}", - e - ); - ErrorMessage::internal_server_error(&format!( - "Failed to fold chat completions stream: {}", - e - )) - })?; + let response = + NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone()) + .await + .map_err(|e| { + tracing::error!( + request_id, + "Failed to fold chat completions stream for: {:?}", + e + ); + ErrorMessage::internal_server_error(&format!( + "Failed to fold chat completions stream: {}", + e + )) + })?; inflight_guard.mark_ok(); Ok(Json(response).into_response()) @@ -726,6 +739,8 @@ async fn responses( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; + let parsing_options = get_parsing_options(&state, model); + let mut inflight_guard = state .metrics_clone() @@ -742,19 +757,20 @@ async fn responses( .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?; // TODO: handle streaming, currently just unary - let response = NvCreateChatCompletionResponse::from_annotated_stream(stream) - .await - .map_err(|e| { - tracing::error!( - request_id, - "Failed to fold chat completions stream for: {:?}", - e - ); - ErrorMessage::internal_server_error(&format!( - "Failed to fold chat completions stream: {}", - e - )) - })?; + let response = + NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone()) + .await + .map_err(|e| { + tracing::error!( + request_id, + "Failed to fold chat completions stream for: {:?}", + e + ); + ErrorMessage::internal_server_error(&format!( + "Failed to fold chat completions stream: {}", + e + )) + })?; // Convert NvCreateChatCompletionResponse --> NvResponse let response: NvResponse = response.try_into().map_err(|e| { diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 72708070e81..ab1a0c5b91e 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -202,6 +202,7 @@ impl LocalModelBuilder { ); card.migration_limit = self.migration_limit; card.user_data = self.user_data.take(); + return Ok(LocalModel { card, full_path: PathBuf::new(), @@ -392,6 +393,7 @@ impl LocalModel { let kvstore: Box = Box::new(EtcdStorage::new(etcd_client.clone())); let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let key = self.card.slug().to_string(); + card_store .publish(model_card::ROOT_PATH, None, &key, &mut self.card) .await?; diff --git a/lib/llm/src/local_model/runtime_config.rs b/lib/llm/src/local_model/runtime_config.rs index 4421ff4022e..8c5a6a434f9 100644 --- a/lib/llm/src/local_model/runtime_config.rs +++ b/lib/llm/src/local_model/runtime_config.rs @@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig { pub max_num_batched_tokens: Option, + pub tool_call_parser: Option, + + pub reasoning_parser: Option, + /// Mapping of engine-specific runtime configs #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub runtime_data: HashMap, diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 917fcf0c50c..f600d08c248 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -101,7 +101,6 @@ impl OpenAIPreprocessor { let mdcsum = mdc.mdcsum(); let formatter = PromptFormatter::from_mdc(mdc.clone()).await?; let PromptFormatter::OAI(formatter) = formatter; - let tokenizer = match &mdc.tokenizer { Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?, Some(TokenizerKind::GGUF(tokenizer)) => { diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 7c3166dc4cd..668d8e69336 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt: /// Gets the current prompt token count (Input Sequence Length). fn get_isl(&self) -> Option; } + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +pub struct ParsingOptions { + pub tool_call_parser: Option, + + pub reasoning_parser: Option, +} + +impl ParsingOptions { + pub fn new(tool_call_parser: Option, reasoning_parser: Option) -> Self { + Self { + tool_call_parser, + reasoning_parser, + } + } +} diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index ed15b7d69ee..a99b3e1ddac 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -19,7 +19,9 @@ use std::collections::HashMap; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use crate::protocols::{ codec::{Message, SseCodecError}, - convert_sse_stream, Annotated, + convert_sse_stream, + openai::ParsingOptions, + Annotated, }; use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate; @@ -99,6 +101,7 @@ impl DeltaAggregator { /// * `Err(String)` if an error occurs during processing. pub async fn apply( stream: impl Stream>, + parsing_options: ParsingOptions, ) -> Result { let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { @@ -175,7 +178,10 @@ impl DeltaAggregator { // After aggregation, inspect each choice's text for tool call syntax for choice in aggregator.choices.values_mut() { if choice.tool_calls.is_none() { - if let Ok(tool_calls) = try_tool_call_parse_aggregate(&choice.text, None) { + if let Ok(tool_calls) = try_tool_call_parse_aggregate( + &choice.text, + parsing_options.tool_call_parser.as_deref(), + ) { if tool_calls.is_empty() { continue; } @@ -262,6 +268,7 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_annotated_stream( stream: impl Stream>, + parsing_options: ParsingOptions, ) -> Result; /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`]. @@ -274,21 +281,24 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_sse_stream( stream: DataStream>, + parsing_options: ParsingOptions, ) -> Result; } impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse { async fn from_annotated_stream( stream: impl Stream>, + parsing_options: ParsingOptions, ) -> Result { - DeltaAggregator::apply(stream).await + DeltaAggregator::apply(stream, parsing_options).await } async fn from_sse_stream( stream: DataStream>, + parsing_options: ParsingOptions, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateChatCompletionResponse::from_annotated_stream(stream).await + NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options).await } } @@ -347,7 +357,7 @@ mod tests { Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -377,7 +387,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -421,7 +431,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -492,7 +502,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -550,7 +560,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/src/protocols/openai/completions/aggregator.rs b/lib/llm/src/protocols/openai/completions/aggregator.rs index 7bb0c44aea9..e72fc072c27 100644 --- a/lib/llm/src/protocols/openai/completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/completions/aggregator.rs @@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse; use crate::protocols::{ codec::{Message, SseCodecError}, common::FinishReason, - convert_sse_stream, Annotated, DataStream, + convert_sse_stream, + openai::ParsingOptions, + Annotated, DataStream, }; /// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`]. @@ -65,7 +67,9 @@ impl DeltaAggregator { /// Aggregates a stream of [`Annotated`]s into a single [`CompletionResponse`]. pub async fn apply( stream: impl Stream>, + parsing_options: ParsingOptions, ) -> Result { + tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); // TODO: remove this once completion has tool call support let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { let delta = match delta.ok() { @@ -177,15 +181,17 @@ impl From for dynamo_async_openai::types::Choice { impl NvCreateCompletionResponse { pub async fn from_sse_stream( stream: DataStream>, + parsing_options: ParsingOptions, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateCompletionResponse::from_annotated_stream(stream).await + NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await } pub async fn from_annotated_stream( stream: impl Stream>, + parsing_options: ParsingOptions, ) -> Result { - DeltaAggregator::apply(stream).await + DeltaAggregator::apply(stream, parsing_options).await } } @@ -241,7 +247,7 @@ mod tests { let stream: DataStream> = Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -265,7 +271,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -305,7 +311,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -365,7 +371,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/tests/aggregators.rs b/lib/llm/tests/aggregators.rs index 5f16715e43a..c6ad39dfa9c 100644 --- a/lib/llm/tests/aggregators.rs +++ b/lib/llm/tests/aggregators.rs @@ -18,6 +18,7 @@ use dynamo_llm::protocols::{ openai::{ chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse}, completions::NvCreateCompletionResponse, + ParsingOptions, }, ContentProvider, DataStream, }; @@ -37,9 +38,12 @@ async fn test_openai_chat_stream() { // note: we are only taking the first 16 messages to keep the size of the response small let stream = create_message_stream(&data).take(16); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) - .await - .unwrap(); + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await + .unwrap(); // todo: provide a cleaner way to extract the content from choices assert_eq!( @@ -59,9 +63,12 @@ async fn test_openai_chat_stream() { #[tokio::test] async fn test_openai_chat_edge_case_multi_line_data() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) - .await - .unwrap(); + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await + .unwrap(); assert_eq!( result @@ -79,9 +86,12 @@ async fn test_openai_chat_edge_case_multi_line_data() { #[tokio::test] async fn test_openai_chat_edge_case_comments_per_response() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) - .await - .unwrap(); + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await + .unwrap(); assert_eq!( result @@ -99,7 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() { #[tokio::test] async fn test_openai_chat_edge_case_invalid_deserialize_error() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)).await; + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await; assert!(result.is_err()); // insta::assert_debug_snapshot!(result); @@ -112,9 +126,10 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() { #[tokio::test] async fn test_openai_cmpl_stream() { let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16); - let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream)) - .await - .unwrap(); + let result = + NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), ParsingOptions::default()) + .await + .unwrap(); // todo: provide a cleaner way to extract the content from choices assert_eq!( diff --git a/lib/parsers/src/tool_calling/tools.rs b/lib/parsers/src/tool_calling/tools.rs index 96b88a59b4c..7f326b46ad0 100644 --- a/lib/parsers/src/tool_calling/tools.rs +++ b/lib/parsers/src/tool_calling/tools.rs @@ -14,6 +14,11 @@ pub fn try_tool_call_parse_aggregate( message: &str, parser_str: Option<&str>, ) -> anyhow::Result> { + if parser_str.is_none() { + tracing::info!("No tool parser provided. Trying parsing with default parser."); + } else { + tracing::info!("Using tool parser: {:?}", parser_str); + } let parsed = detect_and_parse_tool_call(message, parser_str)?; if parsed.is_empty() { return Ok(vec![]);