Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c23f86f
Implement add_labels for Endpoint and add model name to StatLoggerFac…
tzulingk Aug 16, 2025
e7e2a2e
format the codes
tzulingk Aug 16, 2025
9a8cc4d
remove unused import
tzulingk Aug 16, 2025
90d64d4
refactor: change create_endpoint from model to labels parameter
keivenchang Aug 17, 2025
5841108
refactor: update metrics endpoint creation to use labels instead of m…
keivenchang Aug 17, 2025
af27fb2
style: fix formatting in service_metrics example
keivenchang Aug 18, 2025
3b11882
Add metrics labels at EndpointConfigBuilder::start()
tzulingk Aug 18, 2025
7c3bdb0
Add metric_labels in EndpointConfig.
tzulingk Aug 20, 2025
acb237f
Merge remote-tracking branch 'origin/main' into tzulingk/model_label_…
tzulingk Aug 20, 2025
386dd2a
Change labels to metrics_labels
tzulingk Aug 20, 2025
402d5e6
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 20, 2025
60171e8
Cargo fmt
tzulingk Aug 20, 2025
40d8f86
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 20, 2025
91101ad
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 20, 2025
946e821
Use different endpoint name to aavoid the duplicated metrics registra…
tzulingk Aug 21, 2025
f7f86be
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
939808a
Remove the metrics_labels in MockVllmEngine, so that we will have the…
tzulingk Aug 21, 2025
cc8364c
format tests/router/test_router_e2e_with_mockers.py
tzulingk Aug 21, 2025
9ef6e2f
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
523984f
Update examples/multimodal/components/worker.py
tzulingk Aug 21, 2025
7f2ea28
Update components/backends/vllm/src/dynamo/vllm/main.py
tzulingk Aug 21, 2025
fd892a1
Update components/backends/vllm/src/dynamo/vllm/publisher.py
tzulingk Aug 21, 2025
ed24e16
Update components/backends/vllm/src/dynamo/vllm/publisher.py
tzulingk Aug 21, 2025
c35521b
Update components/backends/vllm/src/dynamo/vllm/publisher.py
tzulingk Aug 21, 2025
837570b
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
651b4f1
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
a8a84fa
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
406c880
fmt python files
tzulingk Aug 21, 2025
578d25c
python fmt.
tzulingk Aug 21, 2025
b9acce8
pre-commit run --all-files results are different from my local black …
tzulingk Aug 21, 2025
2f294e6
Merge branch 'main' into tzulingk/model_label_dis_444
tzulingk Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions components/backends/vllm/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,12 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()

generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
generate_endpoint = component.endpoint(config.endpoint).add_labels(
[("model", config.model)]
)
clear_endpoint = component.endpoint("clear_kv_blocks").add_labels(
[("model", config.model)]
)

engine_client, _, default_sampling_params = setup_vllm_engine(config)

Expand Down Expand Up @@ -168,8 +172,12 @@ async def init(runtime: DistributedRuntime, config: Config):
component = runtime.namespace(config.namespace).component(config.component)
await component.create_service()

generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
generate_endpoint = component.endpoint(config.endpoint).add_labels(
[("model", config.model)]
)
clear_endpoint = component.endpoint("clear_kv_blocks").add_labels(
[("model", config.model)]
)

prefill_worker_client = (
await runtime.namespace(config.namespace)
Expand All @@ -178,7 +186,9 @@ async def init(runtime: DistributedRuntime, config: Config):
.client()
)

factory = StatLoggerFactory(component, config.engine_args.data_parallel_rank or 0)
factory = StatLoggerFactory(
component, config.engine_args.data_parallel_rank or 0, config.model
)
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(
config, factory
)
Expand Down
14 changes: 10 additions & 4 deletions components/backends/vllm/src/dynamo/vllm/publisher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from optparse import Option
from typing import Optional

from vllm.config import VllmConfig
Expand Down Expand Up @@ -36,9 +37,11 @@ def log_engine_initialized(self):
class DynamoStatLoggerPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""

def __init__(self, component: Component, dp_rank: int) -> None:
def __init__(
self, component: Component, dp_rank: int, model: Optional[str] = None
) -> None:
self.inner = WorkerMetricsPublisher()
self.inner.create_endpoint(component)
self.inner.create_endpoint(component, model)
self.dp_rank = dp_rank
self.num_gpu_block = 1
self.request_total_slots = 1
Expand Down Expand Up @@ -129,15 +132,18 @@ def log_engine_initialized(self) -> None:
class StatLoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""

def __init__(self, component: Component, dp_rank: int = 0) -> None:
def __init__(
self, component: Component, dp_rank: int = 0, model: Optional[str] = None
) -> None:
self.component = component
self.created_logger: Optional[DynamoStatLoggerPublisher] = None
self.dp_rank = dp_rank
self.model = model

