Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b1e6eb4
first commit
jorgeantonio21 Aug 5, 2025
8ffe717
register runtime config after engine initialization
jorgeantonio21 Aug 5, 2025
58d73d2
add sglang runtime config values retrieval
jorgeantonio21 Aug 5, 2025
dfc9154
merge main and resolve conflicts
jorgeantonio21 Aug 5, 2025
87865fc
Merge branch 'main' into feat/ja/runtime-configs-mdc
jorgeantonio21 Aug 6, 2025
5707890
address comments in the PR
jorgeantonio21 Aug 7, 2025
9770e75
Merge branch 'main' into feat/ja/runtime-configs-mdc
jorgeantonio21 Aug 7, 2025
61f6424
refactor logic to pass in engine initialization runtime args directly…
jorgeantonio21 Aug 7, 2025
6fbe951
merge main and resolve conflicts
jorgeantonio21 Aug 11, 2025
b376cfb
resolve _core.py import issues
jorgeantonio21 Aug 11, 2025
9d3cbb1
resolve runtime issues
jorgeantonio21 Aug 11, 2025
d1b87f5
resolve import issues
jorgeantonio21 Aug 11, 2025
d18881b
resolve import issues
jorgeantonio21 Aug 11, 2025
24712cb
resolve vllm cache config issues
jorgeantonio21 Aug 11, 2025
c20f0e1
resolve non-int gpu_mem_integer issue
jorgeantonio21 Aug 11, 2025
af94e4b
resolve non-int gpu_mem_integer issue
jorgeantonio21 Aug 11, 2025
57e12c2
remove uneeded async in python code
jorgeantonio21 Aug 12, 2025
3304c8d
Merge branch 'main' into feat/ja/runtime-configs-mdc
jorgeantonio21 Aug 12, 2025
acddc6b
Merge branch 'main' into feat/ja/runtime-configs-mdc
PeaBrane Aug 12, 2025
d4b1edf
revert llama-cpp version in Cargo.lock
PeaBrane Aug 12, 2025
b7ca2f5
move runtime config into local_model
PeaBrane Aug 12, 2025
becb754
put runtime config in ModelEntry so it gets registered to etcd
PeaBrane Aug 12, 2025
5adaeb1
fmt
PeaBrane Aug 12, 2025
950e6a4
if mocker, override runtime configs
PeaBrane Aug 13, 2025
cbbd03b
router listens to runtime configs (kv total blocks)
PeaBrane Aug 13, 2025
e697253
clippy
PeaBrane Aug 13, 2025
b0dc6f3
mv runtime config bindings to new file local_model.rs
PeaBrane Aug 13, 2025
8004bbd
tensorrtllm support (vibe coded)
PeaBrane Aug 13, 2025
ef3d419
max_num_batched_tokens instead
PeaBrane Aug 13, 2025
6842494
fix sglang server_info args
PeaBrane Aug 13, 2025
3b175cf
direct access to server_Args
PeaBrane Aug 13, 2025
10773c7
sglang: access total num tokens via scheduler info
PeaBrane Aug 13, 2025
69d5d80
isort
PeaBrane Aug 13, 2025
e6de5a2
trtllm: extract directly from config
PeaBrane Aug 13, 2025
09f1cb0
trtllm: get total_kv_blocks from get_stats_async
PeaBrane Aug 13, 2025
36a6fbb
Merge branch 'feat/ja/runtime-configs-mdc' of https://github.com/jorg…
jorgeantonio21 Aug 13, 2025
280e98a
ceil division for sglang total_kv_blocks calculation
jorgeantonio21 Aug 13, 2025
3f8bcdd
hooks
jorgeantonio21 Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion components/backends/sglang/src/dynamo/sglang/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_ip, get_zmq_socket

