From b1e6eb4e0eb39a5bade514cbadd67ec1d49bf5d9 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 5 Aug 2025 13:28:36 +0100 Subject: [PATCH 01/31] first commit --- lib/llm/src/model_card.rs | 1 + lib/llm/src/model_card/create.rs | 4 ++- lib/llm/src/model_card/model.rs | 39 +++++++++++++++++++++++- lib/llm/src/model_card/runtime_config.rs | 27 ++++++++++++++++ 4 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 lib/llm/src/model_card/runtime_config.rs diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index 10a3e05176..5022c1ad15 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -3,6 +3,7 @@ pub mod create; pub mod model; +pub mod runtime_config; pub use model::ModelDeploymentCard; /// Identify model deployment cards in the key-value store diff --git a/lib/llm/src/model_card/create.rs b/lib/llm/src/model_card/create.rs index d628a0178b..07ffb40bff 100644 --- a/lib/llm/src/model_card/create.rs +++ b/lib/llm/src/model_card/create.rs @@ -3,7 +3,7 @@ use crate::model_card::model::ModelDeploymentCard; use anyhow::{Context, Result}; -use std::path::{Path, PathBuf}; +use std::{collections::HashMap, path::{Path, PathBuf}}; use crate::model_card::model::{ModelInfoType, PromptFormatterArtifact, TokenizerKind}; @@ -93,6 +93,7 @@ impl ModelDeploymentCard { context_length, kv_cache_block_size: 0, migration_limit: 0, + runtime_data: HashMap::new(), }) } @@ -133,6 +134,7 @@ impl ModelDeploymentCard { context_length, kv_cache_block_size: 0, // set later migration_limit: 0, + runtime_data: HashMap::new(), }) } } diff --git a/lib/llm/src/model_card/model.rs b/lib/llm/src/model_card/model.rs index 08efc79484..6db597cdd6 100644 --- a/lib/llm/src/model_card/model.rs +++ b/lib/llm/src/model_card/model.rs @@ -13,6 +13,7 @@ //! - Prompt formatter settings (PromptFormatterArtifact) //! - Various metadata like revision, publish time, etc. +use std::collections::HashMap; use std::fmt; use std::fs::File; use std::path::{Path, PathBuf}; @@ -22,7 +23,7 @@ use std::time::Duration; use anyhow::{Context, Result}; use derive_builder::Builder; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokenizers::Tokenizer as HfTokenizer; use url::Url; @@ -131,6 +132,11 @@ pub struct ModelDeploymentCard { /// How many times a request can be migrated to another worker if the HTTP server lost /// connection to the current worker. pub migration_limit: u32, + + /// Runtime-initialized configuration data that is known during model initialization + /// and does not change over the runtime. Examples: total_kv_blocks, model parameters, etc. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub runtime_data: HashMap, } impl ModelDeploymentCard { @@ -234,6 +240,37 @@ impl ModelDeploymentCard { } } + /// Register runtime data for the model + /// + /// This is used to store configuration data that is known during model initialization + /// and does not change over the runtime. Examples: total_kv_blocks, model parameters, etc. + pub fn register_runtime_data(&mut self, key: &str, value: T) { + self.runtime_data + .insert(key.to_string(), serde_json::to_value(value).unwrap()); + } + + /// Get runtime data for the model + pub fn get_runtime_data(&self, key: &str) -> Option { + self.runtime_data + .get(key) + .and_then(|v| serde_json::from_value(v.clone()).ok()) + } + + /// Check if runtime data exists for a given key + pub fn has_runtime_data(&self, key: &str) -> bool { + self.runtime_data.contains_key(key) + } + + /// Get the total number of KV blocks + pub fn total_kv_blocks(&self) -> Option { + self.get_runtime_data("total_kv_blocks") + } + + /// Register the total number of KV blocks + pub fn register_total_kv_blocks(&mut self, total_kv_blocks: u64) { + self.register_runtime_data("total_kv_blocks", total_kv_blocks); + } + /// Move the files this MDC uses into the NATS object store. /// Updates the URI's to point to NATS. pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> { diff --git a/lib/llm/src/model_card/runtime_config.rs b/lib/llm/src/model_card/runtime_config.rs new file mode 100644 index 0000000000..53e6ac1756 --- /dev/null +++ b/lib/llm/src/model_card/runtime_config.rs @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::model::ModelDeploymentCard; + +/// Helper struct for backend engines to register runtime data during initialization +pub struct RuntimeConfigBuilder { + card: ModelDeploymentCard, +} + +impl RuntimeConfigBuilder { + /// Create a new RuntimeConfigBuilder from a ModelDeploymentCard + pub fn new(card: ModelDeploymentCard) -> Self { + Self { card } + } + + /// Register the total number of KV blocks + pub fn with_total_kv_blocks(mut self, total_blocks: u64) -> Self { + self.card.register_total_kv_blocks(total_blocks); + self + } + + /// Build the final ModelDeploymentCard with all registered runtime data + pub fn build(self) -> ModelDeploymentCard { + self.card + } +} From 8ffe717f1307d7b74a89b4b4e1ec7321b211d59d Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 5 Aug 2025 17:43:15 +0100 Subject: [PATCH 02/31] register runtime config after engine initialization --- .../backends/vllm/src/dynamo/vllm/main.py | 21 +++++ lib/bindings/python/rust/lib.rs | 62 +++++++++++- lib/bindings/python/rust/llm/model_card.rs | 94 +++++++++++++++++++ lib/bindings/python/src/dynamo/_core.pyi | 10 ++ .../python/src/dynamo/llm/__init__.py | 1 + lib/llm/src/local_model.rs | 27 ++++++ lib/llm/src/model_card.rs | 1 + lib/llm/src/model_card/create.rs | 6 +- lib/llm/src/model_card/model.rs | 55 ++++++----- lib/llm/src/model_card/runtime_config.rs | 57 ++++++++--- 10 files changed, 289 insertions(+), 45 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 5365623866..784000b9b7 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -16,7 +16,9 @@ ZmqKvEventPublisher, ZmqKvEventPublisherConfig, register_llm, + register_runtime_config, ) +from dynamo.llm.model_card import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging @@ -167,6 +169,25 @@ async def init(runtime: DistributedRuntime, config: Config): component, engine_client, default_sampling_params, prefill_worker_client ) + # Create a runtime config and register it with the MDC + if not config.engine_args.data_parallel_rank: + runtime_config = ModelRuntimeConfig() + # NOTE: This number needs to be queried directly from the engine, + # since this will compute it if no value was set by the user + runtime_config.with_total_kv_blocks( + engine_client.engine.cache_config.num_gpu_blocks + ) + runtime_config.with_max_num_seqs( + engine_client.vllm_config.scheduler_config.max_num_seqs + ) + + gpu_mem_integer = int( + engine_client.engine.cache_config.gpu_memory_utilization * 100 + ) + runtime_config.with_gpu_memory_utilization(gpu_mem_integer) + + await register_runtime_config(generate_endpoint, config.model, runtime_config) + if config.engine_args.enable_prefix_caching: # TODO: We start off with a valid endpoint, then we increment it by dp_rank # May no longer be valid. Lets remove the increment behavior from vLLM and here diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 5b548352f8..ddf3fe3dc8 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -19,12 +19,16 @@ use dynamo_runtime::{ network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn, }, protocols::annotated::Annotated as RsAnnotated, + slug::Slug, + storage::key_value_store::{EtcdStorage, KeyValueStoreManager}, traits::DistributedRuntimeProvider, }; -use dynamo_llm::{self as llm_rs}; +use dynamo_llm::{self as llm_rs, model_card}; use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig}; +use crate::llm::model_card::ModelRuntimeConfig; + #[pyclass(eq, eq_int)] #[derive(Clone, Debug, PartialEq)] pub enum RouterMode { @@ -63,6 +67,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?; m.add_function(wrap_pyfunction!(log_message, m)?)?; m.add_function(wrap_pyfunction!(register_llm, m)?)?; + m.add_function(wrap_pyfunction!(register_runtime_config, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?; @@ -82,6 +87,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -177,6 +183,60 @@ fn register_llm<'p>( }) } +#[pyfunction] +fn register_runtime_config<'p>( + py: Python<'p>, + endpoint: Endpoint, + model_identifier: &str, + runtime_config: ModelRuntimeConfig, +) -> PyResult> { + let model_identifier = model_identifier.to_string(); + let runtime_config = runtime_config.inner.clone(); + let endpoint = endpoint.inner.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + // Get etcd client from endpoint + let Some(etcd_client) = endpoint.drt().etcd_client() else { + return Err(anyhow::anyhow!( + "Cannot update runtime config on static endpoint" + )) + .map_err(to_pyerr); + }; + + // Create storage manager + let kvstore = EtcdStorage::new(etcd_client.clone()); + let card_store = KeyValueStoreManager::new(Box::new(kvstore)); + + // Generate the model slug - this should match what register_llm used + // The register_llm function uses the model_name (or model_path if no model_name) + // and then calls card.set_name() which sets both display_name and service_name + let model_slug = Slug::slugify(&model_identifier); + + // Get existing card + let mut card = card_store + .load::(model_card::ROOT_PATH, &model_slug) + .await + .map_err(to_pyerr)? + .ok_or_else(|| anyhow::anyhow!("Cannot find model card")) + .map_err(to_pyerr)?; + + // Update the card + card.register_runtime_config(runtime_config); + + // Publish the card + card_store + .publish(model_card::ROOT_PATH, None, &model_slug.to_string(), &mut card) + .await + .map_err(to_pyerr)?; + + tracing::info!( + "Successfully updated runtime config for model card: {}", + model_slug + ); + Ok(()) + }) +} + #[pyclass] #[derive(Clone)] struct EtcdKvCache { diff --git a/lib/bindings/python/rust/llm/model_card.rs b/lib/bindings/python/rust/llm/model_card.rs index 64f2a55c1c..f52f742718 100644 --- a/lib/bindings/python/rust/llm/model_card.rs +++ b/lib/bindings/python/rust/llm/model_card.rs @@ -15,6 +15,7 @@ use super::*; use llm_rs::model_card::model::ModelDeploymentCard as RsModelDeploymentCard; +use llm_rs::model_card::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig; #[pyclass] #[derive(Clone)] @@ -46,4 +47,97 @@ impl ModelDeploymentCard { let json = self.inner.to_json().map_err(to_pyerr)?; Ok(json) } + + fn register_runtime_config(&mut self, runtime_config: ModelRuntimeConfig) { + self.inner.register_runtime_config(runtime_config.inner); + } + + #[getter] + fn runtime_config(&self) -> Option { + self.inner + .runtime_config() + .map(|config| ModelRuntimeConfig { + inner: config.clone(), + }) + } + + #[getter] + fn total_kv_blocks(&self) -> Option { + self.inner.total_kv_blocks() + } + + #[getter] + fn max_num_seqs(&self) -> Option { + self.inner.max_num_seqs() + } + + #[getter] + fn gpu_memory_utilization(&self) -> Option { + self.inner.gpu_memory_utilization() + } +} + +#[pyclass] +#[derive(Clone)] +pub struct ModelRuntimeConfig { + pub(crate) inner: RsModelRuntimeConfig, +} + +#[pymethods] +impl ModelRuntimeConfig { + #[new] + fn new() -> Self { + Self { + inner: RsModelRuntimeConfig::new(), + } + } + + fn with_total_kv_blocks(&mut self, total_kv_blocks: u64) { + self.inner.with_total_kv_blocks(total_kv_blocks); + } + + fn with_max_num_seqs(&mut self, max_num_seqs: u64) { + self.inner.with_max_num_seqs(max_num_seqs); + } + + fn with_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { + self.inner + .with_gpu_memory_utilization(gpu_memory_utilization); + } + + fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { + let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; + self.inner + .set_engine_specific(key, value) + .map_err(to_pyerr)?; + Ok(()) + } + + #[getter] + fn total_kv_blocks(&self) -> Option { + self.inner.total_kv_blocks + } + + #[getter] + fn max_num_seqs(&self) -> Option { + self.inner.max_num_seqs + } + + #[getter] + fn gpu_memory_utilization(&self) -> Option { + self.inner.gpu_memory_utilization + } + + #[getter] + fn runtime_data(&self, py: Python<'_>) -> PyResult { + let dict = PyDict::new(py); + for (key, value) in self.inner.runtime_data.clone() { + dict.set_item(key, value.to_string())?; + } + Ok(dict.into()) + } + + fn get_engine_specific(&self, key: &str) -> PyResult> { + Ok(self.inner.get_engine_specific(key).map_err(to_pyerr)?) + } } diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index a32aaf4d84..baffa1e645 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -438,6 +438,12 @@ class ModelDeploymentCard: ... +class ModelRuntimeConfig: + """ + A model runtime configuration is a collection of runtime information + """ + ... + class OAIChatPreprocessor: """ A preprocessor for OpenAI chat completions @@ -833,6 +839,10 @@ async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: st """Attach the model at path to the given endpoint, and advertise it as model_type""" ... +async def register_runtime_config(endpoint: Endpoint, model_identifier: str, runtime_config: ModelRuntimeConfig) -> None: + """Register runtime configuration with the model card""" + ... + class EngineConfig: """Holds internal configuration for a Dynamo engine.""" ... diff --git a/lib/bindings/python/src/dynamo/llm/__init__.py b/lib/bindings/python/src/dynamo/llm/__init__.py index 053fe4c69c..6c9443a97a 100644 --- a/lib/bindings/python/src/dynamo/llm/__init__.py +++ b/lib/bindings/python/src/dynamo/llm/__init__.py @@ -40,4 +40,5 @@ from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for_seq_py from dynamo._core import make_engine from dynamo._core import register_llm as register_llm +from dynamo._core import register_runtime_config as register_runtime_config from dynamo._core import run_input diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index e12ea3c730..3a4b62f85e 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -16,6 +16,7 @@ use dynamo_runtime::{ use crate::discovery::ModelEntry; use crate::entrypoint::RouterConfig; +use crate::model_card::runtime_config::ModelRuntimeConfig; use crate::model_card::{self, ModelDeploymentCard}; use crate::model_type::ModelType; use crate::request_template::RequestTemplate; @@ -48,6 +49,7 @@ pub struct LocalModelBuilder { http_port: u16, migration_limit: u32, is_mocker: bool, + runtime_config: ModelRuntimeConfig, } impl Default for LocalModelBuilder { @@ -64,6 +66,7 @@ impl Default for LocalModelBuilder { router_config: Default::default(), migration_limit: Default::default(), is_mocker: Default::default(), + runtime_config: Default::default(), } } } @@ -126,6 +129,11 @@ impl LocalModelBuilder { self } + pub fn runtime_config(&mut self, runtime_config: ModelRuntimeConfig) -> &mut Self { + self.runtime_config = runtime_config; + self + } + /// Make an LLM ready for use: /// - Download it from Hugging Face (and NGC in future) if necessary /// - Resolve the path @@ -323,6 +331,25 @@ impl LocalModel { .await } + pub async fn register_runtime_config( + &mut self, + endpoint: &Endpoint, + runtime_config: ModelRuntimeConfig, + ) -> anyhow::Result<()> { + self.card.register_runtime_config(runtime_config); + + // Update the card in etcd if we're attached + if let Some(etcd_client) = endpoint.drt().etcd_client() { + let kvstore: Box = Box::new(EtcdStorage::new(etcd_client.clone())); + let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); + let key = self.card.slug().to_string(); + card_store + .publish(model_card::ROOT_PATH, None, &key, &mut self.card) + .await?; + } + Ok(()) + } + /// Ensure that each component serves only one model. /// We can have multiple instances of the same model running using the same component name /// (they get load balanced, and are differentiated in etcd by their lease_id). diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index 5022c1ad15..bbff28aeae 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -5,6 +5,7 @@ pub mod create; pub mod model; pub mod runtime_config; pub use model::ModelDeploymentCard; +pub use runtime_config::ModelRuntimeConfig; /// Identify model deployment cards in the key-value store pub const ROOT_PATH: &str = "mdc"; diff --git a/lib/llm/src/model_card/create.rs b/lib/llm/src/model_card/create.rs index 07ffb40bff..a50e912329 100644 --- a/lib/llm/src/model_card/create.rs +++ b/lib/llm/src/model_card/create.rs @@ -3,7 +3,7 @@ use crate::model_card::model::ModelDeploymentCard; use anyhow::{Context, Result}; -use std::{collections::HashMap, path::{Path, PathBuf}}; +use std::path::{Path, PathBuf}; use crate::model_card::model::{ModelInfoType, PromptFormatterArtifact, TokenizerKind}; @@ -93,7 +93,7 @@ impl ModelDeploymentCard { context_length, kv_cache_block_size: 0, migration_limit: 0, - runtime_data: HashMap::new(), + runtime_config: None, }) } @@ -134,7 +134,7 @@ impl ModelDeploymentCard { context_length, kv_cache_block_size: 0, // set later migration_limit: 0, - runtime_data: HashMap::new(), + runtime_config: None, }) } } diff --git a/lib/llm/src/model_card/model.rs b/lib/llm/src/model_card/model.rs index 6db597cdd6..d131a948b4 100644 --- a/lib/llm/src/model_card/model.rs +++ b/lib/llm/src/model_card/model.rs @@ -13,7 +13,6 @@ //! - Prompt formatter settings (PromptFormatterArtifact) //! - Various metadata like revision, publish time, etc. -use std::collections::HashMap; use std::fmt; use std::fs::File; use std::path::{Path, PathBuf}; @@ -23,11 +22,12 @@ use std::time::Duration; use anyhow::{Context, Result}; use derive_builder::Builder; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use tokenizers::Tokenizer as HfTokenizer; use url::Url; use crate::gguf::{Content, ContentConfig, ModelConfigLike}; +use crate::model_card::runtime_config::ModelRuntimeConfig; use crate::protocols::TokenIdType; /// If a model deployment card hasn't been refreshed in this much time the worker is likely gone @@ -133,10 +133,10 @@ pub struct ModelDeploymentCard { /// connection to the current worker. pub migration_limit: u32, - /// Runtime-initialized configuration data that is known during model initialization - /// and does not change over the runtime. Examples: total_kv_blocks, model parameters, etc. - #[serde(default, skip_serializing_if = "HashMap::is_empty")] - pub runtime_data: HashMap, + /// Runtime configuration data that is computed during initialization + /// but remains constant during the worker session + #[serde(default, skip_serializing_if = "Option::is_none")] + pub runtime_config: Option, } impl ModelDeploymentCard { @@ -240,35 +240,38 @@ impl ModelDeploymentCard { } } - /// Register runtime data for the model - /// - /// This is used to store configuration data that is known during model initialization - /// and does not change over the runtime. Examples: total_kv_blocks, model parameters, etc. - pub fn register_runtime_data(&mut self, key: &str, value: T) { - self.runtime_data - .insert(key.to_string(), serde_json::to_value(value).unwrap()); + /// Register runtime configuration data that was computed during engine initialization + pub fn register_runtime_config(&mut self, runtime_config: ModelRuntimeConfig) { + self.runtime_config = Some(runtime_config); } - /// Get runtime data for the model - pub fn get_runtime_data(&self, key: &str) -> Option { - self.runtime_data - .get(key) - .and_then(|v| serde_json::from_value(v.clone()).ok()) + /// Update an existing runtime config or create a new one if none exists + pub fn update_runtime_config(&mut self, updater: F) + where + F: FnOnce(&mut ModelRuntimeConfig), + { + let mut config = self.runtime_config.take().unwrap_or_default(); + updater(&mut config); + self.runtime_config = Some(config); } - /// Check if runtime data exists for a given key - pub fn has_runtime_data(&self, key: &str) -> bool { - self.runtime_data.contains_key(key) + pub fn runtime_config(&self) -> Option<&ModelRuntimeConfig> { + self.runtime_config.as_ref() } - /// Get the total number of KV blocks + /// Get total number of KV blocks pub fn total_kv_blocks(&self) -> Option { - self.get_runtime_data("total_kv_blocks") + self.runtime_config.as_ref()?.total_kv_blocks + } + + /// Get maximum number of sequences that can be batched together + pub fn max_num_seqs(&self) -> Option { + self.runtime_config.as_ref()?.max_num_seqs } - /// Register the total number of KV blocks - pub fn register_total_kv_blocks(&mut self, total_kv_blocks: u64) { - self.register_runtime_data("total_kv_blocks", total_kv_blocks); + /// Get GPU memory utilization percentage configured + pub fn gpu_memory_utilization(&self) -> Option { + self.runtime_config.as_ref()?.gpu_memory_utilization } /// Move the files this MDC uses into the NATS object store. diff --git a/lib/llm/src/model_card/runtime_config.rs b/lib/llm/src/model_card/runtime_config.rs index 53e6ac1756..bdf63ed7a5 100644 --- a/lib/llm/src/model_card/runtime_config.rs +++ b/lib/llm/src/model_card/runtime_config.rs @@ -1,27 +1,54 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use super::model::ModelDeploymentCard; +use std::collections::HashMap; -/// Helper struct for backend engines to register runtime data during initialization -pub struct RuntimeConfigBuilder { - card: ModelDeploymentCard, +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct ModelRuntimeConfig { + /// The total number of KV blocks available + pub total_kv_blocks: Option, + + /// The maximum number of sequences that can be batched together + pub max_num_seqs: Option, + + /// GPU memory utilization percentage configured + pub gpu_memory_utilization: Option, + + /// Mapping of engine-specific runtime configs + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub runtime_data: HashMap, } -impl RuntimeConfigBuilder { - /// Create a new RuntimeConfigBuilder from a ModelDeploymentCard - pub fn new(card: ModelDeploymentCard) -> Self { - Self { card } +impl ModelRuntimeConfig { + pub fn new() -> Self { + Self::default() + } + + pub fn with_total_kv_blocks(&mut self, total_kv_blocks: u64) { + self.total_kv_blocks = Some(total_kv_blocks); + } + + pub fn with_max_num_seqs(&mut self, max_num_seqs: u64) { + self.max_num_seqs = Some(max_num_seqs); + } + + pub fn with_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { + self.gpu_memory_utilization = Some(gpu_memory_utilization); } - /// Register the total number of KV blocks - pub fn with_total_kv_blocks(mut self, total_blocks: u64) -> Self { - self.card.register_total_kv_blocks(total_blocks); - self + pub fn set_engine_specific(&mut self, key: &str, value: T) -> anyhow::Result<()> { + self.runtime_data + .insert(key.to_string(), serde_json::to_value(value)?); + Ok(()) } - /// Build the final ModelDeploymentCard with all registered runtime data - pub fn build(self) -> ModelDeploymentCard { - self.card + pub fn get_engine_specific(&self, key: &str) -> anyhow::Result> { + if let Some(value) = self.runtime_data.get(key) { + Ok(Some(serde_json::from_value(value.clone())?)) + } else { + Ok(None) + } } } From 58d73d24b3bdc106b9d7679936ecdb5b13299768 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 5 Aug 2025 20:14:48 +0100 Subject: [PATCH 03/31] add sglang runtime config values retrieval --- .../sglang/src/dynamo/sglang/worker/main.py | 53 ++++++++++++++++++- lib/bindings/python/rust/lib.rs | 2 +- lib/bindings/python/rust/llm/model_card.rs | 2 +- 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 6688428b5c..8c15dee1c6 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -15,6 +15,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_ip, get_zmq_socket +from dynamo._core import Endpoint from dynamo.llm import ( ForwardPassMetrics, KvStats, @@ -24,7 +25,9 @@ ZmqKvEventPublisher, ZmqKvEventPublisherConfig, register_llm, + register_runtime_config, ) +from dynamo.llm.model_card import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging from dynamo.sglang.common import ( @@ -364,12 +367,60 @@ async def init( _ = ZmqKvEventPublisher(component=component, config=zmq_config) tasks = [endpoint.serve_endpoint(handler.generate)] - + tasks.append( + register_runtime_config_once_engine_ready( + endpoint, engine, server_args.served_model_name + ) + ) tasks.extend(setup_native_endpoints(server_args, component, handler)) await asyncio.gather(*tasks) +async def register_runtime_config_once_engine_ready( + endpoint: Endpoint, engine: sgl.Engine, model_name: str +): + max_retries = 3 + retry_delay = 2 + + for attempt in range(max_retries): + try: + server_info = engine.get_server_info() + + runtime_config = ModelRuntimeConfig() + + # Use actual computed values from SGLang engine + if server_info.get("max_total_num_tokens") is not None: + runtime_config.with_total_kv_blocks(server_info["max_total_num_tokens"]) + + if server_info.get("max_running_requests") is not None: + runtime_config.with_max_num_seqs(server_info["max_running_requests"]) + + if server_info.get("mem_fraction_static") is not None: + gpu_mem_percentage = int(server_info["mem_fraction_static"] * 100) + runtime_config.with_gpu_memory_utilization(gpu_mem_percentage) + + # Register the runtime config + await register_runtime_config(endpoint, model_name, runtime_config) + + logging.info( + f"Published runtime config for SGLang decode worker: {model_name}" + ) + return # Successfully published runtime config, exit loop + + except Exception as e: + logging.warning( + f"Attempt {attempt + 1}/{max_retries} failed to publish runtime config: {e}" + ) + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + retry_delay *= 2 + else: + logging.error( + f"Failed to publish runtime config after {max_retries} attempts" + ) + + def main(): uvloop.install() asyncio.run(worker()) diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index ddf3fe3dc8..56412582a3 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -225,7 +225,7 @@ fn register_runtime_config<'p>( // Publish the card card_store - .publish(model_card::ROOT_PATH, None, &model_slug.to_string(), &mut card) + .publish(model_card::ROOT_PATH, None, model_slug.as_ref(), &mut card) .await .map_err(to_pyerr)?; diff --git a/lib/bindings/python/rust/llm/model_card.rs b/lib/bindings/python/rust/llm/model_card.rs index f52f742718..3c9bceb602 100644 --- a/lib/bindings/python/rust/llm/model_card.rs +++ b/lib/bindings/python/rust/llm/model_card.rs @@ -138,6 +138,6 @@ impl ModelRuntimeConfig { } fn get_engine_specific(&self, key: &str) -> PyResult> { - Ok(self.inner.get_engine_specific(key).map_err(to_pyerr)?) + self.inner.get_engine_specific(key).map_err(to_pyerr) } } From 57078901feaa913a3ab2ef58767b482ce83ee8fe Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Thu, 7 Aug 2025 10:13:09 +0100 Subject: [PATCH 04/31] address comments in the PR --- .../sglang/src/dynamo/sglang/worker/main.py | 6 +-- .../backends/vllm/src/dynamo/vllm/main.py | 10 ++--- lib/bindings/python/rust/lib.rs | 2 +- lib/bindings/python/rust/llm/model_card.rs | 40 +++++++++++++------ lib/llm/src/local_model.rs | 2 +- lib/llm/src/model_card/model.rs | 24 ----------- lib/llm/src/model_card/runtime_config.rs | 15 ------- 7 files changed, 36 insertions(+), 63 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 8c15dee1c6..524f217440 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -391,14 +391,14 @@ async def register_runtime_config_once_engine_ready( # Use actual computed values from SGLang engine if server_info.get("max_total_num_tokens") is not None: - runtime_config.with_total_kv_blocks(server_info["max_total_num_tokens"]) + runtime_config.total_kv_blocks = server_info["max_total_num_tokens"] if server_info.get("max_running_requests") is not None: - runtime_config.with_max_num_seqs(server_info["max_running_requests"]) + runtime_config.max_num_seqs = server_info["max_running_requests"] if server_info.get("mem_fraction_static") is not None: gpu_mem_percentage = int(server_info["mem_fraction_static"] * 100) - runtime_config.with_gpu_memory_utilization(gpu_mem_percentage) + runtime_config.gpu_memory_utilization = gpu_mem_percentage # Register the runtime config await register_runtime_config(endpoint, model_name, runtime_config) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 3b7a1937ad..141332e304 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -168,17 +168,13 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config = ModelRuntimeConfig() # NOTE: This number needs to be queried directly from the engine, # since this will compute it if no value was set by the user - runtime_config.with_total_kv_blocks( - engine_client.engine.cache_config.num_gpu_blocks - ) - runtime_config.with_max_num_seqs( - engine_client.vllm_config.scheduler_config.max_num_seqs - ) + runtime_config.total_kv_blocks = engine_client.engine.cache_config.num_gpu_blocks + runtime_config.max_num_seqs = engine_client.vllm_config.scheduler_config.max_num_seqs gpu_mem_integer = int( engine_client.engine.cache_config.gpu_memory_utilization * 100 ) - runtime_config.with_gpu_memory_utilization(gpu_mem_integer) + runtime_config.gpu_memory_utilization = gpu_mem_integer await register_runtime_config(generate_endpoint, config.model, runtime_config) diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 94cacd9bfe..27c66f0e54 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -230,7 +230,7 @@ fn register_runtime_config<'p>( .map_err(to_pyerr)?; // Update the card - card.register_runtime_config(runtime_config); + card.runtime_config = Some(runtime_config); // Publish the card card_store diff --git a/lib/bindings/python/rust/llm/model_card.rs b/lib/bindings/python/rust/llm/model_card.rs index 3c9bceb602..16da8a0afe 100644 --- a/lib/bindings/python/rust/llm/model_card.rs +++ b/lib/bindings/python/rust/llm/model_card.rs @@ -48,14 +48,16 @@ impl ModelDeploymentCard { Ok(json) } + #[setter] fn register_runtime_config(&mut self, runtime_config: ModelRuntimeConfig) { - self.inner.register_runtime_config(runtime_config.inner); + self.inner.runtime_config = Some(runtime_config.inner); } #[getter] fn runtime_config(&self) -> Option { self.inner - .runtime_config() + .runtime_config + .as_ref() .map(|config| ModelRuntimeConfig { inner: config.clone(), }) @@ -63,17 +65,29 @@ impl ModelDeploymentCard { #[getter] fn total_kv_blocks(&self) -> Option { - self.inner.total_kv_blocks() + self.inner + .runtime_config + .as_ref() + .map(|config| config.total_kv_blocks) + .flatten() } #[getter] fn max_num_seqs(&self) -> Option { - self.inner.max_num_seqs() + self.inner + .runtime_config + .as_ref() + .map(|config| config.max_num_seqs) + .flatten() } #[getter] fn gpu_memory_utilization(&self) -> Option { - self.inner.gpu_memory_utilization() + self.inner + .runtime_config + .as_ref() + .map(|config| config.gpu_memory_utilization) + .flatten() } } @@ -92,17 +106,19 @@ impl ModelRuntimeConfig { } } - fn with_total_kv_blocks(&mut self, total_kv_blocks: u64) { - self.inner.with_total_kv_blocks(total_kv_blocks); + #[setter] + fn set_total_kv_blocks(&mut self, total_kv_blocks: u64) { + self.inner.total_kv_blocks = Some(total_kv_blocks); } - fn with_max_num_seqs(&mut self, max_num_seqs: u64) { - self.inner.with_max_num_seqs(max_num_seqs); + #[setter] + fn set_max_num_seqs(&mut self, max_num_seqs: u64) { + self.inner.max_num_seqs = Some(max_num_seqs); } - fn with_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { - self.inner - .with_gpu_memory_utilization(gpu_memory_utilization); + #[setter] + fn set_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { + self.inner.gpu_memory_utilization = Some(gpu_memory_utilization); } fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index d332572412..48231ccc56 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -345,7 +345,7 @@ impl LocalModel { endpoint: &Endpoint, runtime_config: ModelRuntimeConfig, ) -> anyhow::Result<()> { - self.card.register_runtime_config(runtime_config); + self.card.runtime_config = Some(runtime_config); // Update the card in etcd if we're attached if let Some(etcd_client) = endpoint.drt().etcd_client() { diff --git a/lib/llm/src/model_card/model.rs b/lib/llm/src/model_card/model.rs index c42829cd2d..af14249ac5 100644 --- a/lib/llm/src/model_card/model.rs +++ b/lib/llm/src/model_card/model.rs @@ -244,11 +244,6 @@ impl ModelDeploymentCard { } } - /// Register runtime configuration data that was computed during engine initialization - pub fn register_runtime_config(&mut self, runtime_config: ModelRuntimeConfig) { - self.runtime_config = Some(runtime_config); - } - /// Update an existing runtime config or create a new one if none exists pub fn update_runtime_config(&mut self, updater: F) where @@ -259,25 +254,6 @@ impl ModelDeploymentCard { self.runtime_config = Some(config); } - pub fn runtime_config(&self) -> Option<&ModelRuntimeConfig> { - self.runtime_config.as_ref() - } - - /// Get total number of KV blocks - pub fn total_kv_blocks(&self) -> Option { - self.runtime_config.as_ref()?.total_kv_blocks - } - - /// Get maximum number of sequences that can be batched together - pub fn max_num_seqs(&self) -> Option { - self.runtime_config.as_ref()?.max_num_seqs - } - - /// Get GPU memory utilization percentage configured - pub fn gpu_memory_utilization(&self) -> Option { - self.runtime_config.as_ref()?.gpu_memory_utilization - } - /// Move the files this MDC uses into the NATS object store. /// Updates the URI's to point to NATS. pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> { diff --git a/lib/llm/src/model_card/runtime_config.rs b/lib/llm/src/model_card/runtime_config.rs index bdf63ed7a5..fd6e406a02 100644 --- a/lib/llm/src/model_card/runtime_config.rs +++ b/lib/llm/src/model_card/runtime_config.rs @@ -7,13 +7,10 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ModelRuntimeConfig { - /// The total number of KV blocks available pub total_kv_blocks: Option, - /// The maximum number of sequences that can be batched together pub max_num_seqs: Option, - /// GPU memory utilization percentage configured pub gpu_memory_utilization: Option, /// Mapping of engine-specific runtime configs @@ -26,18 +23,6 @@ impl ModelRuntimeConfig { Self::default() } - pub fn with_total_kv_blocks(&mut self, total_kv_blocks: u64) { - self.total_kv_blocks = Some(total_kv_blocks); - } - - pub fn with_max_num_seqs(&mut self, max_num_seqs: u64) { - self.max_num_seqs = Some(max_num_seqs); - } - - pub fn with_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { - self.gpu_memory_utilization = Some(gpu_memory_utilization); - } - pub fn set_engine_specific(&mut self, key: &str, value: T) -> anyhow::Result<()> { self.runtime_data .insert(key.to_string(), serde_json::to_value(value)?); From 61f64246ab1225be453744071165ff9868532fe9 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Thu, 7 Aug 2025 10:59:41 +0100 Subject: [PATCH 05/31] refactor logic to pass in engine initialization runtime args directly via register_llm --- Cargo.lock | 4 +- .../sglang/src/dynamo/sglang/worker/main.py | 79 +++++++++++-------- .../backends/vllm/src/dynamo/vllm/main.py | 33 ++++---- lib/bindings/python/rust/lib.rs | 63 +-------------- lib/bindings/python/rust/llm/model_card.rs | 44 +---------- lib/bindings/python/src/dynamo/_core.pyi | 2 +- lib/llm/src/local_model.rs | 29 ++----- lib/llm/src/model_card/create.rs | 2 - lib/llm/src/model_card/model.rs | 16 ---- 9 files changed, 80 insertions(+), 192 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a5e4b8e2d3..0c2c01a457 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3823,9 +3823,9 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "llama-cpp-2" -version = "0.1.108" +version = "0.1.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1881e45e7306b2d2fdb2b322619ce1dba7a4873a4d358f815976d7b4540952b" +checksum = "c149c78a04d5b733f610388641dfe84ca80806c3fdb153980df20f83511d693e" dependencies = [ "enumflags2", "llama-cpp-sys-2", diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 524f217440..45c3ba8889 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -25,7 +25,6 @@ ZmqKvEventPublisher, ZmqKvEventPublisherConfig, register_llm, - register_runtime_config, ) from dynamo.llm.model_card import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker @@ -336,13 +335,8 @@ async def init( await component.create_service() endpoint = component.endpoint("generate") - await register_llm( - ModelType.Backend, - endpoint, - server_args.model_path, - server_args.served_model_name, - kv_cache_block_size=server_args.page_size, - migration_limit=migration_limit, + await register_llm_with_runtime_config( + engine, endpoint, server_args, migration_limit ) if server_args.disaggregation_mode != "null": @@ -367,57 +361,78 @@ async def init( _ = ZmqKvEventPublisher(component=component, config=zmq_config) tasks = [endpoint.serve_endpoint(handler.generate)] - tasks.append( - register_runtime_config_once_engine_ready( - endpoint, engine, server_args.served_model_name - ) - ) tasks.extend(setup_native_endpoints(server_args, component, handler)) await asyncio.gather(*tasks) -async def register_runtime_config_once_engine_ready( - endpoint: Endpoint, engine: sgl.Engine, model_name: str +async def register_llm_with_runtime_config( + engine: sgl.Engine, + endpoint: Endpoint, + server_args: ServerArgs, + migration_limit: int, ): - max_retries = 3 - retry_delay = 2 + """Register LLM with runtime config""" + runtime_config = await _get_runtime_config(engine) + try: + await register_llm( + ModelType.Backend, + endpoint, + server_args.model_path, + server_args.served_model_name, + kv_cache_block_size=server_args.page_size, + migration_limit=migration_limit, + runtime_config=runtime_config, + ) + except Exception as e: + logging.error(f"Failed to register with runtime config: {e}") + return None + + +async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]: + """Get runtime config from SGLang engine""" + MAX_RETRIES = 3 + RETRY_DELAY = 2 - for attempt in range(max_retries): + for attempt in range(MAX_RETRIES): try: server_info = engine.get_server_info() + if not server_info: + logging.warning("No server info from SGLang engine") + return None runtime_config = ModelRuntimeConfig() - - # Use actual computed values from SGLang engine if server_info.get("max_total_num_tokens") is not None: runtime_config.total_kv_blocks = server_info["max_total_num_tokens"] + logging.info( + f"Set model runtime config total KV blocks: {runtime_config.total_kv_blocks}" + ) if server_info.get("max_running_requests") is not None: runtime_config.max_num_seqs = server_info["max_running_requests"] + logging.info( + f"Set model runtime config max num seqs: {runtime_config.max_num_seqs}" + ) if server_info.get("mem_fraction_static") is not None: gpu_mem_percentage = int(server_info["mem_fraction_static"] * 100) runtime_config.gpu_memory_utilization = gpu_mem_percentage + logging.info( + f"Set model runtime config GPU memory utilization: {gpu_mem_percentage}%" + ) - # Register the runtime config - await register_runtime_config(endpoint, model_name, runtime_config) - - logging.info( - f"Published runtime config for SGLang decode worker: {model_name}" - ) - return # Successfully published runtime config, exit loop + return runtime_config except Exception as e: logging.warning( - f"Attempt {attempt + 1}/{max_retries} failed to publish runtime config: {e}" + f"Attempt {attempt + 1}/{MAX_RETRIES} failed to publish runtime config: {e}" ) - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - retry_delay *= 2 + if attempt < MAX_RETRIES - 1: + await asyncio.sleep(RETRY_DELAY) + RETRY_DELAY *= 2 else: logging.error( - f"Failed to publish runtime config after {max_retries} attempts" + f"Failed to publish runtime config after {MAX_RETRIES} attempts" ) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 7e4c50b85f..007367e045 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -16,7 +16,6 @@ ZmqKvEventPublisher, ZmqKvEventPublisherConfig, register_llm, - register_runtime_config, ) from dynamo.llm.model_card import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker @@ -195,21 +194,6 @@ async def init(runtime: DistributedRuntime, config: Config): component, engine_client, default_sampling_params, prefill_worker_client ) - # Create a runtime config and register it with the MDC - if not config.engine_args.data_parallel_rank: - runtime_config = ModelRuntimeConfig() - # NOTE: This number needs to be queried directly from the engine, - # since this will compute it if no value was set by the user - runtime_config.total_kv_blocks = engine_client.engine.cache_config.num_gpu_blocks - runtime_config.max_num_seqs = engine_client.vllm_config.scheduler_config.max_num_seqs - - gpu_mem_integer = int( - engine_client.engine.cache_config.gpu_memory_utilization * 100 - ) - runtime_config.gpu_memory_utilization = gpu_mem_integer - - await register_runtime_config(generate_endpoint, config.model, runtime_config) - if config.engine_args.enable_prefix_caching: # TODO: We start off with a valid endpoint, then we increment it by dp_rank # May no longer be valid. Lets remove the increment behavior from vLLM and here @@ -230,6 +214,22 @@ async def init(runtime: DistributedRuntime, config: Config): handler.kv_publisher = kv_publisher if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register + runtime_config = ModelRuntimeConfig() + + # NOTE: These values need to be queried directly from the engine, + # since this will compute it if no value was set by the user + runtime_config.total_kv_blocks = ( + engine_client.engine.cache_config.num_gpu_blocks + ) + runtime_config.max_num_seqs = ( + engine_client.vllm_config.scheduler_config.max_num_seqs + ) + + gpu_mem_integer = int( + engine_client.engine.cache_config.gpu_memory_utilization * 100 + ) + runtime_config.gpu_memory_utilization = gpu_mem_integer + await register_llm( ModelType.Backend, generate_endpoint, @@ -237,6 +237,7 @@ async def init(runtime: DistributedRuntime, config: Config): config.served_model_name, kv_cache_block_size=config.engine_args.block_size, migration_limit=config.migration_limit, + runtime_config=runtime_config, ) try: diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 27c66f0e54..27b1e8726e 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -19,12 +19,10 @@ use dynamo_runtime::{ network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn, }, protocols::annotated::Annotated as RsAnnotated, - slug::Slug, - storage::key_value_store::{EtcdStorage, KeyValueStoreManager}, traits::DistributedRuntimeProvider, }; -use dynamo_llm::{self as llm_rs, model_card}; +use dynamo_llm::{self as llm_rs}; use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig}; use crate::llm::model_card::ModelRuntimeConfig; @@ -67,7 +65,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?; m.add_function(wrap_pyfunction!(log_message, m)?)?; m.add_function(wrap_pyfunction!(register_llm, m)?)?; - m.add_function(wrap_pyfunction!(register_runtime_config, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?; @@ -137,7 +134,7 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32) } #[pyfunction] -#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, user_data=None))] +#[pyo3(signature = (model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None))] #[allow(clippy::too_many_arguments)] fn register_llm<'p>( py: Python<'p>, @@ -149,6 +146,7 @@ fn register_llm<'p>( kv_cache_block_size: Option, router_mode: Option, migration_limit: u32, + runtime_config: Option, user_data: Option<&Bound<'p, PyDict>>, ) -> PyResult> { let model_type_obj = match model_type { @@ -179,6 +177,7 @@ fn register_llm<'p>( .kv_cache_block_size(kv_cache_block_size) .router_config(Some(router_config)) .migration_limit(Some(migration_limit)) + .runtime_config(runtime_config.unwrap_or_default().inner) .user_data(user_data_json); // Download from HF, load the ModelDeploymentCard let mut local_model = builder.build().await.map_err(to_pyerr)?; @@ -192,60 +191,6 @@ fn register_llm<'p>( }) } -#[pyfunction] -fn register_runtime_config<'p>( - py: Python<'p>, - endpoint: Endpoint, - model_identifier: &str, - runtime_config: ModelRuntimeConfig, -) -> PyResult> { - let model_identifier = model_identifier.to_string(); - let runtime_config = runtime_config.inner.clone(); - let endpoint = endpoint.inner.clone(); - - pyo3_async_runtimes::tokio::future_into_py(py, async move { - // Get etcd client from endpoint - let Some(etcd_client) = endpoint.drt().etcd_client() else { - return Err(anyhow::anyhow!( - "Cannot update runtime config on static endpoint" - )) - .map_err(to_pyerr); - }; - - // Create storage manager - let kvstore = EtcdStorage::new(etcd_client.clone()); - let card_store = KeyValueStoreManager::new(Box::new(kvstore)); - - // Generate the model slug - this should match what register_llm used - // The register_llm function uses the model_name (or model_path if no model_name) - // and then calls card.set_name() which sets both display_name and service_name - let model_slug = Slug::slugify(&model_identifier); - - // Get existing card - let mut card = card_store - .load::(model_card::ROOT_PATH, &model_slug) - .await - .map_err(to_pyerr)? - .ok_or_else(|| anyhow::anyhow!("Cannot find model card")) - .map_err(to_pyerr)?; - - // Update the card - card.runtime_config = Some(runtime_config); - - // Publish the card - card_store - .publish(model_card::ROOT_PATH, None, model_slug.as_ref(), &mut card) - .await - .map_err(to_pyerr)?; - - tracing::info!( - "Successfully updated runtime config for model card: {}", - model_slug - ); - Ok(()) - }) -} - #[pyclass] #[derive(Clone)] struct EtcdKvCache { diff --git a/lib/bindings/python/rust/llm/model_card.rs b/lib/bindings/python/rust/llm/model_card.rs index 16da8a0afe..6703cc9318 100644 --- a/lib/bindings/python/rust/llm/model_card.rs +++ b/lib/bindings/python/rust/llm/model_card.rs @@ -47,52 +47,10 @@ impl ModelDeploymentCard { let json = self.inner.to_json().map_err(to_pyerr)?; Ok(json) } - - #[setter] - fn register_runtime_config(&mut self, runtime_config: ModelRuntimeConfig) { - self.inner.runtime_config = Some(runtime_config.inner); - } - - #[getter] - fn runtime_config(&self) -> Option { - self.inner - .runtime_config - .as_ref() - .map(|config| ModelRuntimeConfig { - inner: config.clone(), - }) - } - - #[getter] - fn total_kv_blocks(&self) -> Option { - self.inner - .runtime_config - .as_ref() - .map(|config| config.total_kv_blocks) - .flatten() - } - - #[getter] - fn max_num_seqs(&self) -> Option { - self.inner - .runtime_config - .as_ref() - .map(|config| config.max_num_seqs) - .flatten() - } - - #[getter] - fn gpu_memory_utilization(&self) -> Option { - self.inner - .runtime_config - .as_ref() - .map(|config| config.gpu_memory_utilization) - .flatten() - } } #[pyclass] -#[derive(Clone)] +#[derive(Clone, Default)] pub struct ModelRuntimeConfig { pub(crate) inner: RsModelRuntimeConfig, } diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index fee96a86fd..06119ae718 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -843,7 +843,7 @@ async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: st """Attach the model at path to the given endpoint, and advertise it as model_type""" ... -async def register_runtime_config(endpoint: Endpoint, model_identifier: str, runtime_config: ModelRuntimeConfig) -> None: +async def register_runtime_config(endpoint: Endpoint, runtime_config: ModelRuntimeConfig) -> None: """Register runtime configuration with the model card""" ... diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 48231ccc56..fe501c74ee 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -16,8 +16,7 @@ use dynamo_runtime::{ use crate::discovery::ModelEntry; use crate::entrypoint::RouterConfig; -use crate::model_card::runtime_config::ModelRuntimeConfig; -use crate::model_card::{self, ModelDeploymentCard}; +use crate::model_card::{self, ModelDeploymentCard, ModelRuntimeConfig}; use crate::model_type::ModelType; use crate::request_template::RequestTemplate; @@ -178,6 +177,7 @@ impl LocalModelBuilder { template, http_port: self.http_port, router_config: self.router_config.take().unwrap_or_default(), + runtime_config: self.runtime_config.clone(), }); } @@ -236,6 +236,7 @@ impl LocalModelBuilder { template, http_port: self.http_port, router_config: self.router_config.take().unwrap_or_default(), + runtime_config: self.runtime_config.clone(), }) } } @@ -248,6 +249,7 @@ pub struct LocalModel { template: Option, http_port: u16, // Only used if input is HTTP server router_config: RouterConfig, + runtime_config: ModelRuntimeConfig, } impl LocalModel { @@ -279,6 +281,10 @@ impl LocalModel { &self.router_config } + pub fn runtime_config(&self) -> &ModelRuntimeConfig { + &self.runtime_config + } + pub fn is_gguf(&self) -> bool { // GGUF is the only file (not-folder) we accept, so we don't need to check the extension // We will error when we come to parse it @@ -340,25 +346,6 @@ impl LocalModel { .await } - pub async fn register_runtime_config( - &mut self, - endpoint: &Endpoint, - runtime_config: ModelRuntimeConfig, - ) -> anyhow::Result<()> { - self.card.runtime_config = Some(runtime_config); - - // Update the card in etcd if we're attached - if let Some(etcd_client) = endpoint.drt().etcd_client() { - let kvstore: Box = Box::new(EtcdStorage::new(etcd_client.clone())); - let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); - let key = self.card.slug().to_string(); - card_store - .publish(model_card::ROOT_PATH, None, &key, &mut self.card) - .await?; - } - Ok(()) - } - /// Ensure that each component serves only one model. /// We can have multiple instances of the same model running using the same component name /// (they get load balanced, and are differentiated in etcd by their lease_id). diff --git a/lib/llm/src/model_card/create.rs b/lib/llm/src/model_card/create.rs index f3c0e27555..237968c17c 100644 --- a/lib/llm/src/model_card/create.rs +++ b/lib/llm/src/model_card/create.rs @@ -93,7 +93,6 @@ impl ModelDeploymentCard { context_length, kv_cache_block_size: 0, migration_limit: 0, - runtime_config: None, user_data: None, }) } @@ -135,7 +134,6 @@ impl ModelDeploymentCard { context_length, kv_cache_block_size: 0, // set later migration_limit: 0, - runtime_config: None, user_data: None, }) } diff --git a/lib/llm/src/model_card/model.rs b/lib/llm/src/model_card/model.rs index af14249ac5..eb42949811 100644 --- a/lib/llm/src/model_card/model.rs +++ b/lib/llm/src/model_card/model.rs @@ -27,7 +27,6 @@ use tokenizers::Tokenizer as HfTokenizer; use url::Url; use crate::gguf::{Content, ContentConfig, ModelConfigLike}; -use crate::model_card::runtime_config::ModelRuntimeConfig; use crate::protocols::TokenIdType; /// If a model deployment card hasn't been refreshed in this much time the worker is likely gone @@ -133,11 +132,6 @@ pub struct ModelDeploymentCard { /// connection to the current worker. pub migration_limit: u32, - /// Runtime configuration data that is computed during initialization - /// but remains constant during the worker session - #[serde(default, skip_serializing_if = "Option::is_none")] - pub runtime_config: Option, - /// User-defined metadata for custom worker behavior #[serde(default, skip_serializing_if = "Option::is_none")] pub user_data: Option, @@ -244,16 +238,6 @@ impl ModelDeploymentCard { } } - /// Update an existing runtime config or create a new one if none exists - pub fn update_runtime_config(&mut self, updater: F) - where - F: FnOnce(&mut ModelRuntimeConfig), - { - let mut config = self.runtime_config.take().unwrap_or_default(); - updater(&mut config); - self.runtime_config = Some(config); - } - /// Move the files this MDC uses into the NATS object store. /// Updates the URI's to point to NATS. pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> { From b376cfbe5ff0cb069505f0a86bef92618c2672fa Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 11 Aug 2025 14:48:28 +0100 Subject: [PATCH 06/31] resolve _core.py import issues --- .../backends/vllm/src/dynamo/vllm/main.py | 39 +++++++++++++------ lib/bindings/python/src/dynamo/_core.pyi | 4 -- .../python/src/dynamo/llm/__init__.py | 1 - 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 007367e045..e3fff655f3 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -216,18 +216,11 @@ async def init(runtime: DistributedRuntime, config: Config): if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register runtime_config = ModelRuntimeConfig() - # NOTE: These values need to be queried directly from the engine, - # since this will compute it if no value was set by the user - runtime_config.total_kv_blocks = ( - engine_client.engine.cache_config.num_gpu_blocks - ) - runtime_config.max_num_seqs = ( - engine_client.vllm_config.scheduler_config.max_num_seqs - ) - - gpu_mem_integer = int( - engine_client.engine.cache_config.gpu_memory_utilization * 100 - ) + # make a `collective_rpc` call to get runtime configuration values + runtime_values = await get_engine_cache_info(engine_client) + runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] + runtime_config.max_num_seqs = runtime_values["max_num_seqs"] + gpu_mem_integer = runtime_values["gpu_memory_utilization"] runtime_config.gpu_memory_utilization = gpu_mem_integer await register_llm( @@ -255,6 +248,28 @@ async def init(runtime: DistributedRuntime, config: Config): handler.cleanup() +async def get_engine_cache_info(engine: AsyncLLM): + """Retrieve cache configuration information from [`AsyncLLM`] engine.""" + cache_values = await engine.collective_rpc( + lambda worker: { + "num_gpu_blocks": worker.cache_config.num_gpu_blocks, + "gpu_memory_utilization": worker.cache_config.gpu_memory_utilization, + } + ) + + scheduler_values = await engine.collective_rpc( + lambda worker: { + "max_num_seqs": worker.scheduler_config.max_num_seqs, + } + ) + + return { + "num_gpu_blocks": cache_values["num_gpu_blocks"], + "max_num_seqs": scheduler_values["max_num_seqs"], + "gpu_memory_utilization": cache_values["gpu_memory_utilization"], + } + + def main(): uvloop.run(worker()) diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 06119ae718..479b20552c 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -843,10 +843,6 @@ async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: st """Attach the model at path to the given endpoint, and advertise it as model_type""" ... -async def register_runtime_config(endpoint: Endpoint, runtime_config: ModelRuntimeConfig) -> None: - """Register runtime configuration with the model card""" - ... - class EngineConfig: """Holds internal configuration for a Dynamo engine.""" ... diff --git a/lib/bindings/python/src/dynamo/llm/__init__.py b/lib/bindings/python/src/dynamo/llm/__init__.py index 6c9443a97a..053fe4c69c 100644 --- a/lib/bindings/python/src/dynamo/llm/__init__.py +++ b/lib/bindings/python/src/dynamo/llm/__init__.py @@ -40,5 +40,4 @@ from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for_seq_py from dynamo._core import make_engine from dynamo._core import register_llm as register_llm -from dynamo._core import register_runtime_config as register_runtime_config from dynamo._core import run_input From 9d3cbb110101f4f0ef355d8399847565156916a4 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 11 Aug 2025 16:43:31 +0100 Subject: [PATCH 07/31] resolve runtime issues --- .../sglang/src/dynamo/sglang/worker/main.py | 2 + .../backends/vllm/src/dynamo/vllm/main.py | 47 ++++++++++++------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 45c3ba8889..943e958303 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -435,6 +435,8 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig f"Failed to publish runtime config after {MAX_RETRIES} attempts" ) + return None + def main(): uvloop.install() diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index e3fff655f3..bdcb498f36 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -217,12 +217,19 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config = ModelRuntimeConfig() # make a `collective_rpc` call to get runtime configuration values + logging.info( + "Getting engine runtime configuration metadata from vLLM engine..." + ) runtime_values = await get_engine_cache_info(engine_client) runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] gpu_mem_integer = runtime_values["gpu_memory_utilization"] runtime_config.gpu_memory_utilization = gpu_mem_integer + logging.info( + f"Registering model {config.model} with runtime config: {runtime_config}" + ) + await register_llm( ModelType.Backend, generate_endpoint, @@ -250,24 +257,32 @@ async def init(runtime: DistributedRuntime, config: Config): async def get_engine_cache_info(engine: AsyncLLM): """Retrieve cache configuration information from [`AsyncLLM`] engine.""" - cache_values = await engine.collective_rpc( - lambda worker: { - "num_gpu_blocks": worker.cache_config.num_gpu_blocks, - "gpu_memory_utilization": worker.cache_config.gpu_memory_utilization, - } - ) - scheduler_values = await engine.collective_rpc( - lambda worker: { - "max_num_seqs": worker.scheduler_config.max_num_seqs, - } - ) + try: + cache_values = await engine.collective_rpc( + lambda worker: { + "num_gpu_blocks": worker.cache_config.num_gpu_blocks, + "gpu_memory_utilization": worker.cache_config.gpu_memory_utilization, + } + ) - return { - "num_gpu_blocks": cache_values["num_gpu_blocks"], - "max_num_seqs": scheduler_values["max_num_seqs"], - "gpu_memory_utilization": cache_values["gpu_memory_utilization"], - } + scheduler_values = await engine.collective_rpc( + lambda worker: { + "max_num_seqs": worker.scheduler_config.max_num_seqs, + } + ) + logging.info(f"Collective RPC cache values: {cache_values}") + logging.info(f"Collective RPC scheduler values: {scheduler_values}") + return { + "num_gpu_blocks": cache_values["num_gpu_blocks"], + "max_num_seqs": scheduler_values["max_num_seqs"], + "gpu_memory_utilization": cache_values["gpu_memory_utilization"], + } + except Exception as e: + logging.error( + f"Failed to get collective RPC configuration values from vLLM engine: {e}" + ) + raise def main(): From d1b87f538ed093e39d99adc74832e33a41623f4c Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 11 Aug 2025 19:48:25 +0100 Subject: [PATCH 08/31] resolve import issues --- components/backends/sglang/src/dynamo/sglang/worker/main.py | 2 +- components/backends/vllm/src/dynamo/vllm/main.py | 2 +- lib/bindings/python/src/dynamo/llm/__init__.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 943e958303..df31a58923 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -19,6 +19,7 @@ from dynamo.llm import ( ForwardPassMetrics, KvStats, + ModelRuntimeConfig, ModelType, WorkerMetricsPublisher, WorkerStats, @@ -26,7 +27,6 @@ ZmqKvEventPublisherConfig, register_llm, ) -from dynamo.llm.model_card import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging from dynamo.sglang.common import ( diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index bdcb498f36..69419b23db 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -12,12 +12,12 @@ from vllm.v1.engine.async_llm import AsyncLLM from dynamo.llm import ( + ModelRuntimeConfig, ModelType, ZmqKvEventPublisher, ZmqKvEventPublisherConfig, register_llm, ) -from dynamo.llm.model_card import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging diff --git a/lib/bindings/python/src/dynamo/llm/__init__.py b/lib/bindings/python/src/dynamo/llm/__init__.py index 053fe4c69c..80a698339b 100644 --- a/lib/bindings/python/src/dynamo/llm/__init__.py +++ b/lib/bindings/python/src/dynamo/llm/__init__.py @@ -26,6 +26,7 @@ from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRouterConfig as KvRouterConfig from dynamo._core import KvStats as KvStats +from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelType as ModelType from dynamo._core import OverlapScores as OverlapScores from dynamo._core import RadixTree as RadixTree From d18881b65986966e4dd89ee21f9bf3989495dc6d Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 11 Aug 2025 21:10:16 +0100 Subject: [PATCH 09/31] resolve import issues --- .../backends/vllm/src/dynamo/vllm/main.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 69419b23db..16ffd150e3 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -12,12 +12,12 @@ from vllm.v1.engine.async_llm import AsyncLLM from dynamo.llm import ( - ModelRuntimeConfig, ModelType, ZmqKvEventPublisher, ZmqKvEventPublisherConfig, register_llm, ) +from dynamo.llm import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging @@ -259,20 +259,18 @@ async def get_engine_cache_info(engine: AsyncLLM): """Retrieve cache configuration information from [`AsyncLLM`] engine.""" try: - cache_values = await engine.collective_rpc( - lambda worker: { - "num_gpu_blocks": worker.cache_config.num_gpu_blocks, - "gpu_memory_utilization": worker.cache_config.gpu_memory_utilization, - } - ) - - scheduler_values = await engine.collective_rpc( - lambda worker: { - "max_num_seqs": worker.scheduler_config.max_num_seqs, - } - ) - logging.info(f"Collective RPC cache values: {cache_values}") - logging.info(f"Collective RPC scheduler values: {scheduler_values}") + # Get values directly from vllm_config instead of collective_rpc + cache_values = { + "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks, + "gpu_memory_utilization": engine.vllm_config.cache_config.gpu_memory_utilization, + } + + scheduler_values = { + "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs, + } + + logging.info(f"Cache config values: {cache_values}") + logging.info(f"Scheduler config values: {scheduler_values}") return { "num_gpu_blocks": cache_values["num_gpu_blocks"], "max_num_seqs": scheduler_values["max_num_seqs"], @@ -280,7 +278,7 @@ async def get_engine_cache_info(engine: AsyncLLM): } except Exception as e: logging.error( - f"Failed to get collective RPC configuration values from vLLM engine: {e}" + f"Failed to get configuration values from vLLM config: {e}" ) raise From 24712cb0c56d8337b983abe28b3c1fe5364872b2 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 11 Aug 2025 21:10:39 +0100 Subject: [PATCH 10/31] resolve vllm cache config issues --- components/backends/vllm/src/dynamo/vllm/main.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 16ffd150e3..7271ef434e 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -12,12 +12,12 @@ from vllm.v1.engine.async_llm import AsyncLLM from dynamo.llm import ( + ModelRuntimeConfig, ModelType, ZmqKvEventPublisher, ZmqKvEventPublisherConfig, register_llm, ) -from dynamo.llm import ModelRuntimeConfig from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging @@ -264,11 +264,11 @@ async def get_engine_cache_info(engine: AsyncLLM): "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks, "gpu_memory_utilization": engine.vllm_config.cache_config.gpu_memory_utilization, } - + scheduler_values = { "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs, } - + logging.info(f"Cache config values: {cache_values}") logging.info(f"Scheduler config values: {scheduler_values}") return { @@ -277,9 +277,7 @@ async def get_engine_cache_info(engine: AsyncLLM): "gpu_memory_utilization": cache_values["gpu_memory_utilization"], } except Exception as e: - logging.error( - f"Failed to get configuration values from vLLM config: {e}" - ) + logging.error(f"Failed to get configuration values from vLLM config: {e}") raise From c20f0e169b6722feb1fd27e6414c5da486cbd853 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 11 Aug 2025 21:13:43 +0100 Subject: [PATCH 11/31] resolve non-int gpu_mem_integer issue --- components/backends/vllm/src/dynamo/vllm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 7271ef434e..6682690e4a 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -224,7 +224,7 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] gpu_mem_integer = runtime_values["gpu_memory_utilization"] - runtime_config.gpu_memory_utilization = gpu_mem_integer + runtime_config.gpu_memory_utilization = (gpu_mem_integer * 100) logging.info( f"Registering model {config.model} with runtime config: {runtime_config}" From af94e4b0c3553ffe22941c4e85ccfc548b8d658a Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Mon, 11 Aug 2025 21:15:33 +0100 Subject: [PATCH 12/31] resolve non-int gpu_mem_integer issue --- components/backends/vllm/src/dynamo/vllm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 6682690e4a..2d98b0bca9 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -224,7 +224,7 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] gpu_mem_integer = runtime_values["gpu_memory_utilization"] - runtime_config.gpu_memory_utilization = (gpu_mem_integer * 100) + runtime_config.gpu_memory_utilization = int(gpu_mem_integer * 100) logging.info( f"Registering model {config.model} with runtime config: {runtime_config}" From 57e12c2a127de21100a24ee5ad444ecf769ecac9 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Tue, 12 Aug 2025 22:52:19 +0100 Subject: [PATCH 13/31] remove uneeded async in python code --- components/backends/vllm/src/dynamo/vllm/main.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index 2d98b0bca9..ca5f67698f 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -220,16 +220,12 @@ async def init(runtime: DistributedRuntime, config: Config): logging.info( "Getting engine runtime configuration metadata from vLLM engine..." ) - runtime_values = await get_engine_cache_info(engine_client) + runtime_values = get_engine_cache_info(engine_client) runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] gpu_mem_integer = runtime_values["gpu_memory_utilization"] runtime_config.gpu_memory_utilization = int(gpu_mem_integer * 100) - logging.info( - f"Registering model {config.model} with runtime config: {runtime_config}" - ) - await register_llm( ModelType.Backend, generate_endpoint, @@ -255,7 +251,7 @@ async def init(runtime: DistributedRuntime, config: Config): handler.cleanup() -async def get_engine_cache_info(engine: AsyncLLM): +def get_engine_cache_info(engine: AsyncLLM): """Retrieve cache configuration information from [`AsyncLLM`] engine.""" try: From d4b1edfb7a74a2412b96033cf118f5093bcfcef9 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 16:20:38 -0700 Subject: [PATCH 14/31] revert llama-cpp version in Cargo.lock --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 028cfd8cb4..0cc1f5e248 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3838,9 +3838,9 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "llama-cpp-2" -version = "0.1.114" +version = "0.1.108" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c149c78a04d5b733f610388641dfe84ca80806c3fdb153980df20f83511d693e" +checksum = "a1881e45e7306b2d2fdb2b322619ce1dba7a4873a4d358f815976d7b4540952b" dependencies = [ "enumflags2", "llama-cpp-sys-2", From b7ca2f502c86623cac6a59ca4634b989e5c8b3ff Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 16:27:00 -0700 Subject: [PATCH 15/31] move runtime config into local_model --- lib/bindings/python/rust/llm/model_card.rs | 2 +- lib/llm/src/lib.rs | 1 - lib/llm/src/local_model.rs | 4 +++- lib/llm/src/{ => local_model}/runtime_config.rs | 0 4 files changed, 4 insertions(+), 3 deletions(-) rename lib/llm/src/{ => local_model}/runtime_config.rs (100%) diff --git a/lib/bindings/python/rust/llm/model_card.rs b/lib/bindings/python/rust/llm/model_card.rs index a789cb0927..ef83c4b98c 100644 --- a/lib/bindings/python/rust/llm/model_card.rs +++ b/lib/bindings/python/rust/llm/model_card.rs @@ -3,7 +3,7 @@ use super::*; use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard; -use llm_rs::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig; +use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig; #[pyclass] #[derive(Clone)] diff --git a/lib/llm/src/lib.rs b/lib/llm/src/lib.rs index d5a45ebd13..a65e8f8efa 100644 --- a/lib/llm/src/lib.rs +++ b/lib/llm/src/lib.rs @@ -31,7 +31,6 @@ pub mod preprocessor; pub mod protocols; pub mod recorder; pub mod request_template; -pub mod runtime_config; pub mod tokenizers; pub mod tokens; pub mod types; diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 5ace4672b6..4063329f56 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -19,10 +19,12 @@ use crate::entrypoint::RouterConfig; use crate::model_card::{self, ModelDeploymentCard}; use crate::model_type::ModelType; use crate::request_template::RequestTemplate; -use crate::runtime_config::ModelRuntimeConfig; mod network_name; pub use network_name::ModelNetworkName; +pub mod runtime_config; + +use runtime_config::ModelRuntimeConfig; /// Prefix for Hugging Face model repository const HF_SCHEME: &str = "hf://"; diff --git a/lib/llm/src/runtime_config.rs b/lib/llm/src/local_model/runtime_config.rs similarity index 100% rename from lib/llm/src/runtime_config.rs rename to lib/llm/src/local_model/runtime_config.rs From becb754ee94b80130c1e36b80ea68a3864699b14 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 16:48:24 -0700 Subject: [PATCH 16/31] put runtime config in ModelEntry so it gets registered to etcd --- lib/llm/src/discovery/model_entry.rs | 5 +++++ lib/llm/src/local_model.rs | 1 + lib/llm/src/local_model/runtime_config.rs | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/llm/src/discovery/model_entry.rs b/lib/llm/src/discovery/model_entry.rs index 8801598026..6b4c1c0be8 100644 --- a/lib/llm/src/discovery/model_entry.rs +++ b/lib/llm/src/discovery/model_entry.rs @@ -12,6 +12,7 @@ use dynamo_runtime::{ use serde::{Deserialize, Serialize}; use crate::{ + local_model::runtime_config::ModelRuntimeConfig, model_card::{self, ModelDeploymentCard}, model_type::ModelType, }; @@ -28,6 +29,10 @@ pub struct ModelEntry { /// Specifies whether the model is a chat, completions, etc model. pub model_type: ModelType, + + /// Runtime configuration specific to this model instance + #[serde(default, skip_serializing_if = "Option::is_none")] + pub runtime_config: Option, } impl ModelEntry { diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 4063329f56..59f1d1ce62 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -342,6 +342,7 @@ impl LocalModel { name: self.display_name().to_string(), endpoint: endpoint.id(), model_type, + runtime_config: Some(self.runtime_config.clone()), }; etcd_client .kv_create( diff --git a/lib/llm/src/local_model/runtime_config.rs b/lib/llm/src/local_model/runtime_config.rs index fd6e406a02..f83f5d6335 100644 --- a/lib/llm/src/local_model/runtime_config.rs +++ b/lib/llm/src/local_model/runtime_config.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)] pub struct ModelRuntimeConfig { pub total_kv_blocks: Option, From 5adaeb1ac0be487648d2a6feedc7cc3f3ac7314d Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 16:49:15 -0700 Subject: [PATCH 17/31] fmt --- lib/bindings/python/rust/llm/model_card.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bindings/python/rust/llm/model_card.rs b/lib/bindings/python/rust/llm/model_card.rs index ef83c4b98c..d6f76d06db 100644 --- a/lib/bindings/python/rust/llm/model_card.rs +++ b/lib/bindings/python/rust/llm/model_card.rs @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; -use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard; use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig; +use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard; #[pyclass] #[derive(Clone)] From 950e6a4f8529b099bd92e77659763c565043ebf4 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 17:27:08 -0700 Subject: [PATCH 18/31] if mocker, override runtime configs --- lib/bindings/python/rust/llm/entrypoint.rs | 3 ++- lib/llm/src/local_model.rs | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index c2843bac01..dc59806676 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -164,7 +164,8 @@ pub fn make_engine<'p>( .kv_cache_block_size(args.kv_cache_block_size) .router_config(args.router_config.clone().map(|rc| rc.into())) .http_port(args.http_port) - .is_mocker(matches!(args.engine_type, EngineType::Mocker)); + .is_mocker(matches!(args.engine_type, EngineType::Mocker)) + .extra_engine_args(args.extra_engine_args.clone()); pyo3_async_runtimes::tokio::future_into_py(py, async move { let local_model = builder.build().await.map_err(to_pyerr)?; let inner = select_engine(distributed_runtime, args, local_model) diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 59f1d1ce62..5884d5eee8 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -16,6 +16,7 @@ use dynamo_runtime::{ use crate::discovery::ModelEntry; use crate::entrypoint::RouterConfig; +use crate::mocker::protocols::MockEngineArgs; use crate::model_card::{self, ModelDeploymentCard}; use crate::model_type::ModelType; use crate::request_template::RequestTemplate; @@ -51,6 +52,7 @@ pub struct LocalModelBuilder { http_port: u16, migration_limit: u32, is_mocker: bool, + extra_engine_args: Option, runtime_config: ModelRuntimeConfig, user_data: Option, } @@ -69,6 +71,7 @@ impl Default for LocalModelBuilder { router_config: Default::default(), migration_limit: Default::default(), is_mocker: Default::default(), + extra_engine_args: Default::default(), runtime_config: Default::default(), user_data: Default::default(), } @@ -133,6 +136,11 @@ impl LocalModelBuilder { self } + pub fn extra_engine_args(&mut self, extra_engine_args: Option) -> &mut Self { + self.extra_engine_args = extra_engine_args; + self + } + pub fn runtime_config(&mut self, runtime_config: ModelRuntimeConfig) -> &mut Self { self.runtime_config = runtime_config; self @@ -229,6 +237,18 @@ impl LocalModelBuilder { card.context_length = context_length; } + // Override runtime configs with mocker engine args + if self.is_mocker { + if let Some(path) = &self.extra_engine_args { + let mocker_engine_args = MockEngineArgs::from_json_file(path) + .expect("Failed to load mocker engine args for runtime config overriding."); + self.runtime_config.total_kv_blocks = + Some(mocker_engine_args.num_gpu_blocks as u64); + self.runtime_config.max_num_seqs = + mocker_engine_args.max_num_seqs.map(|v| v as u64); + } + } + card.migration_limit = self.migration_limit; card.user_data = self.user_data.take(); From cbbd03b6edd535ae12ebf521c5cb550bc244a3a3 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 18:28:12 -0700 Subject: [PATCH 19/31] router listens to runtime configs (kv total blocks) --- components/router/src/main.rs | 5 +- lib/llm/src/kv_router.rs | 17 ++++- lib/llm/src/kv_router/metrics_aggregator.rs | 72 +++++++++++++++++++ lib/llm/src/kv_router/scheduler.rs | 78 ++++++++++++++++----- 4 files changed, 150 insertions(+), 22 deletions(-) diff --git a/components/router/src/main.rs b/components/router/src/main.rs index 1ee7fb64fd..8961a5fac9 100644 --- a/components/router/src/main.rs +++ b/components/router/src/main.rs @@ -20,6 +20,7 @@ // 2. Update the backend component to produce a config in a standard location. // 3. Update the KvRouter to read the config from the backend component. +use std::collections::HashMap; use std::sync::Arc; use clap::Parser; @@ -29,7 +30,7 @@ use dynamo_llm::kv_router::{ scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, KvRouter, WorkerSelector, }; -use dynamo_runtime::component::Instance; +use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig; use dynamo_runtime::{ logging, pipeline::network::Ingress, DistributedRuntime, Result, Runtime, Worker, }; @@ -86,7 +87,7 @@ pub struct CustomWorkerSelector(DefaultWorkerSelector); impl WorkerSelector for CustomWorkerSelector { fn select_worker( &self, - workers: &[Instance], + workers: &HashMap>, request: &SchedulingRequest, block_size: u32, ) -> Result { diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 1e488872e1..4bee158ea9 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -34,16 +35,16 @@ use crate::{ compute_block_hash_for_seq, compute_seq_hash_for_block, KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent, }, - // metrics_aggregator::EndpointCollector, + metrics_aggregator::watch_model_runtime_configs, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scoring::ProcessedEndpoints, }, + local_model::runtime_config::ModelRuntimeConfig, preprocessor::PreprocessedRequest, protocols::common::llm_backend::LLMEngineOutput, }; -use dynamo_runtime::component::Instance; use dynamo_runtime::traits::events::EventSubscriber; // [gluo TODO] shouldn't need to be public @@ -65,7 +66,7 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; pub trait WorkerSelector { fn select_worker( &self, - workers: &[Instance], + workers: &HashMap>, request: &SchedulingRequest, block_size: u32, ) -> Result; @@ -176,6 +177,15 @@ impl KvRouter { } }; + // Create runtime config watcher + // TODO: Migrate to discovery_client() once it exposes kv_get_and_watch_prefix functionality + let etcd_client = component + .drt() + .etcd_client() + .expect("Cannot KV route without etcd client"); + let runtime_configs_rx = + watch_model_runtime_configs(etcd_client, cancellation_token.clone()).await?; + let indexer = if kv_router_config.use_kv_events { Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) } else { @@ -191,6 +201,7 @@ impl KvRouter { component.clone(), block_size, instances_rx, + runtime_configs_rx, selector, kv_router_config.router_replica_sync, ) diff --git a/lib/llm/src/kv_router/metrics_aggregator.rs b/lib/llm/src/kv_router/metrics_aggregator.rs index 7ab4e1372c..9fa3f11865 100644 --- a/lib/llm/src/kv_router/metrics_aggregator.rs +++ b/lib/llm/src/kv_router/metrics_aggregator.rs @@ -18,10 +18,14 @@ use std::sync::Once; pub use crate::kv_router::protocols::{ForwardPassMetrics, LoadMetrics, PredictiveLoadMetrics}; use crate::kv_router::KV_METRICS_ENDPOINT; +use crate::discovery::{ModelEntry, MODEL_ROOT_PATH}; use crate::kv_router::scoring::Endpoint; use crate::kv_router::ProcessedEndpoints; +use crate::local_model::runtime_config::ModelRuntimeConfig; use dynamo_runtime::component::Component; +use dynamo_runtime::transports::etcd::{Client as EtcdClient, WatchEvent}; use dynamo_runtime::{service::EndpointInfo, utils::Duration, Result}; +use std::collections::HashMap; use tokio::sync::watch; use tokio_util::sync::CancellationToken; @@ -208,3 +212,71 @@ pub async fn collect_endpoints_task( } } } + +pub async fn watch_model_runtime_configs( + etcd_client: EtcdClient, + cancellation_token: CancellationToken, +) -> Result>> { + let (watch_tx, watch_rx) = watch::channel(HashMap::new()); + + let prefix_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?; + let (_prefix, _watcher, mut events_rx) = prefix_watcher.dissolve(); + + tokio::spawn(async move { + let mut runtime_configs: HashMap = HashMap::new(); + + loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + tracing::debug!("Runtime config watcher cancelled"); + break; + } + event = events_rx.recv() => { + let Some(event) = event else { + tracing::debug!("Runtime config watch stream closed"); + break; + }; + + match event { + WatchEvent::Put(kv) => { + let Ok(model_entry) = serde_json::from_slice::(kv.value()) else { + tracing::warn!( + "Failed to parse ModelEntry from etcd. Key: {}", + kv.key_str().unwrap_or("") + ); + continue; + }; + + let lease_id = kv.lease(); + + if let Some(runtime_config) = model_entry.runtime_config { + runtime_configs.insert(lease_id, runtime_config); + tracing::trace!("Updated runtime config for lease_id: {}", lease_id); + } else { + runtime_configs.remove(&lease_id); + tracing::trace!("Removed runtime config (no config in ModelEntry)"); + } + + if watch_tx.send(runtime_configs.clone()).is_err() { + tracing::error!("Failed to send runtime configs update; receiver dropped"); + break; + } + } + WatchEvent::Delete(kv) => { + let lease_id = kv.lease(); + runtime_configs.remove(&lease_id); + tracing::trace!("Removed runtime config for deleted entry"); + + if watch_tx.send(runtime_configs.clone()).is_err() { + tracing::error!("Failed to send runtime configs update; receiver dropped"); + break; + } + } + } + } + } + } + }); + + Ok(watch_rx) +} diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 6603b0e906..4fadb952dd 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use crate::local_model::runtime_config::ModelRuntimeConfig; use dynamo_runtime::component::{Component, Instance}; use dynamo_runtime::traits::events::EventPublisher; use rand::Rng; @@ -8,6 +9,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; +use tokio::sync::watch; use super::indexer::OverlapScores; use super::protocols::WorkerSelectionResult; @@ -77,12 +79,15 @@ impl KvScheduler { pub async fn start( component: Component, block_size: u32, - mut instances_rx: tokio::sync::watch::Receiver>, // Changed from ProcessedEndpoints + mut instances_rx: watch::Receiver>, + mut runtime_configs_rx: watch::Receiver>, selector: Option>, replica_sync: bool, ) -> Result { let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let mut instances: Vec = instances_rx.borrow_and_update().clone(); + let mut runtime_configs: HashMap = + runtime_configs_rx.borrow_and_update().clone(); let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); let ns_clone = component.namespace().clone(); @@ -112,10 +117,15 @@ impl KvScheduler { tokio::spawn(async move { let mut request_rx = request_rx; tracing::trace!("scheduler background task started"); + let mut workers_with_configs: HashMap> = HashMap::new(); + let mut needs_rebuild = true; loop { - // First, check for instance updates (non-blocking) - match instances_rx.has_changed() { + // Check for instance updates (non-blocking) + let instances_changed = instances_rx.has_changed(); + let configs_changed = runtime_configs_rx.has_changed(); + + match instances_changed { Ok(true) => { instances = instances_rx.borrow_and_update().clone(); let worker_ids: Vec = instances @@ -123,17 +133,42 @@ impl KvScheduler { .map(|instance| instance.instance_id) .collect(); slots_clone.update_workers(worker_ids); + needs_rebuild = true; } - Ok(false) => { - // No changes, continue. This is the happy path. - } + Ok(false) => {} Err(_) => { tracing::warn!("endpoint watch sender shutdown"); break; } } - // Then, wait for a new request + // Check for runtime config updates + match configs_changed { + Ok(true) => { + runtime_configs = runtime_configs_rx.borrow_and_update().clone(); + needs_rebuild = true; + } + Ok(false) => {} + Err(_) => { + tracing::warn!("runtime configs watch sender shutdown"); + } + } + + // Rebuild workers hashmap only when needed + if needs_rebuild { + workers_with_configs.clear(); + for instance in &instances { + let worker_id = instance.instance_id; + let config = runtime_configs.get(&worker_id).cloned(); + if config.is_none() { + tracing::warn!("Runtime config not found for worker_id: {}", worker_id); + } + workers_with_configs.insert(worker_id, config); + } + needs_rebuild = false; + } + + // Wait for a new request let Some(mut request) = request_rx.recv().await else { tracing::warn!("scheduler shutdown"); break; @@ -150,7 +185,7 @@ impl KvScheduler { request.decode_blocks = decode_blocks; request.prefill_tokens = prefill_tokens; - match selector.select_worker(&instances, &request, block_size) { + match selector.select_worker(&workers_with_configs, &request, block_size) { Ok(selection) => { if let Err(e) = event_tx.send(KVHitRateEvent { worker_id: selection.worker_id, @@ -333,7 +368,7 @@ impl DefaultWorkerSelector { impl WorkerSelector for DefaultWorkerSelector { fn select_worker( &self, - workers: &[Instance], + workers: &HashMap>, request: &SchedulingRequest, block_size: u32, ) -> Result { @@ -354,17 +389,16 @@ impl WorkerSelector for DefaultWorkerSelector { let mut max_logit = f64::NEG_INFINITY; // Calculate logits for each worker - for instance in workers.iter() { - let worker_id = instance.instance_id; - let overlap = *overlaps.get(&worker_id).unwrap_or(&0); + for (worker_id, runtime_config) in workers.iter() { + let overlap = *overlaps.get(worker_id).unwrap_or(&0); // this is the number of prefill tokens the worker would have if the request were scheduled there - let prefill_token = *prefill_tokens.get(&worker_id).unwrap_or(&isl); + let prefill_token = *prefill_tokens.get(worker_id).unwrap_or(&isl); let potential_prefill_block = (prefill_token as f64) / (block_size as f64); // this is the number of decode blocks the worker would have if the request were scheduled there let decode_block = *decode_blocks - .get(&worker_id) + .get(worker_id) .unwrap_or(&(potential_prefill_block.floor() as usize)) as f64; @@ -373,7 +407,7 @@ impl WorkerSelector for DefaultWorkerSelector { self.kv_router_config.overlap_score_weight * potential_prefill_block + decode_block; max_logit = max_logit.max(logit); - worker_logits.insert(worker_id, logit); + worker_logits.insert(*worker_id, logit); let overlap_weight = self.kv_router_config.overlap_score_weight; tracing::info!( @@ -388,10 +422,20 @@ impl WorkerSelector for DefaultWorkerSelector { let best_worker_id = softmax_sample(&worker_logits, temperature); let best_logit = worker_logits[&best_worker_id]; + let best_overlap = *overlaps.get(&best_worker_id).unwrap_or(&0); + let total_blocks_info = workers + .get(&best_worker_id) + .and_then(|cfg| cfg.as_ref()) + .and_then(|cfg| cfg.total_kv_blocks) + .map(|blocks| format!(", total blocks: {}", blocks)) + .unwrap_or_default(); + tracing::info!( - "Selected worker: {}, logit: {:.3}", + "Selected worker: {}, logit: {:.3}, cached blocks: {}{}", best_worker_id, - best_logit + best_logit, + best_overlap, + total_blocks_info ); Ok(WorkerSelectionResult { From e697253c9857c5dc5b2ff9963641b6521f6a5b22 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 18:29:17 -0700 Subject: [PATCH 20/31] clippy --- lib/llm/src/kv_router/scheduler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/llm/src/kv_router/scheduler.rs b/lib/llm/src/kv_router/scheduler.rs index 4fadb952dd..f722442fe0 100644 --- a/lib/llm/src/kv_router/scheduler.rs +++ b/lib/llm/src/kv_router/scheduler.rs @@ -389,7 +389,7 @@ impl WorkerSelector for DefaultWorkerSelector { let mut max_logit = f64::NEG_INFINITY; // Calculate logits for each worker - for (worker_id, runtime_config) in workers.iter() { + for worker_id in workers.keys() { let overlap = *overlaps.get(worker_id).unwrap_or(&0); // this is the number of prefill tokens the worker would have if the request were scheduled there From b0dc6f38506f351b1c205aac90946abc0c6a3ccf Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 18:39:34 -0700 Subject: [PATCH 21/31] mv runtime config bindings to new file local_model.rs --- lib/bindings/python/rust/lib.rs | 4 +- lib/bindings/python/rust/llm.rs | 1 + lib/bindings/python/rust/llm/local_model.rs | 72 +++++++++++++++++++++ lib/bindings/python/rust/llm/model_card.rs | 68 ------------------- 4 files changed, 75 insertions(+), 70 deletions(-) create mode 100644 lib/bindings/python/rust/llm/local_model.rs diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index 27b1e8726e..0c1e498301 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -25,7 +25,7 @@ use dynamo_runtime::{ use dynamo_llm::{self as llm_rs}; use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig}; -use crate::llm::model_card::ModelRuntimeConfig; +use crate::llm::local_model::ModelRuntimeConfig; #[pyclass(eq, eq_int)] #[derive(Clone, Debug, PartialEq)] @@ -84,7 +84,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/lib/bindings/python/rust/llm.rs b/lib/bindings/python/rust/llm.rs index 7e3cfc947f..9a2c859228 100644 --- a/lib/bindings/python/rust/llm.rs +++ b/lib/bindings/python/rust/llm.rs @@ -31,6 +31,7 @@ pub mod block_manager; pub mod disagg_router; pub mod entrypoint; pub mod kv; +pub mod local_model; pub mod model_card; pub mod nats; pub mod preprocessor; diff --git a/lib/bindings/python/rust/llm/local_model.rs b/lib/bindings/python/rust/llm/local_model.rs new file mode 100644 index 0000000000..d10167f9f0 --- /dev/null +++ b/lib/bindings/python/rust/llm/local_model.rs @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig; + +#[pyclass] +#[derive(Clone, Default)] +pub struct ModelRuntimeConfig { + pub(crate) inner: RsModelRuntimeConfig, +} + +#[pymethods] +impl ModelRuntimeConfig { + #[new] + fn new() -> Self { + Self { + inner: RsModelRuntimeConfig::new(), + } + } + + #[setter] + fn set_total_kv_blocks(&mut self, total_kv_blocks: u64) { + self.inner.total_kv_blocks = Some(total_kv_blocks); + } + + #[setter] + fn set_max_num_seqs(&mut self, max_num_seqs: u64) { + self.inner.max_num_seqs = Some(max_num_seqs); + } + + #[setter] + fn set_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { + self.inner.gpu_memory_utilization = Some(gpu_memory_utilization); + } + + fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { + let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; + self.inner + .set_engine_specific(key, value) + .map_err(to_pyerr)?; + Ok(()) + } + + #[getter] + fn total_kv_blocks(&self) -> Option { + self.inner.total_kv_blocks + } + + #[getter] + fn max_num_seqs(&self) -> Option { + self.inner.max_num_seqs + } + + #[getter] + fn gpu_memory_utilization(&self) -> Option { + self.inner.gpu_memory_utilization + } + + #[getter] + fn runtime_data(&self, py: Python<'_>) -> PyResult { + let dict = PyDict::new(py); + for (key, value) in self.inner.runtime_data.clone() { + dict.set_item(key, value.to_string())?; + } + Ok(dict.into()) + } + + fn get_engine_specific(&self, key: &str) -> PyResult> { + self.inner.get_engine_specific(key).map_err(to_pyerr) + } +} diff --git a/lib/bindings/python/rust/llm/model_card.rs b/lib/bindings/python/rust/llm/model_card.rs index d6f76d06db..b60fc655ce 100644 --- a/lib/bindings/python/rust/llm/model_card.rs +++ b/lib/bindings/python/rust/llm/model_card.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; -use llm_rs::local_model::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig; use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard; #[pyclass] @@ -36,70 +35,3 @@ impl ModelDeploymentCard { Ok(json) } } - -#[pyclass] -#[derive(Clone, Default)] -pub struct ModelRuntimeConfig { - pub(crate) inner: RsModelRuntimeConfig, -} - -#[pymethods] -impl ModelRuntimeConfig { - #[new] - fn new() -> Self { - Self { - inner: RsModelRuntimeConfig::new(), - } - } - - #[setter] - fn set_total_kv_blocks(&mut self, total_kv_blocks: u64) { - self.inner.total_kv_blocks = Some(total_kv_blocks); - } - - #[setter] - fn set_max_num_seqs(&mut self, max_num_seqs: u64) { - self.inner.max_num_seqs = Some(max_num_seqs); - } - - #[setter] - fn set_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { - self.inner.gpu_memory_utilization = Some(gpu_memory_utilization); - } - - fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { - let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; - self.inner - .set_engine_specific(key, value) - .map_err(to_pyerr)?; - Ok(()) - } - - #[getter] - fn total_kv_blocks(&self) -> Option { - self.inner.total_kv_blocks - } - - #[getter] - fn max_num_seqs(&self) -> Option { - self.inner.max_num_seqs - } - - #[getter] - fn gpu_memory_utilization(&self) -> Option { - self.inner.gpu_memory_utilization - } - - #[getter] - fn runtime_data(&self, py: Python<'_>) -> PyResult { - let dict = PyDict::new(py); - for (key, value) in self.inner.runtime_data.clone() { - dict.set_item(key, value.to_string())?; - } - Ok(dict.into()) - } - - fn get_engine_specific(&self, key: &str) -> PyResult> { - self.inner.get_engine_specific(key).map_err(to_pyerr) - } -} From 8004bbdea59e74adda394fa60038690815158240 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 20:39:55 -0700 Subject: [PATCH 22/31] tensorrtllm support (vibe coded) --- .../backends/trtllm/src/dynamo/trtllm/main.py | 53 +++++++++++++++++-- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index 8e54f00df2..c3f7edc9f1 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -20,10 +20,10 @@ from torch.cuda import device_count from transformers import AutoConfig -from dynamo.llm import ModelType, register_llm +from dynamo.llm import ModelRuntimeConfig, ModelType, register_llm from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime.logging import configure_dynamo_logging -from dynamo.trtllm.engine import get_llm_engine +from dynamo.trtllm.engine import TensorRTLLMEngine, get_llm_engine from dynamo.trtllm.multimodal_processor import MultimodalRequestProcessor from dynamo.trtllm.publisher import get_publisher from dynamo.trtllm.request_handlers.handlers import ( @@ -49,6 +49,49 @@ async def graceful_shutdown(runtime): logging.info("DistributedRuntime shutdown complete") +async def get_engine_runtime_config( + engine: TensorRTLLMEngine, config: Config +) -> ModelRuntimeConfig: + """Retrieve runtime configuration from TensorRT-LLM engine.""" + runtime_config = ModelRuntimeConfig() + + try: + # Get runtime stats from the engine + stats_generator = engine.llm.get_stats_async(timeout=5) + async for stat in stats_generator: + # Extract KV cache configuration + if "kvCacheStats" in stat and "maxNumBlocks" in stat["kvCacheStats"]: + runtime_config.total_kv_blocks = stat["kvCacheStats"]["maxNumBlocks"] + logging.info( + f"Set runtime config total KV blocks: {runtime_config.total_kv_blocks}" + ) + + # Extract max number of sequences + if "maxNumActiveRequests" in stat: + runtime_config.max_num_seqs = stat["maxNumActiveRequests"] + logging.info( + f"Set runtime config max num seqs: {runtime_config.max_num_seqs}" + ) + + # Get GPU memory utilization from the config + # Convert free_gpu_memory_fraction to utilization percentage + gpu_mem_percentage = int(config.free_gpu_memory_fraction * 100) + runtime_config.gpu_memory_utilization = gpu_mem_percentage + logging.info( + f"Set runtime config GPU memory utilization: {gpu_mem_percentage}%" + ) + + # Only need the first stat result + break + + return runtime_config + + except Exception as e: + logging.error(f"Failed to get runtime config from TensorRT-LLM engine: {e}") + # Return config with default/None values if retrieval fails + return runtime_config + + @dynamo_worker(static=False) async def worker(runtime: DistributedRuntime): # Set up signal handler for graceful shutdown @@ -196,7 +239,10 @@ async def init(runtime: DistributedRuntime, config: Config): endpoint = component.endpoint(config.endpoint) if is_first_worker(config): - # Register the model with the endpoint if only the worker is first in the disaggregation chain. + # Get runtime configuration from the engine + runtime_config = await get_engine_runtime_config(engine, config) + + # Register the model with runtime config await register_llm( modelType, endpoint, @@ -204,6 +250,7 @@ async def init(runtime: DistributedRuntime, config: Config): config.served_model_name, kv_cache_block_size=config.kv_block_size, migration_limit=config.migration_limit, + runtime_config=runtime_config, # Add runtime config here ) # publisher will be set later if publishing is enabled. handler_config = RequestHandlerConfig( From ef3d419ebd31544125b0e5beec28fe180a0509ec Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 20:57:59 -0700 Subject: [PATCH 23/31] max_num_batched_tokens instead --- .../backends/sglang/src/dynamo/sglang/worker/main.py | 9 +++++---- components/backends/trtllm/src/dynamo/trtllm/main.py | 8 +++----- components/backends/vllm/src/dynamo/vllm/main.py | 7 +++---- lib/bindings/python/rust/llm/local_model.rs | 8 ++++---- lib/llm/src/local_model.rs | 2 ++ lib/llm/src/local_model/runtime_config.rs | 2 +- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index df31a58923..8f155de5ea 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -414,11 +414,12 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig f"Set model runtime config max num seqs: {runtime_config.max_num_seqs}" ) - if server_info.get("mem_fraction_static") is not None: - gpu_mem_percentage = int(server_info["mem_fraction_static"] * 100) - runtime_config.gpu_memory_utilization = gpu_mem_percentage + if server_info.get("max_num_batched_tokens") is not None: + runtime_config.max_num_batched_tokens = server_info[ + "max_num_batched_tokens" + ] logging.info( - f"Set model runtime config GPU memory utilization: {gpu_mem_percentage}%" + f"Set model runtime config max num batched tokens: {runtime_config.max_num_batched_tokens}" ) return runtime_config diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index c3f7edc9f1..7c5486d225 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -73,12 +73,10 @@ async def get_engine_runtime_config( f"Set runtime config max num seqs: {runtime_config.max_num_seqs}" ) - # Get GPU memory utilization from the config - # Convert free_gpu_memory_fraction to utilization percentage - gpu_mem_percentage = int(config.free_gpu_memory_fraction * 100) - runtime_config.gpu_memory_utilization = gpu_mem_percentage + # Get max_num_batched_tokens from config + runtime_config.max_num_batched_tokens = config.max_num_tokens logging.info( - f"Set runtime config GPU memory utilization: {gpu_mem_percentage}%" + f"Set runtime config max num batched tokens: {runtime_config.max_num_batched_tokens}" ) # Only need the first stat result diff --git a/components/backends/vllm/src/dynamo/vllm/main.py b/components/backends/vllm/src/dynamo/vllm/main.py index ca5f67698f..68302285b0 100644 --- a/components/backends/vllm/src/dynamo/vllm/main.py +++ b/components/backends/vllm/src/dynamo/vllm/main.py @@ -223,8 +223,7 @@ async def init(runtime: DistributedRuntime, config: Config): runtime_values = get_engine_cache_info(engine_client) runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] - gpu_mem_integer = runtime_values["gpu_memory_utilization"] - runtime_config.gpu_memory_utilization = int(gpu_mem_integer * 100) + runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"] await register_llm( ModelType.Backend, @@ -258,11 +257,11 @@ def get_engine_cache_info(engine: AsyncLLM): # Get values directly from vllm_config instead of collective_rpc cache_values = { "num_gpu_blocks": engine.vllm_config.cache_config.num_gpu_blocks, - "gpu_memory_utilization": engine.vllm_config.cache_config.gpu_memory_utilization, } scheduler_values = { "max_num_seqs": engine.vllm_config.scheduler_config.max_num_seqs, + "max_num_batched_tokens": engine.vllm_config.scheduler_config.max_num_batched_tokens, } logging.info(f"Cache config values: {cache_values}") @@ -270,7 +269,7 @@ def get_engine_cache_info(engine: AsyncLLM): return { "num_gpu_blocks": cache_values["num_gpu_blocks"], "max_num_seqs": scheduler_values["max_num_seqs"], - "gpu_memory_utilization": cache_values["gpu_memory_utilization"], + "max_num_batched_tokens": scheduler_values["max_num_batched_tokens"], } except Exception as e: logging.error(f"Failed to get configuration values from vLLM config: {e}") diff --git a/lib/bindings/python/rust/llm/local_model.rs b/lib/bindings/python/rust/llm/local_model.rs index d10167f9f0..2fdc1a153b 100644 --- a/lib/bindings/python/rust/llm/local_model.rs +++ b/lib/bindings/python/rust/llm/local_model.rs @@ -30,8 +30,8 @@ impl ModelRuntimeConfig { } #[setter] - fn set_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) { - self.inner.gpu_memory_utilization = Some(gpu_memory_utilization); + fn set_max_num_batched_tokens(&mut self, max_num_batched_tokens: u64) { + self.inner.max_num_batched_tokens = Some(max_num_batched_tokens); } fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { @@ -53,8 +53,8 @@ impl ModelRuntimeConfig { } #[getter] - fn gpu_memory_utilization(&self) -> Option { - self.inner.gpu_memory_utilization + fn max_num_batched_tokens(&self) -> Option { + self.inner.max_num_batched_tokens } #[getter] diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 5884d5eee8..7172b2d3d8 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -246,6 +246,8 @@ impl LocalModelBuilder { Some(mocker_engine_args.num_gpu_blocks as u64); self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64); + self.runtime_config.max_num_batched_tokens = + mocker_engine_args.max_num_batched_tokens.map(|v| v as u64); } } diff --git a/lib/llm/src/local_model/runtime_config.rs b/lib/llm/src/local_model/runtime_config.rs index f83f5d6335..4421ff4022 100644 --- a/lib/llm/src/local_model/runtime_config.rs +++ b/lib/llm/src/local_model/runtime_config.rs @@ -11,7 +11,7 @@ pub struct ModelRuntimeConfig { pub max_num_seqs: Option, - pub gpu_memory_utilization: Option, + pub max_num_batched_tokens: Option, /// Mapping of engine-specific runtime configs #[serde(default, skip_serializing_if = "HashMap::is_empty")] From 6842494d7498009fe3f70c66a08e255d571ffeed Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 22:53:53 -0700 Subject: [PATCH 24/31] fix sglang server_info args --- .../sglang/src/dynamo/sglang/worker/main.py | 37 +++++++++++++++---- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 8f155de5ea..2d2484296a 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -396,17 +396,29 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig for attempt in range(MAX_RETRIES): try: - server_info = engine.get_server_info() + # Run the synchronous get_server_info in a thread executor to avoid event loop conflict + loop = asyncio.get_event_loop() + server_info = await loop.run_in_executor(None, engine.get_server_info) + if not server_info: logging.warning("No server info from SGLang engine") return None runtime_config = ModelRuntimeConfig() - if server_info.get("max_total_num_tokens") is not None: - runtime_config.total_kv_blocks = server_info["max_total_num_tokens"] - logging.info( - f"Set model runtime config total KV blocks: {runtime_config.total_kv_blocks}" - ) + + # Calculate total_kv_blocks from max_total_num_tokens and page_size + if server_info.get("max_total_num_tokens") is not None and hasattr( + engine, "tokenizer_manager" + ): + page_size = engine.tokenizer_manager.server_args.page_size + if page_size: + runtime_config.total_kv_blocks = ( + server_info["max_total_num_tokens"] // page_size + ) + logging.info( + f"Calculated total KV blocks: {runtime_config.total_kv_blocks} " + f"(max_total_num_tokens={server_info['max_total_num_tokens']}, page_size={page_size})" + ) if server_info.get("max_running_requests") is not None: runtime_config.max_num_seqs = server_info["max_running_requests"] @@ -414,6 +426,7 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig f"Set model runtime config max num seqs: {runtime_config.max_num_seqs}" ) + # max_num_batched_tokens might be provided directly by SGLang if server_info.get("max_num_batched_tokens") is not None: runtime_config.max_num_batched_tokens = server_info[ "max_num_batched_tokens" @@ -421,19 +434,27 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig logging.info( f"Set model runtime config max num batched tokens: {runtime_config.max_num_batched_tokens}" ) + # If not provided, we could use max_total_num_tokens as a fallback + elif server_info.get("max_total_num_tokens") is not None: + runtime_config.max_num_batched_tokens = server_info[ + "max_total_num_tokens" + ] + logging.info( + f"Using max_total_num_tokens as max_num_batched_tokens: {runtime_config.max_num_batched_tokens}" + ) return runtime_config except Exception as e: logging.warning( - f"Attempt {attempt + 1}/{MAX_RETRIES} failed to publish runtime config: {e}" + f"Attempt {attempt + 1}/{MAX_RETRIES} failed to get runtime config: {e}" ) if attempt < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAY) RETRY_DELAY *= 2 else: logging.error( - f"Failed to publish runtime config after {MAX_RETRIES} attempts" + f"Failed to get runtime config after {MAX_RETRIES} attempts" ) return None From 3b175cf1ce64d594edcaccb610a75eec0ab7817e Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Tue, 12 Aug 2025 23:23:46 -0700 Subject: [PATCH 25/31] direct access to server_Args --- .../sglang/src/dynamo/sglang/worker/main.py | 99 ++++++++++--------- 1 file changed, 54 insertions(+), 45 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 2d2484296a..5d11be601f 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -391,73 +391,82 @@ async def register_llm_with_runtime_config( async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]: """Get runtime config from SGLang engine""" - MAX_RETRIES = 3 - RETRY_DELAY = 2 - - for attempt in range(MAX_RETRIES): - try: - # Run the synchronous get_server_info in a thread executor to avoid event loop conflict - loop = asyncio.get_event_loop() - server_info = await loop.run_in_executor(None, engine.get_server_info) - - if not server_info: - logging.warning("No server info from SGLang engine") - return None - - runtime_config = ModelRuntimeConfig() - - # Calculate total_kv_blocks from max_total_num_tokens and page_size - if server_info.get("max_total_num_tokens") is not None and hasattr( - engine, "tokenizer_manager" + try: + runtime_config = ModelRuntimeConfig() + + # Access server_args directly from the engine + if hasattr(engine, "tokenizer_manager") and hasattr( + engine.tokenizer_manager, "server_args" + ): + server_args = engine.tokenizer_manager.server_args + + # Calculate total_kv_blocks from max_total_tokens and page_size + if ( + hasattr(server_args, "max_total_tokens") + and server_args.max_total_tokens is not None ): - page_size = engine.tokenizer_manager.server_args.page_size + page_size = getattr(server_args, "page_size", None) if page_size: runtime_config.total_kv_blocks = ( - server_info["max_total_num_tokens"] // page_size + server_args.max_total_tokens // page_size ) logging.info( f"Calculated total KV blocks: {runtime_config.total_kv_blocks} " - f"(max_total_num_tokens={server_info['max_total_num_tokens']}, page_size={page_size})" + f"(max_total_tokens={server_args.max_total_tokens}, page_size={page_size})" ) - if server_info.get("max_running_requests") is not None: - runtime_config.max_num_seqs = server_info["max_running_requests"] + # Set max_num_seqs from max_running_requests + if ( + hasattr(server_args, "max_running_requests") + and server_args.max_running_requests is not None + ): + runtime_config.max_num_seqs = server_args.max_running_requests logging.info( f"Set model runtime config max num seqs: {runtime_config.max_num_seqs}" ) - # max_num_batched_tokens might be provided directly by SGLang - if server_info.get("max_num_batched_tokens") is not None: - runtime_config.max_num_batched_tokens = server_info[ - "max_num_batched_tokens" - ] + # Set max_num_batched_tokens + if ( + hasattr(server_args, "max_num_batched_tokens") + and server_args.max_num_batched_tokens is not None + ): + runtime_config.max_num_batched_tokens = ( + server_args.max_num_batched_tokens + ) logging.info( f"Set model runtime config max num batched tokens: {runtime_config.max_num_batched_tokens}" ) - # If not provided, we could use max_total_num_tokens as a fallback - elif server_info.get("max_total_num_tokens") is not None: - runtime_config.max_num_batched_tokens = server_info[ - "max_total_num_tokens" - ] + # Fallback to max_total_tokens if max_num_batched_tokens not available + elif ( + hasattr(server_args, "max_total_tokens") + and server_args.max_total_tokens is not None + ): + runtime_config.max_num_batched_tokens = server_args.max_total_tokens logging.info( - f"Using max_total_num_tokens as max_num_batched_tokens: {runtime_config.max_num_batched_tokens}" + f"Using max_total_tokens as max_num_batched_tokens: {runtime_config.max_num_batched_tokens}" ) - return runtime_config + # Log a warning if we couldn't get all required values + if runtime_config.total_kv_blocks is None: + logging.warning("Could not determine total_kv_blocks from server_args") + if runtime_config.max_num_seqs is None: + logging.warning("Could not determine max_num_seqs from server_args") + if runtime_config.max_num_batched_tokens is None: + logging.warning( + "Could not determine max_num_batched_tokens from server_args" + ) - except Exception as e: + else: logging.warning( - f"Attempt {attempt + 1}/{MAX_RETRIES} failed to get runtime config: {e}" + "Could not access server_args from engine.tokenizer_manager" ) - if attempt < MAX_RETRIES - 1: - await asyncio.sleep(RETRY_DELAY) - RETRY_DELAY *= 2 - else: - logging.error( - f"Failed to get runtime config after {MAX_RETRIES} attempts" - ) + return None + + return runtime_config - return None + except Exception as e: + logging.error(f"Failed to get runtime config: {e}") + return None def main(): From 10773c790f03f13d74e5ddc8c90be08d1aeaab40 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 13 Aug 2025 00:07:08 -0700 Subject: [PATCH 26/31] sglang: access total num tokens via scheduler info --- .../sglang/src/dynamo/sglang/worker/main.py | 105 ++++++------------ 1 file changed, 32 insertions(+), 73 deletions(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 5d11be601f..a027e09acb 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -9,12 +9,12 @@ import sys from typing import Any, Dict, Optional, Union -import sglang as sgl import uvloop import zmq from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_ip, get_zmq_socket +import sglang as sgl from dynamo._core import Endpoint from dynamo.llm import ( ForwardPassMetrics, @@ -392,80 +392,39 @@ async def register_llm_with_runtime_config( async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig]: """Get runtime config from SGLang engine""" try: - runtime_config = ModelRuntimeConfig() - - # Access server_args directly from the engine - if hasattr(engine, "tokenizer_manager") and hasattr( - engine.tokenizer_manager, "server_args" - ): - server_args = engine.tokenizer_manager.server_args - - # Calculate total_kv_blocks from max_total_tokens and page_size - if ( - hasattr(server_args, "max_total_tokens") - and server_args.max_total_tokens is not None - ): - page_size = getattr(server_args, "page_size", None) - if page_size: - runtime_config.total_kv_blocks = ( - server_args.max_total_tokens // page_size - ) - logging.info( - f"Calculated total KV blocks: {runtime_config.total_kv_blocks} " - f"(max_total_tokens={server_args.max_total_tokens}, page_size={page_size})" - ) - - # Set max_num_seqs from max_running_requests - if ( - hasattr(server_args, "max_running_requests") - and server_args.max_running_requests is not None - ): - runtime_config.max_num_seqs = server_args.max_running_requests - logging.info( - f"Set model runtime config max num seqs: {runtime_config.max_num_seqs}" - ) - - # Set max_num_batched_tokens - if ( - hasattr(server_args, "max_num_batched_tokens") - and server_args.max_num_batched_tokens is not None - ): - runtime_config.max_num_batched_tokens = ( - server_args.max_num_batched_tokens - ) - logging.info( - f"Set model runtime config max num batched tokens: {runtime_config.max_num_batched_tokens}" - ) - # Fallback to max_total_tokens if max_num_batched_tokens not available - elif ( - hasattr(server_args, "max_total_tokens") - and server_args.max_total_tokens is not None - ): - runtime_config.max_num_batched_tokens = server_args.max_total_tokens - logging.info( - f"Using max_total_tokens as max_num_batched_tokens: {runtime_config.max_num_batched_tokens}" - ) - - # Log a warning if we couldn't get all required values - if runtime_config.total_kv_blocks is None: - logging.warning("Could not determine total_kv_blocks from server_args") - if runtime_config.max_num_seqs is None: - logging.warning("Could not determine max_num_seqs from server_args") - if runtime_config.max_num_batched_tokens is None: - logging.warning( - "Could not determine max_num_batched_tokens from server_args" - ) - - else: - logging.warning( - "Could not access server_args from engine.tokenizer_manager" - ) - return None - - return runtime_config + # Try to check if the engine has a scheduler attribute with the computed values + if hasattr(engine, "scheduler_info") and engine.scheduler_info is not None: + runtime_config = ModelRuntimeConfig() + + # Get max_total_num_tokens from scheduler_info + if "max_total_num_tokens" in engine.scheduler_info: + max_total_tokens = engine.scheduler_info["max_total_num_tokens"] + if max_total_tokens and hasattr( + engine.tokenizer_manager, "server_args" + ): + page_size = engine.tokenizer_manager.server_args.page_size + if page_size: + runtime_config.total_kv_blocks = max_total_tokens // page_size + logging.info( + f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} " + f"(max_total_tokens={max_total_tokens}, page_size={page_size})" + ) + + # Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info + # TODO: figure out where they are + + return runtime_config + + # If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config + logging.warning( + "Could not access runtime config from SGLang engine. " + "The engine may compute these values internally after initialization. " + "Proceeding without runtime config - SGLang will use its internal defaults." + ) + return None except Exception as e: - logging.error(f"Failed to get runtime config: {e}") + logging.warning(f"Failed to get runtime config: {e}. Proceeding without it.") return None From 69d5d80524c31c67d9631e4fb8f2c16334bde1c5 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 13 Aug 2025 00:12:06 -0700 Subject: [PATCH 27/31] isort --- components/backends/sglang/src/dynamo/sglang/worker/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index a027e09acb..9a7bfb44b2 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -9,12 +9,12 @@ import sys from typing import Any, Dict, Optional, Union +import sglang as sgl import uvloop import zmq from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_ip, get_zmq_socket -import sglang as sgl from dynamo._core import Endpoint from dynamo.llm import ( ForwardPassMetrics, From e6de5a233a62e63aa507f5677439251e9fc9be56 Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 13 Aug 2025 00:43:33 -0700 Subject: [PATCH 28/31] trtllm: extract directly from config --- .../backends/trtllm/src/dynamo/trtllm/main.py | 34 ++++++------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index 7c5486d225..9a5f90ebf7 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -56,31 +56,17 @@ async def get_engine_runtime_config( runtime_config = ModelRuntimeConfig() try: - # Get runtime stats from the engine - stats_generator = engine.llm.get_stats_async(timeout=5) - async for stat in stats_generator: - # Extract KV cache configuration - if "kvCacheStats" in stat and "maxNumBlocks" in stat["kvCacheStats"]: - runtime_config.total_kv_blocks = stat["kvCacheStats"]["maxNumBlocks"] - logging.info( - f"Set runtime config total KV blocks: {runtime_config.total_kv_blocks}" - ) - - # Extract max number of sequences - if "maxNumActiveRequests" in stat: - runtime_config.max_num_seqs = stat["maxNumActiveRequests"] - logging.info( - f"Set runtime config max num seqs: {runtime_config.max_num_seqs}" - ) - - # Get max_num_batched_tokens from config - runtime_config.max_num_batched_tokens = config.max_num_tokens - logging.info( - f"Set runtime config max num batched tokens: {runtime_config.max_num_batched_tokens}" - ) + # TODO: extract kv_total_blocks + + # Extract max number of sequences + runtime_config.max_num_seqs = config.max_batch_size + logging.info(f"Set runtime config max_num_seqs: {runtime_config.max_num_seqs}") - # Only need the first stat result - break + # Get max_num_batched_tokens from config + runtime_config.max_num_batched_tokens = config.max_num_tokens + logging.info( + f"Set runtime config max_num_batched_tokens: {runtime_config.max_num_batched_tokens}" + ) return runtime_config From 09f1cb0da2bb16e240b603ef24f1bb244872e13c Mon Sep 17 00:00:00 2001 From: PeaBrane Date: Wed, 13 Aug 2025 00:46:24 -0700 Subject: [PATCH 29/31] trtllm: get total_kv_blocks from get_stats_async --- components/backends/trtllm/src/dynamo/trtllm/main.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/components/backends/trtllm/src/dynamo/trtllm/main.py b/components/backends/trtllm/src/dynamo/trtllm/main.py index 9a5f90ebf7..c7d7c35e1f 100644 --- a/components/backends/trtllm/src/dynamo/trtllm/main.py +++ b/components/backends/trtllm/src/dynamo/trtllm/main.py @@ -56,7 +56,13 @@ async def get_engine_runtime_config( runtime_config = ModelRuntimeConfig() try: - # TODO: extract kv_total_blocks + # Extract total_kv_blocks from engine stats + stats = engine.llm.get_stats_async(timeout=5) + stat = await anext(stats) + runtime_config.total_kv_blocks = stat["kvCacheStats"]["maxNumBlocks"] + logging.info( + f"Set runtime config total_kv_blocks: {runtime_config.total_kv_blocks}" + ) # Extract max number of sequences runtime_config.max_num_seqs = config.max_batch_size From 280e98af830e39fe74f5cfa8f9de0c053b7d34da Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Wed, 13 Aug 2025 19:08:30 +0100 Subject: [PATCH 30/31] ceil division for sglang total_kv_blocks calculation --- components/backends/sglang/src/dynamo/sglang/worker/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 9a7bfb44b2..8c70abc50f 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -404,7 +404,7 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig ): page_size = engine.tokenizer_manager.server_args.page_size if page_size: - runtime_config.total_kv_blocks = max_total_tokens // page_size + runtime_config.total_kv_blocks = (max_total_tokens + page_size - 1) // page_size logging.info( f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} " f"(max_total_tokens={max_total_tokens}, page_size={page_size})" From 3f8bcddb64e86aa4f240f3bf5b2ec05bb18c8907 Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Wed, 13 Aug 2025 19:09:35 +0100 Subject: [PATCH 31/31] hooks --- components/backends/sglang/src/dynamo/sglang/worker/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/components/backends/sglang/src/dynamo/sglang/worker/main.py b/components/backends/sglang/src/dynamo/sglang/worker/main.py index 8c70abc50f..b925921e5d 100644 --- a/components/backends/sglang/src/dynamo/sglang/worker/main.py +++ b/components/backends/sglang/src/dynamo/sglang/worker/main.py @@ -404,7 +404,9 @@ async def _get_runtime_config(engine: sgl.Engine) -> Optional[ModelRuntimeConfig ): page_size = engine.tokenizer_manager.server_args.page_size if page_size: - runtime_config.total_kv_blocks = (max_total_tokens + page_size - 1) // page_size + runtime_config.total_kv_blocks = ( + max_total_tokens + page_size - 1 + ) // page_size logging.info( f"Got total KV blocks from scheduler: {runtime_config.total_kv_blocks} " f"(max_total_tokens={max_total_tokens}, page_size={page_size})"