From d876f7a2874ad6f3e3d783e00e82e0d79beaa50a Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 21 Aug 2025 20:50:08 +0000 Subject: [PATCH 1/9] chore: added reasoning and tool parser arg in vllm and mdc --- .../backends/vllm/src/dynamo/vllm/args.py | 19 ++++++++++++++++++- .../backends/vllm/src/dynamo/vllm/main.py | 2 ++ lib/llm/src/local_model.rs | 3 +++ lib/llm/src/local_model/runtime_config.rs | 4 ++++ lib/llm/src/model_card.rs | 12 ++++++++++++ 5 files changed, 39 insertions(+), 1 deletion(-) diff --git a/components/backends/vllm/src/dynamo/vllm/args.py b/components/backends/vllm/src/dynamo/vllm/args.py index 293275f046..9f9c1cf9b7 100644 --- a/components/backends/vllm/src/dynamo/vllm/args.py +++ b/components/backends/vllm/src/dynamo/vllm/args.py @@ -58,6 +58,10 @@ class Config: # Connector list from CLI connector_list: Optional[list] = None + # tool and reasoning parser info + tool_call_parser: Optional[str] = None + reasoning_parser: Optional[str] = None + def parse_args() -> Config: parser = FlexibleArgumentParser( @@ -102,6 +106,18 @@ def parse_args() -> Config: help="List of connectors to use in order (e.g., --connector nixl lmcache). " "Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.", ) + parser.add_argument( + "--tool-call-parser", + type=str, + default=None, + help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.", + ) + parser.add_argument( + "--reasoning-parser", + type=str, + default=None, + help="Reasoning parser name for the model.", + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -151,7 +167,8 @@ def parse_args() -> Config: config.port_range = DynamoPortRange( min=args.dynamo_port_min, max=args.dynamo_port_max ) - + config.tool_call_parser = args.tool_call_parser + config.reasoning_parser = args.reasoning_parser # Check for conflicting flags has_kv_transfer_config = ( hasattr(engine_args, "kv_transfer_config") diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 7e0486c915..5be1830c99 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -234,6 +234,8 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"] + runtime_config.tool_call_parser = config.tool_call_parser + runtime_config.reasoning_parser = config.reasoning_parser await register_llm( ModelType.Backend, diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index b9c1f8085c..4d8c80e840 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -202,6 +202,9 @@ impl LocalModelBuilder { ); card.migration_limit = self.migration_limit; card.user_data = self.user_data.take(); + card.tool_call_parser = self.runtime_config.tool_call_parser.take(); + card.reasoning_parser = self.runtime_config.reasoning_parser.take(); + return Ok(LocalModel { card, full_path: PathBuf::new(), diff --git a/lib/llm/src/local_model/runtime_config.rs b/lib/llm/src/local_model/runtime_config.rs index 4421ff4022..8c5a6a434f 100644 --- a/lib/llm/src/local_model/runtime_config.rs +++ b/lib/llm/src/local_model/runtime_config.rs @@ -13,6 +13,10 @@ pub struct ModelRuntimeConfig { pub max_num_batched_tokens: Option, + pub tool_call_parser: Option, + + pub reasoning_parser: Option, + /// Mapping of engine-specific runtime configs #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub runtime_data: HashMap, diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index e4445c5512..a6badca8ba 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -137,6 +137,14 @@ pub struct ModelDeploymentCard { /// User-defined metadata for custom worker behavior #[serde(default, skip_serializing_if = "Option::is_none")] pub user_data: Option, + + /// Tool call parser name for the model. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_call_parser: Option, + + /// Reasoning parser name for the model. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_parser: Option, } impl ModelDeploymentCard { @@ -441,6 +449,8 @@ impl ModelDeploymentCard { kv_cache_block_size: 0, migration_limit: 0, user_data: None, + tool_call_parser: None, + reasoning_parser: None, }) } @@ -482,6 +492,8 @@ impl ModelDeploymentCard { kv_cache_block_size: 0, // set later migration_limit: 0, user_data: None, + tool_call_parser: None, + reasoning_parser: None, }) } } From 5a83e4973cb1501457f80e15697bbf1b9b9d8d91 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 21 Aug 2025 20:53:18 +0000 Subject: [PATCH 2/9] fix: lint --- components/backends/vllm/src/dynamo/vllm/args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/backends/vllm/src/dynamo/vllm/args.py b/components/backends/vllm/src/dynamo/vllm/args.py index 9f9c1cf9b7..e1ea2bf790 100644 --- a/components/backends/vllm/src/dynamo/vllm/args.py +++ b/components/backends/vllm/src/dynamo/vllm/args.py @@ -58,7 +58,7 @@ class Config: # Connector list from CLI connector_list: Optional[list] = None - # tool and reasoning parser info + # tool and reasoning parser info tool_call_parser: Optional[str] = None reasoning_parser: Optional[str] = None From 4be423606eb16f1a7b2cb198e23c4799d9f376b2 Mon Sep 17 00:00:00 2001 From: ayushag Date: Thu, 21 Aug 2025 22:54:03 +0000 Subject: [PATCH 3/9] chore: add debug statements tmp --- .../backends/vllm/src/dynamo/vllm/args.py | 9 +++++---- .../backends/vllm/src/dynamo/vllm/main.py | 3 +++ lib/bindings/python/rust/lib.rs | 2 ++ lib/bindings/python/rust/llm/local_model.rs | 20 +++++++++++++++++++ lib/llm/src/local_model.rs | 6 ++++-- lib/llm/src/preprocessor.rs | 3 ++- 6 files changed, 36 insertions(+), 7 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/args.py b/components/backends/vllm/src/dynamo/vllm/args.py index e1ea2bf790..a9ccc89f7a 100644 --- a/components/backends/vllm/src/dynamo/vllm/args.py +++ b/components/backends/vllm/src/dynamo/vllm/args.py @@ -106,14 +106,15 @@ def parse_args() -> Config: help="List of connectors to use in order (e.g., --connector nixl lmcache). " "Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.", ) + # To avoid name conflicts with different backends, adoped prefix "dyn-" for dynamo specific args parser.add_argument( - "--tool-call-parser", + "--dyn-tool-call-parser", type=str, default=None, help="Tool call parser name for the model. Available options: 'hermes', 'nemotron_deci', 'llama3_json', 'mistral', 'phi4'.", ) parser.add_argument( - "--reasoning-parser", + "--dyn-reasoning-parser", type=str, default=None, help="Reasoning parser name for the model.", @@ -167,8 +168,8 @@ def parse_args() -> Config: config.port_range = DynamoPortRange( min=args.dynamo_port_min, max=args.dynamo_port_max ) - config.tool_call_parser = args.tool_call_parser - config.reasoning_parser = args.reasoning_parser + config.tool_call_parser = args.dyn_tool_call_parser + config.reasoning_parser = args.dyn_reasoning_parser # Check for conflicting flags has_kv_transfer_config = ( hasattr(engine_args, "kv_transfer_config") diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 5be1830c99..753cde93cf 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -237,6 +237,9 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config.tool_call_parser = config.tool_call_parser runtime_config.reasoning_parser = config.reasoning_parser + print("Registering LLM") + print("Tool Call Parser: ", runtime_config.tool_call_parser) + print("Reasoning Parser: ", runtime_config.reasoning_parser) await register_llm( ModelType.Backend, generate_endpoint, diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 5989f076b1..aff6d6c796 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -183,6 +183,8 @@ fn register_llm<'p>( .user_data(user_data_json); // Download from HF, load the ModelDeploymentCard let mut local_model = builder.build().await.map_err(to_pyerr)?; + println!("MDC Card tool call parser: {:?}", local_model.card().tool_call_parser); + println!("MDC Card reasoning parser: {:?}", local_model.card().reasoning_parser); // Advertise ourself on etcd so ingress can find us local_model .attach(&endpoint.inner, model_type_obj) diff --git a/lib/bindings/python/rust/llm/local_model.rs b/lib/bindings/python/rust/llm/local_model.rs index 2fdc1a153b..fc1f365906 100644 --- a/lib/bindings/python/rust/llm/local_model.rs +++ b/lib/bindings/python/rust/llm/local_model.rs @@ -34,6 +34,16 @@ impl ModelRuntimeConfig { self.inner.max_num_batched_tokens = Some(max_num_batched_tokens); } + #[setter] + fn set_tool_call_parser(&mut self, tool_call_parser: Option) { + self.inner.tool_call_parser = tool_call_parser; + } + + #[setter] + fn set_reasoning_parser(&mut self, reasoning_parser: Option) { + self.inner.reasoning_parser = reasoning_parser; + } + fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; self.inner @@ -57,6 +67,16 @@ impl ModelRuntimeConfig { self.inner.max_num_batched_tokens } + #[getter] + fn tool_call_parser(&self) -> Option { + self.inner.tool_call_parser.clone() + } + + #[getter] + fn reasoning_parser(&self) -> Option { + self.inner.reasoning_parser.clone() + } + #[getter] fn runtime_data(&self, py: Python<'_>) -> PyResult { let dict = PyDict::new(py); diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 4d8c80e840..9db17713c0 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -202,8 +202,6 @@ impl LocalModelBuilder { ); card.migration_limit = self.migration_limit; card.user_data = self.user_data.take(); - card.tool_call_parser = self.runtime_config.tool_call_parser.take(); - card.reasoning_parser = self.runtime_config.reasoning_parser.take(); return Ok(LocalModel { card, @@ -280,6 +278,8 @@ impl LocalModelBuilder { card.migration_limit = self.migration_limit; card.user_data = self.user_data.take(); + card.tool_call_parser = self.runtime_config.tool_call_parser.take(); + card.reasoning_parser = self.runtime_config.reasoning_parser.take(); Ok(LocalModel { card, @@ -395,6 +395,8 @@ impl LocalModel { let kvstore: Box = Box::new(EtcdStorage::new(etcd_client.clone())); let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let key = self.card.slug().to_string(); + println!("Inside LocalModel attach"); + println!("Tool Call Parser: {:?}", self.card.tool_call_parser); card_store .publish(model_card::ROOT_PATH, None, &key, &mut self.card) .await?; diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 917fcf0c50..7e42fceafb 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -101,7 +101,8 @@ impl OpenAIPreprocessor { let mdcsum = mdc.mdcsum(); let formatter = PromptFormatter::from_mdc(mdc.clone()).await?; let PromptFormatter::OAI(formatter) = formatter; - + println!("Inside OpenAIPreprocessor"); + println!("Tool Call Parser: {:?}", mdc.tool_call_parser); let tokenizer = match &mdc.tokenizer { Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?, Some(TokenizerKind::GGUF(tokenizer)) => { From 73f78d7ca6eeb067e1c276bcec785416f9deba7b Mon Sep 17 00:00:00 2001 From: ayushag Date: Fri, 22 Aug 2025 05:52:53 +0000 Subject: [PATCH 4/9] chore: cli for vllm works e2e --- .../backends/vllm/src/dynamo/vllm/main.py | 3 - lib/bindings/python/rust/lib.rs | 2 - lib/llm/src/discovery/model_manager.rs | 11 +++ lib/llm/src/http/service/openai.rs | 83 +++++++++++-------- lib/llm/src/local_model.rs | 5 +- lib/llm/src/model_card.rs | 12 --- lib/llm/src/preprocessor.rs | 2 - .../openai/chat_completions/aggregator.rs | 23 +++-- .../openai/completions/aggregator.rs | 16 ++-- lib/llm/tests/aggregators.rs | 10 +-- 10 files changed, 89 insertions(+), 78 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 753cde93cf..5be1830c99 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -237,9 +237,6 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config.tool_call_parser = config.tool_call_parser runtime_config.reasoning_parser = config.reasoning_parser - print("Registering LLM") - print("Tool Call Parser: ", runtime_config.tool_call_parser) - print("Reasoning Parser: ", runtime_config.reasoning_parser) await register_llm( ModelType.Backend, generate_endpoint, diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index aff6d6c796..5989f076b1 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -183,8 +183,6 @@ fn register_llm<'p>( .user_data(user_data_json); // Download from HF, load the ModelDeploymentCard let mut local_model = builder.build().await.map_err(to_pyerr)?; - println!("MDC Card tool call parser: {:?}", local_model.card().tool_call_parser); - println!("MDC Card reasoning parser: {:?}", local_model.card().reasoning_parser); // Advertise ourself on etcd so ingress can find us local_model .attach(&endpoint.inner, model_type_obj) diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index b934a75ccc..01d5d57cad 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -246,6 +246,17 @@ impl ModelManager { .insert(model_name.to_string(), new_kv_chooser.clone()); Ok(new_kv_chooser) } + + pub fn get_model_tool_call_parser(&self, model: &str) -> Option { + self.entries + .lock() + .unwrap() + .values() + .find(|entry| entry.name == model) + .and_then(|entry| entry.runtime_config.as_ref()) + .and_then(|config| config.tool_call_parser.clone()) + .map(|parser| parser.to_string()) + } } pub struct ModelEngines { diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index ee3691f09f..02b04942e5 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -267,6 +267,8 @@ async fn completions( .get_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; + let tool_call_parser = state.manager().get_model_tool_call_parser(model); + let mut inflight_guard = state .metrics_clone() @@ -325,16 +327,17 @@ async fn completions( process_metrics_only(response, &mut response_collector); }); - let response = NvCreateCompletionResponse::from_annotated_stream(stream) - .await - .map_err(|e| { - tracing::error!( - "Failed to fold completions stream for {}: {:?}", - request_id, - e - ); - ErrorMessage::internal_server_error("Failed to fold completions stream") - })?; + let response = + NvCreateCompletionResponse::from_annotated_stream(stream, tool_call_parser.clone()) + .await + .map_err(|e| { + tracing::error!( + "Failed to fold completions stream for {}: {:?}", + request_id, + e + ); + ErrorMessage::internal_server_error("Failed to fold completions stream") + })?; inflight_guard.mark_ok(); Ok(Json(response).into_response()) @@ -494,6 +497,9 @@ async fn chat_completions( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; + let tool_call_parser = state.manager().get_model_tool_call_parser(model); + println!("Tool Call Parser: {:?}", tool_call_parser); + let mut inflight_guard = state .metrics_clone() @@ -553,19 +559,20 @@ async fn chat_completions( process_metrics_only(response, &mut response_collector); }); - let response = NvCreateChatCompletionResponse::from_annotated_stream(stream) - .await - .map_err(|e| { - tracing::error!( - request_id, - "Failed to fold chat completions stream for: {:?}", - e - ); - ErrorMessage::internal_server_error(&format!( - "Failed to fold chat completions stream: {}", - e - )) - })?; + let response = + NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser.clone()) + .await + .map_err(|e| { + tracing::error!( + request_id, + "Failed to fold chat completions stream for: {:?}", + e + ); + ErrorMessage::internal_server_error(&format!( + "Failed to fold chat completions stream: {}", + e + )) + })?; inflight_guard.mark_ok(); Ok(Json(response).into_response()) @@ -726,6 +733,9 @@ async fn responses( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; + let tool_call_parser = state.manager().get_model_tool_call_parser(model); + println!("Tool Call Parser: {:?}", tool_call_parser); + let mut inflight_guard = state .metrics_clone() @@ -742,19 +752,20 @@ async fn responses( .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?; // TODO: handle streaming, currently just unary - let response = NvCreateChatCompletionResponse::from_annotated_stream(stream) - .await - .map_err(|e| { - tracing::error!( - request_id, - "Failed to fold chat completions stream for: {:?}", - e - ); - ErrorMessage::internal_server_error(&format!( - "Failed to fold chat completions stream: {}", - e - )) - })?; + let response = + NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser.clone()) + .await + .map_err(|e| { + tracing::error!( + request_id, + "Failed to fold chat completions stream for: {:?}", + e + ); + ErrorMessage::internal_server_error(&format!( + "Failed to fold chat completions stream: {}", + e + )) + })?; // Convert NvCreateChatCompletionResponse --> NvResponse let response: NvResponse = response.try_into().map_err(|e| { diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 9db17713c0..24f8f6c536 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -278,8 +278,6 @@ impl LocalModelBuilder { card.migration_limit = self.migration_limit; card.user_data = self.user_data.take(); - card.tool_call_parser = self.runtime_config.tool_call_parser.take(); - card.reasoning_parser = self.runtime_config.reasoning_parser.take(); Ok(LocalModel { card, @@ -395,8 +393,7 @@ impl LocalModel { let kvstore: Box = Box::new(EtcdStorage::new(etcd_client.clone())); let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let key = self.card.slug().to_string(); - println!("Inside LocalModel attach"); - println!("Tool Call Parser: {:?}", self.card.tool_call_parser); + card_store .publish(model_card::ROOT_PATH, None, &key, &mut self.card) .await?; diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index a6badca8ba..e4445c5512 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -137,14 +137,6 @@ pub struct ModelDeploymentCard { /// User-defined metadata for custom worker behavior #[serde(default, skip_serializing_if = "Option::is_none")] pub user_data: Option, - - /// Tool call parser name for the model. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_call_parser: Option, - - /// Reasoning parser name for the model. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub reasoning_parser: Option, } impl ModelDeploymentCard { @@ -449,8 +441,6 @@ impl ModelDeploymentCard { kv_cache_block_size: 0, migration_limit: 0, user_data: None, - tool_call_parser: None, - reasoning_parser: None, }) } @@ -492,8 +482,6 @@ impl ModelDeploymentCard { kv_cache_block_size: 0, // set later migration_limit: 0, user_data: None, - tool_call_parser: None, - reasoning_parser: None, }) } } diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 7e42fceafb..f600d08c24 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -101,8 +101,6 @@ impl OpenAIPreprocessor { let mdcsum = mdc.mdcsum(); let formatter = PromptFormatter::from_mdc(mdc.clone()).await?; let PromptFormatter::OAI(formatter) = formatter; - println!("Inside OpenAIPreprocessor"); - println!("Tool Call Parser: {:?}", mdc.tool_call_parser); let tokenizer = match &mdc.tokenizer { Some(TokenizerKind::HfTokenizerJson(file)) => HuggingFaceTokenizer::from_file(file)?, Some(TokenizerKind::GGUF(tokenizer)) => { diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index ed15b7d69e..1ca33b4d7d 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -99,6 +99,7 @@ impl DeltaAggregator { /// * `Err(String)` if an error occurs during processing. pub async fn apply( stream: impl Stream>, + tool_call_parser: Option, ) -> Result { let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { @@ -175,7 +176,9 @@ impl DeltaAggregator { // After aggregation, inspect each choice's text for tool call syntax for choice in aggregator.choices.values_mut() { if choice.tool_calls.is_none() { - if let Ok(tool_calls) = try_tool_call_parse_aggregate(&choice.text, None) { + if let Ok(tool_calls) = + try_tool_call_parse_aggregate(&choice.text, tool_call_parser.as_deref()) + { if tool_calls.is_empty() { continue; } @@ -262,6 +265,7 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_annotated_stream( stream: impl Stream>, + tool_call_parser: Option, ) -> Result; /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`]. @@ -274,21 +278,24 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_sse_stream( stream: DataStream>, + tool_call_parser: Option, ) -> Result; } impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse { async fn from_annotated_stream( stream: impl Stream>, + tool_call_parser: Option, ) -> Result { - DeltaAggregator::apply(stream).await + DeltaAggregator::apply(stream, tool_call_parser).await } async fn from_sse_stream( stream: DataStream>, + tool_call_parser: Option, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateChatCompletionResponse::from_annotated_stream(stream).await + NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser).await } } @@ -347,7 +354,7 @@ mod tests { Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); @@ -377,7 +384,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); @@ -421,7 +428,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); @@ -492,7 +499,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); @@ -550,7 +557,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/src/protocols/openai/completions/aggregator.rs b/lib/llm/src/protocols/openai/completions/aggregator.rs index 7bb0c44aea..05570759f2 100644 --- a/lib/llm/src/protocols/openai/completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/completions/aggregator.rs @@ -65,7 +65,9 @@ impl DeltaAggregator { /// Aggregates a stream of [`Annotated`]s into a single [`CompletionResponse`]. pub async fn apply( stream: impl Stream>, + tool_call_parser: Option, ) -> Result { + println!("Tool Call Parser: {:?}", tool_call_parser); // TODO: remove this once completion has tool call support let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { let delta = match delta.ok() { @@ -177,15 +179,17 @@ impl From for dynamo_async_openai::types::Choice { impl NvCreateCompletionResponse { pub async fn from_sse_stream( stream: DataStream>, + tool_call_parser: Option, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateCompletionResponse::from_annotated_stream(stream).await + NvCreateCompletionResponse::from_annotated_stream(stream, tool_call_parser).await } pub async fn from_annotated_stream( stream: impl Stream>, + tool_call_parser: Option, ) -> Result { - DeltaAggregator::apply(stream).await + DeltaAggregator::apply(stream, tool_call_parser).await } } @@ -241,7 +245,7 @@ mod tests { let stream: DataStream> = Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); @@ -265,7 +269,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); @@ -305,7 +309,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); @@ -365,7 +369,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream).await; + let result = DeltaAggregator::apply(stream, None).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/tests/aggregators.rs b/lib/llm/tests/aggregators.rs index 5f16715e43..f6f9ee7818 100644 --- a/lib/llm/tests/aggregators.rs +++ b/lib/llm/tests/aggregators.rs @@ -37,7 +37,7 @@ async fn test_openai_chat_stream() { // note: we are only taking the first 16 messages to keep the size of the response small let stream = create_message_stream(&data).take(16); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) + let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None) .await .unwrap(); @@ -59,7 +59,7 @@ async fn test_openai_chat_stream() { #[tokio::test] async fn test_openai_chat_edge_case_multi_line_data() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) + let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None) .await .unwrap(); @@ -79,7 +79,7 @@ async fn test_openai_chat_edge_case_multi_line_data() { #[tokio::test] async fn test_openai_chat_edge_case_comments_per_response() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)) + let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None) .await .unwrap(); @@ -99,7 +99,7 @@ async fn test_openai_chat_edge_case_comments_per_response() { #[tokio::test] async fn test_openai_chat_edge_case_invalid_deserialize_error() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)).await; + let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None).await; assert!(result.is_err()); // insta::assert_debug_snapshot!(result); @@ -112,7 +112,7 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() { #[tokio::test] async fn test_openai_cmpl_stream() { let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16); - let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream)) + let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), None) .await .unwrap(); From 211519c8aab9f8ad61ef09234e717c3037ac1eca Mon Sep 17 00:00:00 2001 From: ayushag Date: Fri, 22 Aug 2025 19:13:04 +0000 Subject: [PATCH 5/9] chore: added stream_args struct --- lib/llm/src/http/service/openai.rs | 69 ++++++++++--------- lib/llm/src/protocols/openai.rs | 16 +++++ .../openai/chat_completions/aggregator.rs | 35 +++++----- .../openai/completions/aggregator.rs | 24 ++++--- lib/llm/tests/aggregators.rs | 33 +++++---- 5 files changed, 106 insertions(+), 71 deletions(-) diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 02b04942e5..c72b65536a 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -37,6 +37,7 @@ use crate::protocols::openai::{ completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, responses::{NvCreateResponse, NvResponse}, + StreamArgs, }; use crate::request_template::RequestTemplate; use crate::types::Annotated; @@ -194,6 +195,13 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin uuid.to_string() } +fn get_stream_args(state: &Arc, model: &str) -> StreamArgs { + let tool_call_parser = state.manager().get_model_tool_call_parser(model); + let reasoning_parser = None; // TODO: Implement reasoning parser + + StreamArgs::new(tool_call_parser, reasoning_parser) +} + /// OpenAI Completions Request Handler /// /// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source" @@ -267,7 +275,7 @@ async fn completions( .get_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; - let tool_call_parser = state.manager().get_model_tool_call_parser(model); + let extra_stream_args = get_stream_args(&state, model); let mut inflight_guard = state @@ -327,17 +335,16 @@ async fn completions( process_metrics_only(response, &mut response_collector); }); - let response = - NvCreateCompletionResponse::from_annotated_stream(stream, tool_call_parser.clone()) - .await - .map_err(|e| { - tracing::error!( - "Failed to fold completions stream for {}: {:?}", - request_id, - e - ); - ErrorMessage::internal_server_error("Failed to fold completions stream") - })?; + let response = NvCreateCompletionResponse::from_annotated_stream(stream, extra_stream_args) + .await + .map_err(|e| { + tracing::error!( + "Failed to fold completions stream for {}: {:?}", + request_id, + e + ); + ErrorMessage::internal_server_error("Failed to fold completions stream") + })?; inflight_guard.mark_ok(); Ok(Json(response).into_response()) @@ -497,8 +504,7 @@ async fn chat_completions( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; - let tool_call_parser = state.manager().get_model_tool_call_parser(model); - println!("Tool Call Parser: {:?}", tool_call_parser); + let extra_stream_args = get_stream_args(&state, model); let mut inflight_guard = state @@ -559,20 +565,22 @@ async fn chat_completions( process_metrics_only(response, &mut response_collector); }); - let response = - NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser.clone()) - .await - .map_err(|e| { - tracing::error!( - request_id, - "Failed to fold chat completions stream for: {:?}", - e - ); - ErrorMessage::internal_server_error(&format!( - "Failed to fold chat completions stream: {}", - e - )) - })?; + let response = NvCreateChatCompletionResponse::from_annotated_stream( + stream, + extra_stream_args.clone(), + ) + .await + .map_err(|e| { + tracing::error!( + request_id, + "Failed to fold chat completions stream for: {:?}", + e + ); + ErrorMessage::internal_server_error(&format!( + "Failed to fold chat completions stream: {}", + e + )) + })?; inflight_guard.mark_ok(); Ok(Json(response).into_response()) @@ -733,8 +741,7 @@ async fn responses( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; - let tool_call_parser = state.manager().get_model_tool_call_parser(model); - println!("Tool Call Parser: {:?}", tool_call_parser); + let extra_stream_args = get_stream_args(&state, model); let mut inflight_guard = state @@ -753,7 +760,7 @@ async fn responses( // TODO: handle streaming, currently just unary let response = - NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser.clone()) + NvCreateChatCompletionResponse::from_annotated_stream(stream, extra_stream_args.clone()) .await .map_err(|e| { tracing::error!( diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 7c3166dc4c..2f0a1e9634 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -193,3 +193,19 @@ pub trait DeltaGeneratorExt: /// Gets the current prompt token count (Input Sequence Length). fn get_isl(&self) -> Option; } + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +pub struct StreamArgs { + pub tool_call_parser: Option, + + pub reasoning_parser: Option, +} + +impl StreamArgs { + pub fn new(tool_call_parser: Option, reasoning_parser: Option) -> Self { + Self { + tool_call_parser, + reasoning_parser, + } + } +} diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index 1ca33b4d7d..96c59e992c 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -19,7 +19,9 @@ use std::collections::HashMap; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use crate::protocols::{ codec::{Message, SseCodecError}, - convert_sse_stream, Annotated, + convert_sse_stream, + openai::StreamArgs, + Annotated, }; use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate; @@ -99,7 +101,7 @@ impl DeltaAggregator { /// * `Err(String)` if an error occurs during processing. pub async fn apply( stream: impl Stream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result { let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { @@ -176,9 +178,10 @@ impl DeltaAggregator { // After aggregation, inspect each choice's text for tool call syntax for choice in aggregator.choices.values_mut() { if choice.tool_calls.is_none() { - if let Ok(tool_calls) = - try_tool_call_parse_aggregate(&choice.text, tool_call_parser.as_deref()) - { + if let Ok(tool_calls) = try_tool_call_parse_aggregate( + &choice.text, + extra_stream_args.tool_call_parser.as_deref(), + ) { if tool_calls.is_empty() { continue; } @@ -265,7 +268,7 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_annotated_stream( stream: impl Stream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result; /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`]. @@ -278,24 +281,24 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_sse_stream( stream: DataStream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result; } impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse { async fn from_annotated_stream( stream: impl Stream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result { - DeltaAggregator::apply(stream, tool_call_parser).await + DeltaAggregator::apply(stream, extra_stream_args).await } async fn from_sse_stream( stream: DataStream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateChatCompletionResponse::from_annotated_stream(stream, tool_call_parser).await + NvCreateChatCompletionResponse::from_annotated_stream(stream, extra_stream_args).await } } @@ -354,7 +357,7 @@ mod tests { Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); @@ -384,7 +387,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); @@ -428,7 +431,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); @@ -499,7 +502,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); @@ -557,7 +560,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/src/protocols/openai/completions/aggregator.rs b/lib/llm/src/protocols/openai/completions/aggregator.rs index 05570759f2..b0dd7283fa 100644 --- a/lib/llm/src/protocols/openai/completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/completions/aggregator.rs @@ -22,7 +22,9 @@ use super::NvCreateCompletionResponse; use crate::protocols::{ codec::{Message, SseCodecError}, common::FinishReason, - convert_sse_stream, Annotated, DataStream, + convert_sse_stream, + openai::StreamArgs, + Annotated, DataStream, }; /// Aggregates a stream of [`CompletionResponse`]s into a single [`CompletionResponse`]. @@ -65,9 +67,9 @@ impl DeltaAggregator { /// Aggregates a stream of [`Annotated`]s into a single [`CompletionResponse`]. pub async fn apply( stream: impl Stream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result { - println!("Tool Call Parser: {:?}", tool_call_parser); // TODO: remove this once completion has tool call support + println!("Tool Call Parser: {:?}", extra_stream_args.tool_call_parser); // TODO: remove this once completion has tool call support let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { let delta = match delta.ok() { @@ -179,17 +181,17 @@ impl From for dynamo_async_openai::types::Choice { impl NvCreateCompletionResponse { pub async fn from_sse_stream( stream: DataStream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateCompletionResponse::from_annotated_stream(stream, tool_call_parser).await + NvCreateCompletionResponse::from_annotated_stream(stream, extra_stream_args).await } pub async fn from_annotated_stream( stream: impl Stream>, - tool_call_parser: Option, + extra_stream_args: StreamArgs, ) -> Result { - DeltaAggregator::apply(stream, tool_call_parser).await + DeltaAggregator::apply(stream, extra_stream_args).await } } @@ -245,7 +247,7 @@ mod tests { let stream: DataStream> = Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); @@ -269,7 +271,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); @@ -309,7 +311,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); @@ -369,7 +371,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, None).await; + let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/tests/aggregators.rs b/lib/llm/tests/aggregators.rs index f6f9ee7818..794e6c996e 100644 --- a/lib/llm/tests/aggregators.rs +++ b/lib/llm/tests/aggregators.rs @@ -18,6 +18,7 @@ use dynamo_llm::protocols::{ openai::{ chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse}, completions::NvCreateCompletionResponse, + StreamArgs, }, ContentProvider, DataStream, }; @@ -37,9 +38,10 @@ async fn test_openai_chat_stream() { // note: we are only taking the first 16 messages to keep the size of the response small let stream = create_message_stream(&data).take(16); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None) - .await - .unwrap(); + let result = + NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) + .await + .unwrap(); // todo: provide a cleaner way to extract the content from choices assert_eq!( @@ -59,9 +61,10 @@ async fn test_openai_chat_stream() { #[tokio::test] async fn test_openai_chat_edge_case_multi_line_data() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None) - .await - .unwrap(); + let result = + NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) + .await + .unwrap(); assert_eq!( result @@ -79,9 +82,10 @@ async fn test_openai_chat_edge_case_multi_line_data() { #[tokio::test] async fn test_openai_chat_edge_case_comments_per_response() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None) - .await - .unwrap(); + let result = + NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) + .await + .unwrap(); assert_eq!( result @@ -99,7 +103,9 @@ async fn test_openai_chat_edge_case_comments_per_response() { #[tokio::test] async fn test_openai_chat_edge_case_invalid_deserialize_error() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error"); - let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), None).await; + let result = + NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) + .await; assert!(result.is_err()); // insta::assert_debug_snapshot!(result); @@ -112,9 +118,10 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() { #[tokio::test] async fn test_openai_cmpl_stream() { let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16); - let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), None) - .await - .unwrap(); + let result = + NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) + .await + .unwrap(); // todo: provide a cleaner way to extract the content from choices assert_eq!( From 801b3fc3b1cf3121344b58853a98f9fef31e4ee4 Mon Sep 17 00:00:00 2001 From: ayushag Date: Fri, 22 Aug 2025 19:21:56 +0000 Subject: [PATCH 6/9] chore: add correct print stats --- lib/llm/src/protocols/openai/completions/aggregator.rs | 2 +- lib/parsers/src/tool_calling/tools.rs | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/llm/src/protocols/openai/completions/aggregator.rs b/lib/llm/src/protocols/openai/completions/aggregator.rs index b0dd7283fa..dd3281398d 100644 --- a/lib/llm/src/protocols/openai/completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/completions/aggregator.rs @@ -69,7 +69,7 @@ impl DeltaAggregator { stream: impl Stream>, extra_stream_args: StreamArgs, ) -> Result { - println!("Tool Call Parser: {:?}", extra_stream_args.tool_call_parser); // TODO: remove this once completion has tool call support + tracing::debug!("Tool Call Parser: {:?}", extra_stream_args.tool_call_parser); // TODO: remove this once completion has tool call support let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { let delta = match delta.ok() { diff --git a/lib/parsers/src/tool_calling/tools.rs b/lib/parsers/src/tool_calling/tools.rs index 96b88a59b4..54176b81a2 100644 --- a/lib/parsers/src/tool_calling/tools.rs +++ b/lib/parsers/src/tool_calling/tools.rs @@ -14,6 +14,12 @@ pub fn try_tool_call_parse_aggregate( message: &str, parser_str: Option<&str>, ) -> anyhow::Result> { + if parser_str.is_none() { + tracing::info!("No tool parser provided. Trying parsing with default parser."); + } + else { + tracing::info!("Using tool parser: {:?}", parser_str); + } let parsed = detect_and_parse_tool_call(message, parser_str)?; if parsed.is_empty() { return Ok(vec![]); From ec5103508450568b96d5bbf9cc579f3e526e18be Mon Sep 17 00:00:00 2001 From: ayushag Date: Fri, 22 Aug 2025 19:29:30 +0000 Subject: [PATCH 7/9] fix: fmt --- lib/parsers/src/tool_calling/tools.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/parsers/src/tool_calling/tools.rs b/lib/parsers/src/tool_calling/tools.rs index 54176b81a2..7f326b46ad 100644 --- a/lib/parsers/src/tool_calling/tools.rs +++ b/lib/parsers/src/tool_calling/tools.rs @@ -16,8 +16,7 @@ pub fn try_tool_call_parse_aggregate( ) -> anyhow::Result> { if parser_str.is_none() { tracing::info!("No tool parser provided. Trying parsing with default parser."); - } - else { + } else { tracing::info!("Using tool parser: {:?}", parser_str); } let parsed = detect_and_parse_tool_call(message, parser_str)?; From 114ba0917327842a87eb06b968085af1e36366a6 Mon Sep 17 00:00:00 2001 From: ayushag Date: Fri, 22 Aug 2025 19:51:13 +0000 Subject: [PATCH 8/9] fix: replaced stream args with parsing opt --- lib/llm/src/http/service/openai.rs | 46 +++++++++---------- lib/llm/src/protocols/openai.rs | 4 +- .../openai/chat_completions/aggregator.rs | 28 +++++------ .../openai/completions/aggregator.rs | 22 ++++----- lib/llm/tests/aggregators.rs | 42 ++++++++++------- 5 files changed, 74 insertions(+), 68 deletions(-) diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index c72b65536a..8a2c9c53a0 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -37,7 +37,7 @@ use crate::protocols::openai::{ completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, responses::{NvCreateResponse, NvResponse}, - StreamArgs, + ParsingOptions, }; use crate::request_template::RequestTemplate; use crate::types::Annotated; @@ -195,11 +195,11 @@ fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> Strin uuid.to_string() } -fn get_stream_args(state: &Arc, model: &str) -> StreamArgs { +fn get_parsing_options(state: &Arc, model: &str) -> ParsingOptions { let tool_call_parser = state.manager().get_model_tool_call_parser(model); let reasoning_parser = None; // TODO: Implement reasoning parser - StreamArgs::new(tool_call_parser, reasoning_parser) + ParsingOptions::new(tool_call_parser, reasoning_parser) } /// OpenAI Completions Request Handler @@ -275,7 +275,7 @@ async fn completions( .get_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; - let extra_stream_args = get_stream_args(&state, model); + let parsing_options = get_parsing_options(&state, model); let mut inflight_guard = state @@ -335,7 +335,7 @@ async fn completions( process_metrics_only(response, &mut response_collector); }); - let response = NvCreateCompletionResponse::from_annotated_stream(stream, extra_stream_args) + let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options) .await .map_err(|e| { tracing::error!( @@ -504,7 +504,7 @@ async fn chat_completions( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; - let extra_stream_args = get_stream_args(&state, model); + let parsing_options = get_parsing_options(&state, model); let mut inflight_guard = state @@ -565,22 +565,20 @@ async fn chat_completions( process_metrics_only(response, &mut response_collector); }); - let response = NvCreateChatCompletionResponse::from_annotated_stream( - stream, - extra_stream_args.clone(), - ) - .await - .map_err(|e| { - tracing::error!( - request_id, - "Failed to fold chat completions stream for: {:?}", - e - ); - ErrorMessage::internal_server_error(&format!( - "Failed to fold chat completions stream: {}", - e - )) - })?; + let response = + NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone()) + .await + .map_err(|e| { + tracing::error!( + request_id, + "Failed to fold chat completions stream for: {:?}", + e + ); + ErrorMessage::internal_server_error(&format!( + "Failed to fold chat completions stream: {}", + e + )) + })?; inflight_guard.mark_ok(); Ok(Json(response).into_response()) @@ -741,7 +739,7 @@ async fn responses( .get_chat_completions_engine(model) .map_err(|_| ErrorMessage::model_not_found())?; - let extra_stream_args = get_stream_args(&state, model); + let parsing_options = get_parsing_options(&state, model); let mut inflight_guard = state @@ -760,7 +758,7 @@ async fn responses( // TODO: handle streaming, currently just unary let response = - NvCreateChatCompletionResponse::from_annotated_stream(stream, extra_stream_args.clone()) + NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options.clone()) .await .map_err(|e| { tracing::error!( diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 2f0a1e9634..668d8e6933 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -195,13 +195,13 @@ pub trait DeltaGeneratorExt: } #[derive(Clone, Debug, Serialize, Deserialize, Default)] -pub struct StreamArgs { +pub struct ParsingOptions { pub tool_call_parser: Option, pub reasoning_parser: Option, } -impl StreamArgs { +impl ParsingOptions { pub fn new(tool_call_parser: Option, reasoning_parser: Option) -> Self { Self { tool_call_parser, diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index 96c59e992c..a99b3e1dda 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -20,7 +20,7 @@ use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse use crate::protocols::{ codec::{Message, SseCodecError}, convert_sse_stream, - openai::StreamArgs, + openai::ParsingOptions, Annotated, }; @@ -101,7 +101,7 @@ impl DeltaAggregator { /// * `Err(String)` if an error occurs during processing. pub async fn apply( stream: impl Stream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result { let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { @@ -180,7 +180,7 @@ impl DeltaAggregator { if choice.tool_calls.is_none() { if let Ok(tool_calls) = try_tool_call_parse_aggregate( &choice.text, - extra_stream_args.tool_call_parser.as_deref(), + parsing_options.tool_call_parser.as_deref(), ) { if tool_calls.is_empty() { continue; @@ -268,7 +268,7 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_annotated_stream( stream: impl Stream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result; /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`]. @@ -281,24 +281,24 @@ pub trait ChatCompletionAggregator { /// * `Err(String)` if an error occurs. async fn from_sse_stream( stream: DataStream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result; } impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse { async fn from_annotated_stream( stream: impl Stream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result { - DeltaAggregator::apply(stream, extra_stream_args).await + DeltaAggregator::apply(stream, parsing_options).await } async fn from_sse_stream( stream: DataStream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateChatCompletionResponse::from_annotated_stream(stream, extra_stream_args).await + NvCreateChatCompletionResponse::from_annotated_stream(stream, parsing_options).await } } @@ -357,7 +357,7 @@ mod tests { Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -387,7 +387,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -431,7 +431,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -502,7 +502,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -560,7 +560,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/src/protocols/openai/completions/aggregator.rs b/lib/llm/src/protocols/openai/completions/aggregator.rs index dd3281398d..e72fc072c2 100644 --- a/lib/llm/src/protocols/openai/completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/completions/aggregator.rs @@ -23,7 +23,7 @@ use crate::protocols::{ codec::{Message, SseCodecError}, common::FinishReason, convert_sse_stream, - openai::StreamArgs, + openai::ParsingOptions, Annotated, DataStream, }; @@ -67,9 +67,9 @@ impl DeltaAggregator { /// Aggregates a stream of [`Annotated`]s into a single [`CompletionResponse`]. pub async fn apply( stream: impl Stream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result { - tracing::debug!("Tool Call Parser: {:?}", extra_stream_args.tool_call_parser); // TODO: remove this once completion has tool call support + tracing::debug!("Tool Call Parser: {:?}", parsing_options.tool_call_parser); // TODO: remove this once completion has tool call support let aggregator = stream .fold(DeltaAggregator::new(), |mut aggregator, delta| async move { let delta = match delta.ok() { @@ -181,17 +181,17 @@ impl From for dynamo_async_openai::types::Choice { impl NvCreateCompletionResponse { pub async fn from_sse_stream( stream: DataStream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result { let stream = convert_sse_stream::(stream); - NvCreateCompletionResponse::from_annotated_stream(stream, extra_stream_args).await + NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options).await } pub async fn from_annotated_stream( stream: impl Stream>, - extra_stream_args: StreamArgs, + parsing_options: ParsingOptions, ) -> Result { - DeltaAggregator::apply(stream, extra_stream_args).await + DeltaAggregator::apply(stream, parsing_options).await } } @@ -247,7 +247,7 @@ mod tests { let stream: DataStream> = Box::pin(stream::empty()); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -271,7 +271,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -311,7 +311,7 @@ mod tests { let stream = Box::pin(stream::iter(annotated_deltas)); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); @@ -371,7 +371,7 @@ mod tests { let stream = Box::pin(stream::iter(vec![annotated_delta])); // Call DeltaAggregator::apply - let result = DeltaAggregator::apply(stream, StreamArgs::default()).await; + let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await; // Check the result assert!(result.is_ok()); diff --git a/lib/llm/tests/aggregators.rs b/lib/llm/tests/aggregators.rs index 794e6c996e..c6ad39dfa9 100644 --- a/lib/llm/tests/aggregators.rs +++ b/lib/llm/tests/aggregators.rs @@ -18,7 +18,7 @@ use dynamo_llm::protocols::{ openai::{ chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse}, completions::NvCreateCompletionResponse, - StreamArgs, + ParsingOptions, }, ContentProvider, DataStream, }; @@ -38,10 +38,12 @@ async fn test_openai_chat_stream() { // note: we are only taking the first 16 messages to keep the size of the response small let stream = create_message_stream(&data).take(16); - let result = - NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) - .await - .unwrap(); + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await + .unwrap(); // todo: provide a cleaner way to extract the content from choices assert_eq!( @@ -61,10 +63,12 @@ async fn test_openai_chat_stream() { #[tokio::test] async fn test_openai_chat_edge_case_multi_line_data() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data"); - let result = - NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) - .await - .unwrap(); + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await + .unwrap(); assert_eq!( result @@ -82,10 +86,12 @@ async fn test_openai_chat_edge_case_multi_line_data() { #[tokio::test] async fn test_openai_chat_edge_case_comments_per_response() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response"); - let result = - NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) - .await - .unwrap(); + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await + .unwrap(); assert_eq!( result @@ -103,9 +109,11 @@ async fn test_openai_chat_edge_case_comments_per_response() { #[tokio::test] async fn test_openai_chat_edge_case_invalid_deserialize_error() { let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error"); - let result = - NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) - .await; + let result = NvCreateChatCompletionResponse::from_sse_stream( + Box::pin(stream), + ParsingOptions::default(), + ) + .await; assert!(result.is_err()); // insta::assert_debug_snapshot!(result); @@ -119,7 +127,7 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() { async fn test_openai_cmpl_stream() { let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16); let result = - NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), StreamArgs::default()) + NvCreateCompletionResponse::from_sse_stream(Box::pin(stream), ParsingOptions::default()) .await .unwrap(); From 174df2e2ee46e2312a118e29e825c46a48c43e67 Mon Sep 17 00:00:00 2001 From: ayushag Date: Fri, 22 Aug 2025 20:02:43 +0000 Subject: [PATCH 9/9] fix: better error handling --- lib/llm/src/discovery/model_manager.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/llm/src/discovery/model_manager.rs b/lib/llm/src/discovery/model_manager.rs index 01d5d57cad..95e05baca0 100644 --- a/lib/llm/src/discovery/model_manager.rs +++ b/lib/llm/src/discovery/model_manager.rs @@ -248,14 +248,15 @@ impl ModelManager { } pub fn get_model_tool_call_parser(&self, model: &str) -> Option { - self.entries - .lock() - .unwrap() - .values() - .find(|entry| entry.name == model) - .and_then(|entry| entry.runtime_config.as_ref()) - .and_then(|config| config.tool_call_parser.clone()) - .map(|parser| parser.to_string()) + match self.entries.lock() { + Ok(entries) => entries + .values() + .find(|entry| entry.name == model) + .and_then(|entry| entry.runtime_config.as_ref()) + .and_then(|config| config.tool_call_parser.clone()) + .map(|parser| parser.to_string()), + Err(_) => None, + } } }