from dynamo._core import Endpoint
from dynamo.llm import (
ForwardPassMetrics,
KvStats,
Expand All @@ -24,7 +25,9 @@
ZmqKvEventPublisher,
ZmqKvEventPublisherConfig,
register_llm,
register_runtime_config,
)
from dynamo.llm.model_card import ModelRuntimeConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.common import (
Expand Down Expand Up @@ -364,12 +367,60 @@ async def init(
_ = ZmqKvEventPublisher(component=component, config=zmq_config)

tasks = [endpoint.serve_endpoint(handler.generate)]

tasks.append(
register_runtime_config_once_engine_ready(
endpoint, engine, server_args.served_model_name
)
)
tasks.extend(setup_native_endpoints(server_args, component, handler))

await asyncio.gather(*tasks)


async def register_runtime_config_once_engine_ready(
endpoint: Endpoint, engine: sgl.Engine, model_name: str
):
max_retries = 3
retry_delay = 2

for attempt in range(max_retries):
try:
server_info = engine.get_server_info()

runtime_config = ModelRuntimeConfig()

# Use actual computed values from SGLang engine
if server_info.get("max_total_num_tokens") is not None:
runtime_config.with_total_kv_blocks(server_info["max_total_num_tokens"])

if server_info.get("max_running_requests") is not None:
runtime_config.with_max_num_seqs(server_info["max_running_requests"])

if server_info.get("mem_fraction_static") is not None:
gpu_mem_percentage = int(server_info["mem_fraction_static"] * 100)
runtime_config.with_gpu_memory_utilization(gpu_mem_percentage)

# Register the runtime config
await register_runtime_config(endpoint, model_name, runtime_config)

logging.info(
f"Published runtime config for SGLang decode worker: {model_name}"
)
return # Successfully published runtime config, exit loop

except Exception as e:
logging.warning(
f"Attempt {attempt + 1}/{max_retries} failed to publish runtime config: {e}"
)
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay *= 2
else:
logging.error(
f"Failed to publish runtime config after {max_retries} attempts"
)


def main():
uvloop.install()
asyncio.run(worker())
Expand Down
21 changes: 21 additions & 0 deletions components/backends/vllm/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
ZmqKvEventPublisher,
ZmqKvEventPublisherConfig,
register_llm,
register_runtime_config,
)
from dynamo.llm.model_card import ModelRuntimeConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging

Expand Down Expand Up @@ -167,6 +169,25 @@ async def init(runtime: DistributedRuntime, config: Config):
component, engine_client, default_sampling_params, prefill_worker_client
)

# Create a runtime config and register it with the MDC
if not config.engine_args.data_parallel_rank:
runtime_config = ModelRuntimeConfig()
# NOTE: This number needs to be queried directly from the engine,
# since this will compute it if no value was set by the user
runtime_config.with_total_kv_blocks(
engine_client.engine.cache_config.num_gpu_blocks
)
runtime_config.with_max_num_seqs(
engine_client.vllm_config.scheduler_config.max_num_seqs
)

gpu_mem_integer = int(
engine_client.engine.cache_config.gpu_memory_utilization * 100
)
runtime_config.with_gpu_memory_utilization(gpu_mem_integer)

await register_runtime_config(generate_endpoint, config.model, runtime_config)

if config.engine_args.enable_prefix_caching:
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
Expand Down
62 changes: 61 additions & 1 deletion lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ use dynamo_runtime::{
network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn,
},
protocols::annotated::Annotated as RsAnnotated,
slug::Slug,
storage::key_value_store::{EtcdStorage, KeyValueStoreManager},
traits::DistributedRuntimeProvider,
};

use dynamo_llm::{self as llm_rs};
use dynamo_llm::{self as llm_rs, model_card};
use dynamo_llm::{entrypoint::RouterConfig, kv_router::KvRouterConfig};

use crate::llm::model_card::ModelRuntimeConfig;

#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
pub enum RouterMode {
Expand Down Expand Up @@ -63,6 +67,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?;
m.add_function(wrap_pyfunction!(log_message, m)?)?;
m.add_function(wrap_pyfunction!(register_llm, m)?)?;
m.add_function(wrap_pyfunction!(register_runtime_config, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::make_engine, m)?)?;
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;

Expand All @@ -82,6 +87,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::entrypoint::KvRouterConfig>()?;
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
m.add_class::<llm::model_card::ModelDeploymentCard>()?;
m.add_class::<llm::model_card::ModelRuntimeConfig>()?;
m.add_class::<llm::preprocessor::OAIChatPreprocessor>()?;
m.add_class::<llm::backend::Backend>()?;
m.add_class::<llm::kv::OverlapScores>()?;
Expand Down Expand Up @@ -177,6 +183,60 @@ fn register_llm<'p>(
})
}

#[pyfunction]
fn register_runtime_config<'p>(
py: Python<'p>,
endpoint: Endpoint,
model_identifier: &str,
runtime_config: ModelRuntimeConfig,
) -> PyResult<Bound<'p, PyAny>> {
let model_identifier = model_identifier.to_string();
let runtime_config = runtime_config.inner.clone();
let endpoint = endpoint.inner.clone();

pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Get etcd client from endpoint
let Some(etcd_client) = endpoint.drt().etcd_client() else {
return Err(anyhow::anyhow!(
"Cannot update runtime config on static endpoint"
))
.map_err(to_pyerr);
};

// Create storage manager
let kvstore = EtcdStorage::new(etcd_client.clone());
let card_store = KeyValueStoreManager::new(Box::new(kvstore));