def create_stat_logger(self, dp_rank: int) -> StatLoggerBase:
if self.dp_rank != dp_rank:
return NullStatLogger()
logger = DynamoStatLoggerPublisher(self.component, dp_rank)
logger = DynamoStatLoggerPublisher(self.component, dp_rank, self.model)
self.created_logger = logger

return logger
Expand Down
9 changes: 8 additions & 1 deletion components/metrics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,14 @@ async fn app(runtime: Runtime) -> Result<()> {
.context("Unable to create unique instance of Count; possibly one already exists")?;

let target_component = namespace.component(&config.component_name)?;
let target_endpoint = target_component.endpoint(&config.endpoint_name);
let target_endpoint = {
let e = target_component.endpoint(&config.endpoint_name);
if let Some(ref model) = config.model_name {
e.add_labels(&[("model", model.as_str())])?
} else {
e
}
};

let service_path = target_endpoint.path();
let service_subject = target_endpoint.subject();
Expand Down
15 changes: 15 additions & 0 deletions lib/bindings/python/rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,21 @@ impl Endpoint {
.map(|l| l.id())
.unwrap_or(0)
}

/// Add constant labels to this Endpoint (for metrics). Returns a new Endpoint with labels.
/// labels: list of (key, value) tuples.
fn add_labels(&self, labels: Vec<(String, String)>) -> PyResult<Endpoint> {
use rs::metrics::MetricsRegistry as _;
let pairs: Vec<(&str, &str)> = labels
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
let inner = self.inner.clone().add_labels(&pairs).map_err(to_pyerr)?;
Ok(Endpoint {
inner,
event_loop: self.event_loop.clone(),
})
}
}

#[pymethods]
Expand Down
5 changes: 3 additions & 2 deletions lib/bindings/python/rust/llm/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,18 @@ impl WorkerMetricsPublisher {
})
}

#[pyo3(signature = (component))]
#[pyo3(signature = (component, model = None))]
fn create_endpoint<'p>(
&self,
py: Python<'p>,
component: Component,
model: Option<String>,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
rs_publisher
.create_endpoint(rs_component)
.create_endpoint(rs_component, model)
.await
.map_err(to_pyerr)?;
Ok(())
Expand Down
13 changes: 9 additions & 4 deletions lib/llm/src/kv_router/publisher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ impl WorkerMetricsPublisher {
self.tx.send(metrics)
}

