Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c23f86f
Implement add_labels for Endpoint and add model name to StatLoggerFac…
tzulingk Aug 16, 2025
e7e2a2e
format the codes
tzulingk Aug 16, 2025
9a8cc4d
remove unused import
tzulingk Aug 16, 2025
90d64d4
refactor: change create_endpoint from model to labels parameter
keivenchang Aug 17, 2025
5841108
refactor: update metrics endpoint creation to use labels instead of m…
keivenchang Aug 17, 2025
af27fb2
style: fix formatting in service_metrics example
keivenchang Aug 18, 2025
3b11882
Add metrics labels at EndpointConfigBuilder::start()
tzulingk Aug 18, 2025
7c3bdb0
Add metric_labels in EndpointConfig.
tzulingk Aug 20, 2025
acb237f
Merge remote-tracking branch 'origin/main' into tzulingk/model_label_…
tzulingk Aug 20, 2025
386dd2a
Change labels to metrics_labels
tzulingk Aug 20, 2025
402d5e6
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 20, 2025
60171e8
Cargo fmt
tzulingk Aug 20, 2025
40d8f86
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 20, 2025
91101ad
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 20, 2025
946e821
Use different endpoint name to aavoid the duplicated metrics registra…
tzulingk Aug 21, 2025
f7f86be
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
939808a
Remove the metrics_labels in MockVllmEngine, so that we will have the…
tzulingk Aug 21, 2025
cc8364c
format tests/router/test_router_e2e_with_mockers.py
tzulingk Aug 21, 2025
9ef6e2f
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
523984f
Update examples/multimodal/components/worker.py
tzulingk Aug 21, 2025
7f2ea28
Update components/backends/vllm/src/dynamo/vllm/main.py
tzulingk Aug 21, 2025
fd892a1
Update components/backends/vllm/src/dynamo/vllm/publisher.py
tzulingk Aug 21, 2025
ed24e16
Update components/backends/vllm/src/dynamo/vllm/publisher.py
tzulingk Aug 21, 2025
c35521b
Update components/backends/vllm/src/dynamo/vllm/publisher.py
tzulingk Aug 21, 2025
837570b
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
651b4f1
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
a8a84fa
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
406c880
fmt python files
tzulingk Aug 21, 2025
578d25c
python fmt.
tzulingk Aug 21, 2025
b9acce8
pre-commit run --all-files results are different from my local black …
tzulingk Aug 21, 2025
2f294e6
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions components/backends/vllm/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", config.model)],
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
Expand All @@ -178,7 +184,9 @@ async def init(runtime: DistributedRuntime, config: Config):
.client()
)