// Generate the model slug - this should match what register_llm used
// The register_llm function uses the model_name (or model_path if no model_name)
// and then calls card.set_name() which sets both display_name and service_name
let model_slug = Slug::slugify(&model_identifier);

// Get existing card
let mut card = card_store
.load::<llm_rs::model_card::ModelDeploymentCard>(model_card::ROOT_PATH, &model_slug)
.await
.map_err(to_pyerr)?
.ok_or_else(|| anyhow::anyhow!("Cannot find model card"))
.map_err(to_pyerr)?;

// Update the card
card.register_runtime_config(runtime_config);

// Publish the card
card_store
.publish(model_card::ROOT_PATH, None, model_slug.as_ref(), &mut card)
.await
.map_err(to_pyerr)?;

tracing::info!(
"Successfully updated runtime config for model card: {}",
model_slug
);
Ok(())
})
}

#[pyclass]
#[derive(Clone)]
struct EtcdKvCache {
Expand Down
94 changes: 94 additions & 0 deletions lib/bindings/python/rust/llm/model_card.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

use super::*;
use llm_rs::model_card::model::ModelDeploymentCard as RsModelDeploymentCard;
use llm_rs::model_card::runtime_config::ModelRuntimeConfig as RsModelRuntimeConfig;

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -46,4 +47,97 @@ impl ModelDeploymentCard {
let json = self.inner.to_json().map_err(to_pyerr)?;
Ok(json)
}

fn register_runtime_config(&mut self, runtime_config: ModelRuntimeConfig) {
self.inner.register_runtime_config(runtime_config.inner);
}

#[getter]
fn runtime_config(&self) -> Option<ModelRuntimeConfig> {
self.inner
.runtime_config()
.map(|config| ModelRuntimeConfig {
inner: config.clone(),
})
}

#[getter]
fn total_kv_blocks(&self) -> Option<u64> {
self.inner.total_kv_blocks()
}

#[getter]
fn max_num_seqs(&self) -> Option<u64> {
self.inner.max_num_seqs()
}

#[getter]
fn gpu_memory_utilization(&self) -> Option<u64> {
self.inner.gpu_memory_utilization()
}
}

#[pyclass]
#[derive(Clone)]
pub struct ModelRuntimeConfig {
pub(crate) inner: RsModelRuntimeConfig,
}

#[pymethods]
impl ModelRuntimeConfig {
#[new]
fn new() -> Self {
Self {
inner: RsModelRuntimeConfig::new(),
}
}

fn with_total_kv_blocks(&mut self, total_kv_blocks: u64) {
self.inner.with_total_kv_blocks(total_kv_blocks);
}

fn with_max_num_seqs(&mut self, max_num_seqs: u64) {
self.inner.with_max_num_seqs(max_num_seqs);
}

fn with_gpu_memory_utilization(&mut self, gpu_memory_utilization: u64) {
self.inner
.with_gpu_memory_utilization(gpu_memory_utilization);
}

fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
.set_engine_specific(key, value)
.map_err(to_pyerr)?;
Ok(())
}

#[getter]
fn total_kv_blocks(&self) -> Option<u64> {
self.inner.total_kv_blocks
}

#[getter]
fn max_num_seqs(&self) -> Option<u64> {
self.inner.max_num_seqs
}

#[getter]
fn gpu_memory_utilization(&self) -> Option<u64> {
self.inner.gpu_memory_utilization
}

#[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py);
for (key, value) in self.inner.runtime_data.clone() {
dict.set_item(key, value.to_string())?;
}
Ok(dict.into())
}

fn get_engine_specific(&self, key: &str) -> PyResult<Option<String>> {
self.inner.get_engine_specific(key).map_err(to_pyerr)
}
}
10 changes: 10 additions & 0 deletions lib/bindings/python/src/dynamo/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,12 @@ class ModelDeploymentCard:

...

class ModelRuntimeConfig:
"""
A model runtime configuration is a collection of runtime information
"""
...

class OAIChatPreprocessor:
"""
A preprocessor for OpenAI chat completions
Expand Down Expand Up @@ -833,6 +839,10 @@ async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: st
"""Attach the model at path to the given endpoint, and advertise it as model_type"""
...

async def register_runtime_config(endpoint: Endpoint, model_identifier: str, runtime_config: ModelRuntimeConfig) -> None:
"""Register runtime configuration with the model card"""
...

class EngineConfig:
"""Holds internal configuration for a Dynamo engine."""
...
Expand Down
1 change: 1 addition & 0 deletions lib/bindings/python/src/dynamo/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@
from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for_seq_py
from dynamo._core import make_engine
from dynamo._core import register_llm as register_llm
from dynamo._core import register_runtime_config as register_runtime_config
from dynamo._core import run_input
Loading