pub async fn create_endpoint(&self, component: Component) -> Result<()> {
pub async fn create_endpoint(&self, component: Component, model: Option<String>) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpointHandler::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
Expand All @@ -511,9 +511,14 @@ impl WorkerMetricsPublisher {
// tracing::warn!("Component is static, assuming worker_id of 0");
// 0
// });

component
.endpoint(KV_METRICS_ENDPOINT)
let endpoint = {
let mut e = component.endpoint(KV_METRICS_ENDPOINT);
if let Some(model_name) = model {
e = e.add_labels(&[("model", model_name.as_str())])?;
}
e
};
endpoint
.endpoint_builder()
.stats_handler(move |_| {
let metrics = metrics_rx.borrow_and_update().clone();
Expand Down
5 changes: 4 additions & 1 deletion lib/llm/src/mocker/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ impl MockVllmEngine {
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
if let Err(e) = publisher
.create_endpoint(comp.clone(), Some("mock_model".to_string()))
.await
{
tracing::error!("Metrics endpoint failed: {e}");
}
}
Expand Down
1 change: 1 addition & 0 deletions lib/runtime/examples/hello_world/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async fn app(runtime: Runtime) -> Result<()> {
.namespace(DEFAULT_NAMESPACE)?
.component("backend")?
.endpoint("generate")
.add_labels(&[("model", "hello_world_model")])?
.client()
.await?;
client.wait_for_instances().await?;
Expand Down
1 change: 1 addition & 0 deletions lib/runtime/examples/hello_world/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
.create()
.await?
.endpoint("generate")
.add_labels(&[("model", "hello_world_model")])?
.endpoint_builder()
.handler(ingress)
.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ async fn app(runtime: Runtime) -> Result<()> {
let namespace = distributed.namespace(DEFAULT_NAMESPACE)?;
let component = namespace.component("backend")?;

let client = component.endpoint("generate").client().await?;
let client = component
.endpoint("generate")
.add_labels(&[("model", "service_metrics_model")])?
.client()
.await?;

client.wait_for_instances().await?;
let router =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async fn backend(runtime: DistributedRuntime) -> Result<()> {
.create()
.await?
.endpoint("generate")
.add_labels(&[("model", "service_metrics_model")])?
.endpoint_builder()
.stats_handler(|stats| {
println!("stats: {:?}", stats);
Expand Down
3 changes: 2 additions & 1 deletion lib/runtime/examples/system_metrics/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ pub async fn backend(drt: DistributedRuntime, endpoint_name: Option<&str>) -> Re
.service_builder()
.create()
.await?
.endpoint(endpoint_name);
.endpoint(endpoint_name)
.add_labels(&[("model", DEFAULT_MODEL_NAME)])?;

// Create custom metrics for system stats
let system_metrics =
Expand Down
55 changes: 55 additions & 0 deletions lib/runtime/src/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,61 @@ impl MetricsRegistry for Endpoint {
}

impl Endpoint {
/// Get any stored labels for this registry
pub fn stored_labels(&self) -> Vec<(&str, &str)> {
self.labels
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect()
}

/// Get mutable access to the labels storage - implementors must provide this
fn labels_mut(&mut self) -> &mut Vec<(String, String)> {
&mut self.labels
}

/// Add labels to this Endpoint and return a new instance with the labels.
/// This allows for method chaining like: runtime.namespace(...).add_labels(...)?
/// Fails if:
/// - Provided `labels` contains duplicate keys, or
/// - Any provided key already exists in the registry's stored labels.
pub fn add_labels(mut self, labels: &[(&str, &str)]) -> anyhow::Result<Self>
where
Self: Sized,
{
// 1) Validate for duplicate keys in the input
let mut seen_keys = std::collections::HashSet::new();
for (key, _) in labels {
if !seen_keys.insert(*key) {
return Err(anyhow::anyhow!(
"Duplicate label key '{}' found in labels",
key
));
}
}

// 2) Validate no overlap with existing stored labels
let existing: std::collections::HashSet<&str> =
self.stored_labels().into_iter().map(|(k, _)| k).collect();
if let Some(conflict) = labels
.iter()
.map(|(k, _)| *k)
.find(|k| existing.contains(k))
{
return Err(anyhow::anyhow!(
"Label key '{}' already exists in registry; refusing to overwrite",
conflict
));
}

// 3) Safe to append
let labels_storage = self.labels_mut();
for (key, value) in labels {
labels_storage.push((key.to_string(), value.to_string()));
}
Ok(self)
}

pub fn id(&self) -> EndpointId {
EndpointId {
namespace: self.component.namespace().name().to_string(),
Expand Down
43 changes: 43 additions & 0 deletions lib/runtime/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,49 @@ dynamo_component_nats_total_errors 5"#;

println!("✓ All refactored filter functions work correctly!");
}

#[test]
fn test_add_labels_to_endpoint() {
// Setup
let drt = super::test_helpers::create_test_drt();
let namespace = drt.namespace("ns_labels").unwrap();
let component = namespace.component("comp_labels").unwrap();
let endpoint = component
.endpoint("ep_labels")
.add_labels(&[("label1", "val1"), ("label2", "val2")])
.unwrap();

// Create a metric on the endpoint
let counter = endpoint
.create_counter(
"test_counter_with_labels",
"A test counter",
&endpoint.stored_labels(),
)
.unwrap();
counter.inc_by(10.0);

// Get Prometheus output
let output = endpoint.prometheus_metrics_fmt().unwrap();
let metrics = super::test_helpers::extract_metrics(&output);

// Verification
assert_eq!(metrics.len(), 1);
let (name, labels, value) =
super::test_helpers::parse_prometheus_metric(&metrics[0]).unwrap();

assert_eq!(name, "dynamo_component_test_counter_with_labels");
assert_eq!(value, 10.0);

// Check for auto-labels
assert_eq!(labels.get("dynamo_namespace").unwrap(), "ns_labels");
assert_eq!(labels.get("dynamo_component").unwrap(), "comp_labels");
assert_eq!(labels.get("dynamo_endpoint").unwrap(), "ep_labels");

// Check for custom labels
assert_eq!(labels.get("label1").unwrap(), "val1");
assert_eq!(labels.get("label2").unwrap(), "val2");
}
}

#[cfg(feature = "integration")]
Expand Down
13 changes: 7 additions & 6 deletions lib/runtime/src/pipeline/network/ingress/push_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,42 +55,43 @@ impl WorkHandlerMetrics {
pub fn from_endpoint(
endpoint: &crate::component::Endpoint,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let labels: Vec<(&str, &str)> = endpoint.stored_labels();
let request_counter = endpoint.create_intcounter(
"requests_total",
"Total number of requests processed by work handler",
&[],
&labels,
)?;

let request_duration = endpoint.create_histogram(
"request_duration_seconds",
"Time spent processing requests by work handler",
&[],
&labels,
None,
)?;

let concurrent_requests = endpoint.create_intgauge(
"concurrent_requests",
"Number of requests currently being processed by work handler",
&[],
&labels,
)?;

let request_bytes = endpoint.create_intcounter(
"request_bytes_total",
"Total number of bytes received in requests by work handler",
&[],
&labels,
)?;

let response_bytes = endpoint.create_intcounter(
"response_bytes_total",
"Total number of bytes sent in responses by work handler",
&[],
&labels,
)?;

let error_counter = endpoint.create_intcountervec(
"errors_total",
"Total number of errors in work handler processing",
&["error_type"],
&[],
&labels,
)?;

Ok(Self::new(
Expand Down
Loading