factory = StatLoggerFactory(component, config.engine_args.data_parallel_rank or 0)
factory = StatLoggerFactory(
component, config.engine_args.data_parallel_rank or 0, metrics_labels = [("model", config.model)]
)
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(
config, factory
)
Expand Down Expand Up @@ -239,8 +247,14 @@ async def init(runtime: DistributedRuntime, config: Config):
await asyncio.gather(
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=False,
metrics_labels=[("model", config.model)],
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
Expand Down
25 changes: 20 additions & 5 deletions components/backends/vllm/src/dynamo/vllm/publisher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Optional
from typing import List, Optional, Tuple

from vllm.config import VllmConfig
from vllm.v1.metrics.loggers import StatLoggerBase
Expand Down Expand Up @@ -36,9 +36,16 @@ def log_engine_initialized(self):
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""

def __init__(self, component: Component, dp_rank: int) -> None:
def __init__(
self,
component: Component,
dp_rank: int,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
# Use labels directly for the new create_endpoint signature
metrics_labels = metrics_labels or []
self.inner.create_endpoint(component, metrics_labels)
self.dp_rank = dp_rank
self.num_gpu_block = 1
self.request_total_slots = 1
Expand Down Expand Up @@ -129,15 +136,23 @@ def log_engine_initialized(self) -> None:
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""

def __init__(self, component: Component, dp_rank: int = 0) -> None:
def __init__(
self,
component: Component,
dp_rank: int = 0,
metrics_labels: Optional[List[Tuple[str, str]]] = None,
) -> None:
self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
self.metrics_labels = metrics_labels or []

def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
logger = DynamoStatLoggerPublisher(
self.component, dp_rank, metrics_labels=self.metrics_labels
)
self.created_logger = logger

return logger
Expand Down
14 changes: 9 additions & 5 deletions examples/multimodal/components/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def setup_vllm_engine(self, component: Component, endpoint: Endpoint):

# Create vLLM engine with metrics logger and KV event publisher attached
self.stats_logger = StatLoggerFactory(
component, self.engine_args.data_parallel_rank or 0
component,
self.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", self.engine_args.model)],
)
self.engine_client = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
Expand Down Expand Up @@ -342,9 +344,9 @@ async def generate(self, request: vLLMMultimodalRequest):
# Update the prompt token id in the decode request to the one
# in response, which has image templated filled in. So that
# the decode worker will fetch correct amount of KV blocks.
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
decode_request.engine_prompt["prompt_token_ids"] = (
prefill_response.prompt_token_ids
)
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
Expand All @@ -353,7 +355,9 @@ async def generate(self, request: vLLMMultimodalRequest):
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for decode_response in await self.decode_worker_client.round_robin(
async for (
decode_response
) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data())
Expand Down
9 changes: 7 additions & 2 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,19 +513,24 @@ impl Component {

#[pymethods]
impl Endpoint {
#[pyo3(signature = (generator, graceful_shutdown = true))]
#[pyo3(signature = (generator, graceful_shutdown = true, metrics_labels = None))]
fn serve_endpoint<'p>(
&self,
py: Python<'p>,
generator: PyObject,
graceful_shutdown: Option<bool>,
metrics_labels: Option<Vec<(String, String)>>,
) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new(
generator,
self.event_loop.clone(),
)?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let builder = self.inner.endpoint_builder().handler(ingress);
let builder = self
.inner
.endpoint_builder()
.metrics_labels(metrics_labels)
.handler(ingress);
let graceful_shutdown = graceful_shutdown.unwrap_or(true);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder
Expand Down
22 changes: 20 additions & 2 deletions lib/bindings/python/rust/llm/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,35 @@ impl WorkerMetricsPublisher {
})
}

#[pyo3(signature = (component))]
#[pyo3(signature = (component, metrics_labels = None))]
fn create_endpoint<'p>(
&self,
py: Python<'p>,
component: Component,
metrics_labels: Option<Vec<(String, String)>>,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Convert Python labels to Option<&[(&str, &str)]> expected by Rust API
let metrics_labels_ref: Option<Vec<(&str, &str)>> =
if let Some(metrics_labels) = metrics_labels.as_ref() {
if metrics_labels.is_empty() {
None
} else {
Some(
metrics_labels
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect(),
)
}
} else {
None
};

rs_publisher
.create_endpoint(rs_component)
.create_endpoint(rs_component, metrics_labels_ref.as_deref())
.await
.map_err(to_pyerr)?;
Ok(())
Expand Down
13 changes: 12 additions & 1 deletion lib/llm/src/kv_router/publisher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,11 @@ impl WorkerMetricsPublisher {
self.tx.send(metrics)
}

pub async fn create_endpoint(&self, component: Component) -> Result<()> {
pub async fn create_endpoint(
&self,
component: Component,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
Expand All @@ -514,13 +518,20 @@ impl WorkerMetricsPublisher {

self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);

let metrics_labels = metrics_labels.map(|v| {
v.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<Vec<_>>()
});

component
.endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder()
.stats_handler(move |_| {
let metrics = metrics_rx.borrow_and_update().clone();
serde_json::to_value(&*metrics).unwrap()
})
.metrics_labels(metrics_labels)
.handler(handler)
.start()
.await
Expand Down
2 changes: 1 addition & 1 deletion lib/llm/src/mocker/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl MockVllmEngine {
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
if let Err(e) = publisher.create_endpoint(comp.clone(), None).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
Expand Down
1 change: 0 additions & 1 deletion lib/runtime/examples/system_metrics/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use std::sync::Arc;
pub const DEFAULT_NAMESPACE: &str = "dyn_example_namespace";
pub const DEFAULT_COMPONENT: &str = "dyn_example_component";
pub const DEFAULT_ENDPOINT: &str = "dyn_example_endpoint";
pub const DEFAULT_MODEL_NAME: &str = "dyn_example_model";

/// Stats structure returned by the endpoint's stats handler
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
Expand Down
11 changes: 9 additions & 2 deletions lib/runtime/src/component/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub struct EndpointConfig {
#[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>,

/// Additional labels for metrics
#[builder(default, setter(into))]
metrics_labels: Option<Vec<(String, String)>>,

/// Whether to wait for inflight requests to complete during shutdown
#[builder(default = "true")]
graceful_shutdown: bool,
Expand All @@ -59,7 +63,7 @@ impl EndpointConfigBuilder {
}

pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler, stats_handler, graceful_shutdown) =
let (endpoint, lease, handler, stats_handler, metrics_labels, graceful_shutdown) =
self.build_internal()?.dissolve();
let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
Expand All @@ -74,8 +78,11 @@ impl EndpointConfigBuilder {
// acquire the registry lock
let registry = endpoint.drt().component_registry.inner.lock().await;

let metrics_labels: Option<Vec<(&str, &str)>> = metrics_labels
.as_ref()
.map(|v| v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect());
// Add metrics to the handler. The endpoint provides additional information to the handler.
handler.add_metrics(&endpoint)?;
handler.add_metrics(&endpoint, metrics_labels.as_deref())?;

// get the group
let group = registry
Expand Down
14 changes: 11 additions & 3 deletions lib/runtime/src/pipeline/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,12 @@ impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
.map_err(|_| anyhow::anyhow!("Segment already set"))
}

pub fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> {
let metrics = WorkHandlerMetrics::from_endpoint(endpoint)
pub fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
let metrics = WorkHandlerMetrics::from_endpoint(endpoint, metrics_labels)
.map_err(|e| anyhow::anyhow!("Failed to create work handler metrics: {}", e))?;

self.metrics
Expand Down Expand Up @@ -345,7 +349,11 @@ pub trait PushWorkHandler: Send + Sync {
async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError>;

/// Add metrics to the handler
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()>;
fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()>;
}

/*
Expand Down
22 changes: 14 additions & 8 deletions lib/runtime/src/pipeline/network/ingress/push_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,43 +54,45 @@ impl WorkHandlerMetrics {
/// Create WorkHandlerMetrics from an endpoint using its built-in labeling
pub fn from_endpoint(
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let metrics_labels = metrics_labels.unwrap_or(&[]);
let request_counter = endpoint.create_intcounter(
"requests_total",
"Total number of requests processed by work handler",
&[],
metrics_labels,
)?;

let request_duration = endpoint.create_histogram(
"request_duration_seconds",
"Time spent processing requests by work handler",
&[],
metrics_labels,
None,
)?;

let inflight_requests = endpoint.create_intgauge(
"inflight_requests",
"Number of requests currently being processed by work handler",
&[],
metrics_labels,
)?;

let request_bytes = endpoint.create_intcounter(
"request_bytes_total",
"Total number of bytes received in requests by work handler",
&[],
metrics_labels,
)?;

let response_bytes = endpoint.create_intcounter(
"response_bytes_total",
"Total number of bytes sent in responses by work handler",
&[],
metrics_labels,
)?;

let error_counter = endpoint.create_intcountervec(
"errors_total",
"Total number of errors in work handler processing",
&["error_type"],
&[],
metrics_labels,
)?;

Ok(Self::new(
Expand All @@ -110,10 +112,14 @@ where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> {
fn add_metrics(
&self,
endpoint: &crate::component::Endpoint,
metrics_labels: Option<&[(&str, &str)]>,
) -> Result<()> {
// Call the Ingress-specific add_metrics implementation
use crate::pipeline::network::Ingress;
Ingress::add_metrics(self, endpoint)
Ingress::add_metrics(self, endpoint, metrics_labels)
}

async fn handle_payload(&self, payload: Bytes) -> Result<(), PipelineError> {
Expand Down
Loading