From 662cb68eb01d969ad12156cd5c89096731ddb105 Mon Sep 17 00:00:00 2001 From: vraiti Date: Wed, 6 May 2026 12:09:08 -0400 Subject: [PATCH 01/13] [Feature] Add Prometheus metrics for multi-stage pipelines Co-Authored-By: Claude Opus 4.6 Signed-off-by: vraiti --- docs/design/index.md | 4 + docs/design/metrics.md | 197 +++++++++++++++++++ docs/usage/metrics.md | 79 ++++++++ tests/metrics/test_prometheus.py | 146 ++++++++++++++ vllm_omni/core/sched/omni_scheduler_mixin.py | 16 ++ vllm_omni/engine/async_omni_engine.py | 7 +- vllm_omni/engine/orchestrator.py | 43 +++- vllm_omni/engine/stage_init_utils.py | 2 +- vllm_omni/engine/stage_pool.py | 4 +- vllm_omni/entrypoints/omni_base.py | 19 ++ vllm_omni/metrics/__init__.py | 3 + vllm_omni/metrics/prometheus.py | 119 +++++++++++ vllm_omni/patch.py | 27 +++ 13 files changed, 660 insertions(+), 6 deletions(-) create mode 100644 docs/design/metrics.md create mode 100644 docs/usage/metrics.md create mode 100644 tests/metrics/test_prometheus.py create mode 100644 vllm_omni/metrics/prometheus.py diff --git a/docs/design/index.md b/docs/design/index.md index 61aa5048368..8789f2e1b22 100644 --- a/docs/design/index.md +++ b/docs/design/index.md @@ -13,6 +13,10 @@ This section contains design documents and architecture specifications for vLLM- - [Adding Step Execution Support for Diffusion Pipelines](feature/diffusion_step_execution.md) - [Continuous Batching for Step-Wise Diffusion](feature/diffusion_continuous_batching.md) +## Infrastructure Design Documents + +- [Prometheus Metrics](metrics.md) + ## Module Design Documents - [AR Module](module/ar_module.md) diff --git a/docs/design/metrics.md b/docs/design/metrics.md new file mode 100644 index 00000000000..dcf8b2c04d8 --- /dev/null +++ b/docs/design/metrics.md @@ -0,0 +1,197 @@ +# Prometheus Metrics Design + +This document describes how vLLM-Omni exposes Prometheus metrics for +multi-stage pipelines, the constraints that shaped the design, and how +the pipeline-level metrics coexist with upstream vLLM per-engine +metrics. + +## Objectives + +- Expose pipeline-level request and latency metrics that span the full + multi-stage execution (orchestrator scope). +- Preserve all upstream vLLM per-engine metrics (`vllm:*`) for stages + backed by an AR LLM engine. +- Expose per-stage diffusion timing breakdowns for pipelines that + include a diffusion engine. +- Keep the metrics collection overhead low enough that it does not + regress TTFA or throughput. + +## Background + +### Upstream vLLM Metrics + +Upstream vLLM defines 44 Prometheus metrics under the `vllm:` prefix. +These are registered by `PrometheusStatLogger` and cover engine-level +state: KV cache usage, running/waiting request counts, token +throughput, TTFT, inter-token latency, e2e latency, and so on. They +are served via the `/metrics` HTTP endpoint provided by +`prometheus_fastapi_instrumentator` and the default +`prometheus_client` WSGI handler. + +vLLM's `unregister_vllm_metrics()` function strips every +`prometheus_client` collector whose `_name` attribute contains the +substring `"vllm"`. This runs during engine initialization to clean up +stale collectors from prior instantiations within the same process. + +### The Problem + +vLLM-Omni runs multiple engine instances (stages) within a single +process, coordinated by an Orchestrator. The pipeline needs its own +metrics — aggregate request counts, end-to-end latency across all +stages, and diffusion timing breakdowns — that do not exist in upstream +vLLM. All pipeline-level metrics use the `vllm:omni_` prefix to +distinguish them from upstream per-engine metrics. The +`unregister_vllm_metrics()` function is monkey-patched to a no-op at +import time (see `vllm_omni/patch.py`) so that these metrics are not +destroyed during engine initialization. + +Upstream per-engine metrics retain the `vllm:` prefix and are +registered by a `PrometheusStatLogger` instance that the Orchestrator +creates and feeds directly. + +## Architecture + +### Component Overview + +``` + +-----------------------+ + | API Server (FastAPI)| + | GET /metrics | + +----------+------------+ + | + prometheus_client default handler + | + +-------------+-------------+ + | | + vllm:omni_* collectors vllm:* collectors + | | + +-----------+-----------+ +--------+---------+ + | OmniPrometheusMetrics | | PrometheusStatLogger | + +-----------+-----------+ +--------+---------+ + | | + OmniBase Orchestrator + (request lifecycle, (feeds SchedulerStats + diffusion timing) + IterationStats + per engine step) +``` + +### Data Flow + +There are two independent paths for metric collection. + +**Path 1: Pipeline-level metrics (`vllm:omni_*`)** + +`OmniPrometheusMetrics` registers Gauge, Counter, and Histogram +collectors at init time. It is instantiated once per entrypoint, +labeled with the model name. The entrypoint calls its methods as +requests progress: + +- `set_running(n)` / `set_waiting(n)` — updated after each request + completes. The running count comes from `OmniRequestCounter`, a + simple counter incremented/decremented by the Orchestrator as it + tracks requests. Waiting is derived as `total - running`. + +- `request_succeeded(e2e_seconds, queue_seconds)` — recorded when a + request finishes at the final stage. + +- `request_failed()` — recorded when a request errors. + +- `observe_diffusion_metrics(stage_id, metrics)` — recorded when a + diffusion stage finishes. The metrics dict contains timing + breakdowns (preprocess, exec, postprocess, total step time) + accumulated from engine output. + +**Path 2: Per-engine metrics (`vllm:*`)** + +The Orchestrator instantiates upstream vLLM's `PrometheusStatLogger` +and feeds it scheduler stats and iteration stats after processing +each batch of engine outputs. This populates the standard vLLM +metrics (TTFT, token throughput, cache usage, etc.) using the same +code path as standalone vLLM. For diffusion-only pipelines that have +no AR engine, `SchedulerStats` is never produced and `vllm:*` metrics +are absent. + +### Shared State Between Threads + +The Orchestrator runs in a background thread. The API server +(OmniBase) runs in the asyncio event loop thread. +`OmniRequestCounter` bridges them — a plain Python object with an +`int` field. The Orchestrator increments/decrements it; the +entrypoint reads it for gauge updates. No lock is needed because the +counter is advisory (a stale read by one Prometheus scrape interval +is acceptable). It is created by `AsyncOmniEngine.__init__()` and +passed to the Orchestrator at construction time. + +### Metric Registration and Lifecycle + +All `vllm:omni_*` collectors are registered once when +`OmniPrometheusMetrics.__init__()` runs. Per-stage labels +(`model_name`, `engine`) are bound lazily on first observation to +avoid registering labels for stages that never produce data (e.g., a +diffusion pipeline has no AR stage stats). + +The `prometheus_client` default registry holds all collectors. +FastAPI's `/metrics` endpoint serves the default registry, so both +`vllm:omni_*` and `vllm:*` metrics appear in the same scrape +response alongside `http_*` and `process_*` metrics from the +instrumentator and the Python client runtime. + +## Throttling: `make_stats()` Override + +Upstream vLLM's `Scheduler.make_stats()` runs on every AR generation step, +returning a SchedulerStats object for the orchestrator. +Under vLLM's architecture, this is fine. But since vLLM-Omni requires that the +object be serialized and transferred over ZMQ, receiving a SchedulerStats object on +every step can introduce unacceptable overhead to the system. + +`OmniSchedulerMixin.make_stats()` (in +`vllm_omni/core/sched/omni_scheduler_mixin.py`) throttles stats +emission to at most once per second. Between intervals it returns +`None`, which the engine core skips serializing. This keeps gauges +fresh enough for Prometheus scrapes (typically 15-30s intervals) while +eliminating the per-step overhead. + +## Metric Definitions + +### Pipeline-Level + +| Metric | Type | Labels | Description | +|--------|------|--------|-------------| +| `vllm:omni_num_requests_running` | Gauge | `model_name` | Requests currently executing across all stages | +| `vllm:omni_num_requests_waiting` | Gauge | `model_name` | Requests queued but not yet scheduled | +| `vllm:omni_num_requests_success` | Counter | `model_name` | Requests completed without error | +| `vllm:omni_num_requests_fail` | Counter | `model_name` | Requests that returned an error | +| `vllm:omni_e2e_request_latency_seconds` | Histogram | `model_name` | End-to-end request latency across all stages | +| `vllm:omni_request_queue_time_seconds` | Histogram | `model_name` | Time spent waiting in the request queue | + +### Diffusion Stage-Level + +| Metric | Type | Labels | Description | +|--------|------|--------|-------------| +| `vllm:omni_diffusion_preprocess_time_ms` | Histogram | `model_name`, `engine` | Diffusion input preprocessing time | +| `vllm:omni_diffusion_exec_time_ms` | Histogram | `model_name`, `engine` | Diffusion model forward pass time | +| `vllm:omni_diffusion_postprocess_time_ms` | Histogram | `model_name`, `engine` | Diffusion output postprocessing time | +| `vllm:omni_diffusion_step_time_ms` | Histogram | `model_name`, `engine` | Total diffusion step time | + +### LLM Stage-Level + +Reference [vLLM docs](https://github.com/vllm-project/vllm/blob/main/docs/usage/metrics.md) + +Note that metrics that depend upon features that are not supported in vLLM-Omni (e.g. speculative decoding, LoRA) will not be available as well. + +## Logging vs. Prometheus + +`OrchestratorAggregator` (in `vllm_omni/metrics/stats.py`) is the +logging-oriented metrics path. It collects detailed per-request, +per-stage, and per-transfer statistics and prints formatted tables to +the `INFO` log. This is designed for development and debugging — +individual request traces, transfer bandwidth, inter-stage timing. + +`OmniPrometheusMetrics` is the Prometheus-oriented path. It records +aggregate counters, gauges, and histograms suitable for time-series +monitoring and alerting. The two paths are independent; both can run +simultaneously. + +The separation follows upstream vLLM's pattern of `LoggingStatLogger` +vs. `PrometheusStatLogger` — same underlying data, different +consumption models. diff --git a/docs/usage/metrics.md b/docs/usage/metrics.md new file mode 100644 index 00000000000..60e7193288b --- /dev/null +++ b/docs/usage/metrics.md @@ -0,0 +1,79 @@ +# Production Metrics + +vLLM-Omni exposes Prometheus metrics via the `/metrics` endpoint on the +OpenAI-compatible API server. The metrics fall into three categories depending +on the pipeline type. + +```bash +vllm-omni serve Qwen/Qwen3-Omni-30B-A3B-Instruct --port 8000 +curl http://localhost:8000/metrics +``` + +## Metric Namespaces + +| Prefix | Source | Present when | +|--------|--------|--------------| +| `vllm:omni_` | vLLM-Omni orchestrato / diffusion stages | Always / Pipeline includes a diffusion stage | +| `vllm:` | Upstream vLLM engine | Pipeline includes an LLM (AR) stage | +| `http_` / `process_` | Uvicorn / Python runtime | Always | + +## Pipeline-Level Metrics (`vllm:omni_`) + +These metrics are defined in `vllm_omni/metrics/prometheus.py` and track +request lifecycle across the full multi-stage pipeline. + +### Request Tracking + +| Metric | Type | Labels | Description | +|--------|------|--------|-------------| +| `vllm:omni_num_requests_running` | Gauge | `model_name` | Requests currently running across all pipeline stages | +| `vllm:omni_num_requests_waiting` | Gauge | `model_name` | Requests waiting to be scheduled | +| `vllm:omni_num_requests_success` | Counter | `model_name` | Requests that completed without error | +| `vllm:omni_num_requests_fail` | Counter | `model_name` | Requests that returned an error | + +### Latency + +| Metric | Type | Labels | Description | +|--------|------|--------|-------------| +| `vllm:omni_e2e_request_latency_seconds` | Histogram | `model_name` | End-to-end request latency in seconds | +| `vllm:omni_request_queue_time_seconds` | Histogram | `model_name` | Time spent waiting in the request queue | + +## Diffusion Engine Metrics (`vllm:omni_`) + +These histograms are populated only when the pipeline includes a diffusion +stage (e.g. image or video generation models). + +| Metric | Type | Labels | Description | +|--------|------|--------|-------------| +| `vllm:omni_diffusion_preprocess_time_ms` | Histogram | `model_name`, `engine` | Input preprocessing time per request | +| `vllm:omni_diffusion_exec_time_ms` | Histogram | `model_name`, `engine` | DiT forward pass execution time per request | +| `vllm:omni_diffusion_postprocess_time_ms` | Histogram | `model_name`, `engine` | Output postprocessing time (VAE decode) per request | +| `vllm:omni_diffusion_step_time_ms` | Histogram | `model_name`, `engine` | Total diffusion step time per request | + +## vLLM Engine Metrics (`vllm:`) + +When the pipeline includes an LLM stage, the upstream vLLM engine exposes its +full set of metrics under the `vllm:` prefix. These are registered by +`vllm.v1.metrics.loggers.PrometheusStatLogger` and cover scheduler state, +token throughput, cache utilization, and request latencies. + +For a full overview of vLLM metrics, consult [the vLLM docs](https://github.com/vllm-project/vllm/blob/main/docs/usage/metrics.md) + +## Metric Availability by Pipeline Type + +| Metric group | Multi-stage LLM (Qwen3-Omni) | Diffusion-only (Z-Image-Turbo) | +|---|---|---| +| `vllm:omni_` request tracking | Yes | Yes | +| `vllm:omni_` latency | Yes | Yes | +| `vllm:omni_` KV cache | Yes | No | +| `vllm:omni_` diffusion timing | Only if pipeline has a diffusion stage | Yes | +| `vllm:` engine metrics | Yes | No | +| `vllm:` MFU metrics | With `--enable-mfu-metrics` | No | + +## Naming Convention + +vLLM-Omni pipeline metrics use the `vllm:omni_` prefix to distinguish +them from upstream per-engine `vllm:` metrics. The upstream +`unregister_vllm_metrics()` function is monkey-patched to a no-op (see +`vllm_omni/patch.py`) so that these metrics are not destroyed during +engine initialization. diff --git a/tests/metrics/test_prometheus.py b/tests/metrics/test_prometheus.py new file mode 100644 index 00000000000..94b50aaeea6 --- /dev/null +++ b/tests/metrics/test_prometheus.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import re + +import pytest +from prometheus_client import REGISTRY, CollectorRegistry, generate_latest + +from vllm_omni.metrics import OmniPrometheusMetrics + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +_MODEL = "test-model" + +_PIPELINE_METRICS = [ + "vllm:omni_num_requests_running", + "vllm:omni_num_requests_waiting", + "vllm:omni_num_requests_success", + "vllm:omni_num_requests_fail", + "vllm:omni_e2e_request_latency_seconds", + "vllm:omni_request_queue_time_seconds", +] + +_DIFFUSION_METRICS = [ + "vllm:omni_diffusion_preprocess_time_ms", + "vllm:omni_diffusion_exec_time_ms", + "vllm:omni_diffusion_postprocess_time_ms", + "vllm:omni_diffusion_step_time_ms", +] + + +@pytest.fixture(scope="module") +def registry() -> CollectorRegistry: + return REGISTRY + + +@pytest.fixture(scope="module") +def prom() -> OmniPrometheusMetrics: + return OmniPrometheusMetrics(model_name=_MODEL) + + +@pytest.fixture(scope="module") +def scrape_output(prom: OmniPrometheusMetrics, registry: CollectorRegistry) -> str: + prom.request_succeeded(e2e_seconds=1.5, queue_seconds=0.3) + prom.request_succeeded(e2e_seconds=2.0, queue_seconds=0.5) + prom.request_failed() + prom.set_running(5) + prom.set_waiting(2) + prom.observe_diffusion_metrics( + stage_id=1, + metrics={ + "preprocess_time_ms": 10.0, + "diffusion_engine_exec_time_ms": 200.0, + "postprocess_time_ms": 15.0, + "diffusion_engine_total_time_ms": 225.0, + }, + ) + return generate_latest(registry).decode() + + +def _sample_value(output: str, metric_line: str) -> float | None: + for line in output.splitlines(): + if line.startswith(metric_line): + return float(line.split()[-1]) + return None + + +class TestMetricObservation: + def test_all_metric_families_present(self, scrape_output: str) -> None: + for name in _PIPELINE_METRICS + _DIFFUSION_METRICS: + assert f"# HELP {name}" in scrape_output, f"missing metric family: {name}" + + def test_counter_values(self, scrape_output: str) -> None: + success = _sample_value( + scrape_output, + f'vllm:omni_num_requests_success_total{{model_name="{_MODEL}"}}', + ) + assert success == 2.0 + + fail = _sample_value( + scrape_output, + f'vllm:omni_num_requests_fail_total{{model_name="{_MODEL}"}}', + ) + assert fail == 1.0 + + def test_gauge_values(self, scrape_output: str) -> None: + running = _sample_value( + scrape_output, + f'vllm:omni_num_requests_running{{model_name="{_MODEL}"}}', + ) + assert running == 5.0 + + waiting = _sample_value( + scrape_output, + f'vllm:omni_num_requests_waiting{{model_name="{_MODEL}"}}', + ) + assert waiting == 2.0 + + def test_histogram_counts(self, scrape_output: str) -> None: + e2e_count = _sample_value( + scrape_output, + f'vllm:omni_e2e_request_latency_seconds_count{{model_name="{_MODEL}"}}', + ) + assert e2e_count == 2.0 + + queue_count = _sample_value( + scrape_output, + f'vllm:omni_request_queue_time_seconds_count{{model_name="{_MODEL}"}}', + ) + assert queue_count == 2.0 + + def test_diffusion_histogram_counts(self, scrape_output: str) -> None: + for name in _DIFFUSION_METRICS: + count = _sample_value( + scrape_output, + f'{name}_count{{engine="1",model_name="{_MODEL}"}}', + ) + assert count == 1.0, f"{name}_count expected 1.0, got {count}" + + +class TestLabelCorrectness: + def test_pipeline_metrics_carry_model_name(self, scrape_output: str) -> None: + for name in _PIPELINE_METRICS: + pattern = rf'^{re.escape(name)}.*model_name="{re.escape(_MODEL)}"' + assert re.search(pattern, scrape_output, re.MULTILINE), f"{name} missing model_name label" + + def test_diffusion_metrics_carry_engine_label(self, scrape_output: str) -> None: + for name in _DIFFUSION_METRICS: + pattern = rf'^{re.escape(name)}.*engine="1".*model_name="{re.escape(_MODEL)}"' + assert re.search(pattern, scrape_output, re.MULTILINE), f"{name} missing engine label" + + def test_no_stage_id_label(self, scrape_output: str) -> None: + assert "stage_id=" not in scrape_output + + +class TestScrapeOutput: + def test_omni_metrics_in_default_registry(self, scrape_output: str) -> None: + for name in _PIPELINE_METRICS + _DIFFUSION_METRICS: + assert name in scrape_output + + def test_process_metrics_in_default_registry(self, scrape_output: str) -> None: + # vllm:* metrics require a full PrometheusStatLogger with VllmConfig + # and are registered by the Orchestrator at server startup. Verifying + # their presence is covered by integration tests. Here we confirm the + # default registry is being scraped by checking for process_* metrics + # from the Python prometheus_client runtime. + assert "process_" in scrape_output diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py index 36080e63acc..604121dbf42 100644 --- a/vllm_omni/core/sched/omni_scheduler_mixin.py +++ b/vllm_omni/core/sched/omni_scheduler_mixin.py @@ -1,8 +1,13 @@ from __future__ import annotations +import time + from vllm.v1.engine import EngineCoreEventType +from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.request import Request, RequestStatus, StreamingUpdate +_STATS_INTERVAL_S = 1.0 + class OmniSchedulerMixin: """Shared scheduler helpers for omni-specific request handling.""" @@ -31,3 +36,14 @@ def _replace_session_with_streaming_update( if self.log_stats: session.record_event(EngineCoreEventType.QUEUED) + + def make_stats(self, *args, **kwargs) -> SchedulerStats | None: + now = time.monotonic() + if now - getattr(self, "_last_stats_time", 0.0) < _STATS_INTERVAL_S: + return None + self._last_stats_time = now + return SchedulerStats( + kv_cache_usage=self.kv_cache_manager.usage, + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + ) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index b2dc839c976..25923ab4801 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -89,6 +89,7 @@ inject_omni_kv_config, load_and_resolve_stage_configs, ) +from vllm_omni.metrics.prometheus import OmniRequestCounter from vllm_omni.platforms import current_omni_platform if TYPE_CHECKING: @@ -306,6 +307,7 @@ def __init__( self._shutdown_called = False self._weak_finalizer: weakref.finalize | None = None self._rpc_lock = threading.Lock() + self._running_counter = OmniRequestCounter() logger.info(f"[AsyncOmniEngine] Launching Orchestrator thread with {self.num_stages} stages") @@ -669,7 +671,7 @@ def _initialize_llm_replica( launch_omni_core_engines( vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, + log_stats=True, omni_master_server=self._omni_master_server, stage_id=plan.metadata.stage_id, stage_config=stage_cfg, @@ -680,7 +682,7 @@ def _initialize_llm_replica( addresses, proc, handshake_address = spawn_stage_core( vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, + log_stats=True, ) logger.info( "[AsyncOmniEngine] Stage %s engine launch started", @@ -1077,6 +1079,7 @@ async def _run_orchestrator() -> None: stage_pools=self.stage_pools, async_chunk=self.async_chunk, pd_config=pd_config, + running_counter=self._running_counter, ) if not startup_future.done(): startup_future.set_result(asyncio.get_running_loop()) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 2d2ac47cbb3..2cd706d4c7f 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -21,11 +21,14 @@ from vllm.sampling_params import SamplingParams from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine.exceptions import EngineDeadError +from vllm.v1.metrics.loggers import PrometheusStatLogger +from vllm.v1.metrics.stats import IterationStats from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.engine.serialization import serialize_additional_information from vllm_omni.engine.stage_pool import StagePool +from vllm_omni.metrics.prometheus import OmniRequestCounter from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -122,6 +125,7 @@ def __init__( *, async_chunk: bool = False, pd_config: dict[str, Any] | None = None, + running_counter: OmniRequestCounter | None = None, ) -> None: self.request_async_queue = request_async_queue self.output_async_queue = output_async_queue @@ -141,6 +145,22 @@ def __init__( self._pd_bootstrap_addr = pd_config.get("bootstrap_addr") self._pd_prefill_engine_id = pd_config.get("prefill_engine_id") self.request_states: dict[str, OrchestratorRequestState] = {} + self._running_counter = running_counter + + vllm_config_for_stats = next( + (p.stage_vllm_config for p in stage_pools if p.stage_vllm_config is not None), + None, + ) + if vllm_config_for_stats is not None: + self._stat_logger: PrometheusStatLogger | None = PrometheusStatLogger( + vllm_config=vllm_config_for_stats, + engine_indexes=list(range(self.num_stages)), + ) + else: + self._stat_logger = None + self._last_stats_ts: float = 0.0 + self._stats_interval_s: float = 1.0 + self._cfg_tracker = CfgCompanionTracker() self._shutdown_event = asyncio.Event() @@ -247,6 +267,8 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None: mm_features=getattr(prompt, "mm_features", None), ) self.request_states[request_id] = req_state + if self._running_counter is not None: + self._running_counter.increment() req_state.streaming.enabled = bool(getattr(prompt, "resumable", False)) req_state.stage_submit_ts[stage_id] = _time.time() enqueue_ts = msg.get("enqueue_ts", 0.0) @@ -446,7 +468,23 @@ async def _orchestration_loop(self) -> None: "new_prompt_len_snapshot", None, ) - raw_output = await pool.process_llm_raw_outputs(replica_id, raw_outputs) + now = _time.monotonic() + record_stats = ( + self._stat_logger is not None and now - self._last_stats_ts >= self._stats_interval_s + ) + iteration_stats = IterationStats() if record_stats else None + raw_output = await pool.process_llm_raw_outputs( + replica_id, + raw_outputs, + iteration_stats=iteration_stats, + ) + if record_stats: + self._last_stats_ts = now + self._stat_logger.record( + raw_outputs.scheduler_stats, + iteration_stats, + engine_idx=stage_id, + ) except asyncio.CancelledError: raise except EngineDeadError as e: @@ -558,7 +596,8 @@ async def _cleanup_request_ids(self, request_ids: list[str], *, abort: bool = Fa self._release_request_bindings(request_ids) for request_id in request_ids: self._pd_kv_params.pop(request_id, None) - self.request_states.pop(request_id, None) + if self.request_states.pop(request_id, None) is not None and self._running_counter is not None: + self._running_counter.decrement() def _maybe_clone_diffusion_params_for_cfg(self, request_id: str, params: Any) -> Any: """Attach CFG companion ids to diffusion sampling params when needed.""" diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index ce68a23daa4..f883197c006 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -660,7 +660,7 @@ def build_llm_stage_output_processor(plan: LogicalStageInitPlan, stage_vllm_conf ) return MultimodalOutputProcessor( tokenizer=tokenizer, - log_stats=False, + log_stats=True, engine_core_output_type=metadata.engine_output_type, ) diff --git a/vllm_omni/engine/stage_pool.py b/vllm_omni/engine/stage_pool.py index 6f745427112..3a267b86611 100644 --- a/vllm_omni/engine/stage_pool.py +++ b/vllm_omni/engine/stage_pool.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.v1.engine import EngineCoreOutputs +from vllm.v1.metrics.stats import IterationStats from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics from vllm_omni.metrics.stats import StageStats @@ -272,6 +273,7 @@ async def process_llm_raw_outputs( self, replica_id: int, raw_outputs: EngineCoreOutputs, + iteration_stats: IterationStats | None = None, ) -> list[Any]: """Run the shared LLM output processor on one raw poll result.""" client = self.clients[replica_id] @@ -279,7 +281,7 @@ async def process_llm_raw_outputs( processed = processor.process_outputs( raw_outputs.outputs, raw_outputs.timestamp, - None, + iteration_stats, ) if processed.reqs_to_abort: diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index 4e84d620026..65278dacacf 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -18,6 +18,7 @@ from vllm_omni.entrypoints.client_request_state import ClientRequestState from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin from vllm_omni.entrypoints.utils import coerce_param_message_types, get_final_stage_id_for_e2e +from vllm_omni.metrics.prometheus import OmniPrometheusMetrics from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific from vllm_omni.outputs import OmniRequestOutput @@ -185,6 +186,7 @@ def __init__( self.async_chunk = bool(getattr(self.engine, "async_chunk", False)) self.request_states: dict[str, ClientRequestState] = {} + self.prom_metrics = OmniPrometheusMetrics(model_name=model) self.default_sampling_params_list = self.engine.default_sampling_params_list if not self.output_modalities: @@ -273,6 +275,8 @@ def _log_summary_and_cleanup(self, request_id: str) -> None: try: if req_state is None or req_state.metrics is None: return + if str(request_id) not in req_state.metrics.e2e_done: + self.prom_metrics.request_failed() except Exception: logger.exception( "[%s] Failed to build/log summary for req=%s", @@ -447,9 +451,24 @@ def _process_single_result( req_id, req_start_ts.get(req_id, wall_start_ts), ) + e2e_seconds = now - req_start_ts.get(req_id, wall_start_ts) + _fin_m = result.get("metrics") + _pt = getattr(_fin_m, "pipeline_timings", None) or {} + queue_ms = _pt.get("queue_wait_ms") + queue_seconds = queue_ms / 1000.0 if queue_ms is not None else None + self.prom_metrics.request_succeeded(e2e_seconds, queue_seconds=queue_seconds) except Exception: logger.exception("[%s] Finalize request handling error", self.__class__.__name__) + running = self.engine._running_counter.value + total = len(self.request_states) + self.prom_metrics.set_running(running) + self.prom_metrics.set_waiting(max(0, total - running)) + + diffusion_metrics = getattr(engine_outputs, "metrics", None) + if finished and isinstance(diffusion_metrics, dict) and diffusion_metrics: + self.prom_metrics.observe_diffusion_metrics(stage_id, diffusion_metrics) + output_type = getattr(engine_outputs, "final_output_type", stage_meta["final_output_type"]) images = getattr(engine_outputs, "images", []) if output_type == "image" else [] return OmniRequestOutput( diff --git a/vllm_omni/metrics/__init__.py b/vllm_omni/metrics/__init__.py index deceb23333a..6814a589181 100644 --- a/vllm_omni/metrics/__init__.py +++ b/vllm_omni/metrics/__init__.py @@ -1,7 +1,10 @@ +from .prometheus import OmniPrometheusMetrics, OmniRequestCounter from .stats import OrchestratorAggregator, StageRequestStats, StageStats from .utils import count_tokens_from_outputs __all__ = [ + "OmniPrometheusMetrics", + "OmniRequestCounter", "OrchestratorAggregator", "StageStats", "StageRequestStats", diff --git a/vllm_omni/metrics/prometheus.py b/vllm_omni/metrics/prometheus.py new file mode 100644 index 00000000000..c510e17df81 --- /dev/null +++ b/vllm_omni/metrics/prometheus.py @@ -0,0 +1,119 @@ +from prometheus_client import Counter, Gauge, Histogram + +_labelnames = ["model_name"] +_diffusion_labelnames = ["model_name", "engine"] + +_DIFFUSION_METRIC_DEFS: dict[str, tuple[str, str]] = { + "preprocess_time_ms": ( + "vllm:omni_diffusion_preprocess_time_ms", + "Diffusion preprocess time per request in milliseconds.", + ), + "diffusion_engine_exec_time_ms": ( + "vllm:omni_diffusion_exec_time_ms", + "Diffusion model execution time per request in milliseconds.", + ), + "postprocess_time_ms": ( + "vllm:omni_diffusion_postprocess_time_ms", + "Diffusion postprocess time per request in milliseconds.", + ), + "diffusion_engine_total_time_ms": ( + "vllm:omni_diffusion_step_time_ms", + "Total diffusion step time per request in milliseconds.", + ), +} + +_running_family = Gauge( + "vllm:omni_num_requests_running", + "Number of requests currently running across all pipeline stages.", + labelnames=_labelnames, +) +_waiting_family = Gauge( + "vllm:omni_num_requests_waiting", + "Number of requests waiting to be scheduled.", + labelnames=_labelnames, +) +_success_family = Counter( + "vllm:omni_num_requests_success", + "Number of requests that completed without error.", + labelnames=_labelnames, +) +_fail_family = Counter( + "vllm:omni_num_requests_fail", + "Number of requests that returned an error.", + labelnames=_labelnames, +) +_e2e_latency_family = Histogram( + "vllm:omni_e2e_request_latency_seconds", + "Histogram of end-to-end request latency in seconds.", + labelnames=_labelnames, +) +_queue_time_family = Histogram( + "vllm:omni_request_queue_time_seconds", + "Histogram of request queue wait time in seconds.", + labelnames=_labelnames, +) +_diffusion_families: dict[str, Histogram] = { + key: Histogram(metric_name, desc, labelnames=_diffusion_labelnames) + for key, (metric_name, desc) in _DIFFUSION_METRIC_DEFS.items() +} + + +class OmniPrometheusMetrics: + """Label-bound wrapper around the raw Prometheus metrics. + + Metric collectors use the ``vllm:omni_`` prefix to avoid being + removed by upstream vLLM's ``unregister_vllm_metrics()``, which + strips every collector whose ``_name`` contains ``"vllm"``. + """ + + def __init__(self, model_name: str) -> None: + self._model_name = model_name + self._running = _running_family.labels(model_name=model_name) + self._waiting = _waiting_family.labels(model_name=model_name) + self._success = _success_family.labels(model_name=model_name) + self._fail = _fail_family.labels(model_name=model_name) + self._e2e_latency = _e2e_latency_family.labels(model_name=model_name) + self._queue_time = _queue_time_family.labels(model_name=model_name) + self._diffusion_by_stage: dict[tuple[str, int], Histogram] = {} + + def set_running(self, n: int) -> None: + self._running.set(n) + + def set_waiting(self, n: int) -> None: + self._waiting.set(n) + + def request_succeeded(self, e2e_seconds: float, queue_seconds: float | None = None) -> None: + self._success.inc() + self._e2e_latency.observe(e2e_seconds) + if queue_seconds is not None: + self._queue_time.observe(queue_seconds) + + def request_failed(self) -> None: + self._fail.inc() + + def observe_diffusion_metrics(self, stage_id: int, metrics: dict[str, float]) -> None: + for key, parent in _diffusion_families.items(): + value = metrics.get(key) + if value is None: + continue + bound = self._diffusion_by_stage.get((key, stage_id)) + if bound is None: + bound = parent.labels( + model_name=self._model_name, + engine=str(stage_id), + ) + self._diffusion_by_stage[(key, stage_id)] = bound + bound.observe(value) + + +class OmniRequestCounter: + """Running-request counter written by the orchestrator thread, read by the client thread.""" + + def __init__(self) -> None: + self.value = 0 + + def increment(self) -> None: + self.value += 1 + + def decrement(self) -> None: + self.value = max(0, self.value - 1) diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py index 74a3d8a7671..0a8f8b81474 100644 --- a/vllm_omni/patch.py +++ b/vllm_omni/patch.py @@ -1,3 +1,4 @@ +import logging import sys from functools import cached_property @@ -123,3 +124,29 @@ def _patched_glm_image_text_config_init(self, *args, **kwargs): module.StreamingUpdate = OmniStreamingUpdate if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest: module.EngineCoreRequest = OmniEngineCoreRequest + +# ============================================================================= +# Patch unregister_vllm_metrics to a no-op +# ============================================================================= +# WHY: unregister_vllm_metrics() uses `"vllm" in collector._name` to strip +# collectors from the Prometheus registry. This destroys any vllm-omni +# metrics that use the vllm: namespace. +# +# REMOVAL: Remove this patch once upstream vLLM adds +# _STAT_LOGGER_METRIC_NAMES to vllm.v1.metrics.prometheus and scopes +# unregister_vllm_metrics() to that set. Track: +# https://github.com/vllm-project/vllm/pull/42331 +import vllm.v1.metrics.prometheus as _vllm_prometheus + +_logger = logging.getLogger(__name__) + + +def _noop_unregister_vllm_metrics(): + pass + + +_vllm_prometheus.unregister_vllm_metrics = _noop_unregister_vllm_metrics +_logger.warning( + "Monkey-patched unregister_vllm_metrics() to a no-op. " + "Remove this patch once vLLM adds _STAT_LOGGER_METRIC_NAMES." +) From 20f29affb835f8e4afaa7fcdf1f261b8f359b2d2 Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 12:25:34 +0800 Subject: [PATCH 02/13] [Metrics] Add shared definitions module for Prometheus and bench CLI Introduce vllm_omni/metrics/definitions.py as the single source of truth for metric family names, label sets, histogram buckets, and RTF formulas. Server-side prometheus.py and bench-side MultiModalsBenchmark Metrics now consume the same constants, eliminating the dual-track naming drift between /metrics output and bench CLI reports. Pre-work for the upcoming RFC extensions: per-modality audio/image/ video families (G1/G2), cross-stage transfer family (G3), and the OmniPrometheusStatLogger wrap that relabels engine into stage+replica (G7). All refactors are name-only; PR #3362's existing 10 prometheus tests pass unchanged. --- vllm_omni/benchmarks/metrics/metrics.py | 20 ++-- vllm_omni/metrics/definitions.py | 143 ++++++++++++++++++++++++ vllm_omni/metrics/prometheus.py | 27 +++-- 3 files changed, 170 insertions(+), 20 deletions(-) create mode 100644 vllm_omni/metrics/definitions.py diff --git a/vllm_omni/benchmarks/metrics/metrics.py b/vllm_omni/benchmarks/metrics/metrics.py index f320fffe9fc..7abdaf7bac2 100644 --- a/vllm_omni/benchmarks/metrics/metrics.py +++ b/vllm_omni/benchmarks/metrics/metrics.py @@ -7,6 +7,8 @@ from vllm.benchmarks.serve import MILLISECONDS_TO_SECONDS_CONVERSION, TERM_PLOTLIB_AVAILABLE, BenchmarkMetrics, TaskType from vllm.tokenizers import TokenizerLike +from vllm_omni.metrics import definitions as defs + @dataclass class MultiModalsBenchmarkMetrics(BenchmarkMetrics): @@ -94,16 +96,16 @@ def process_one_metric( "tpot": "Time per Output Token (excl. 1st token)", "itl": "Inter-token Latency", "e2el": "End-to-end Latency", - "audio_ttfp": "Time to First Packet", - "audio_rtf": "Real Time Factor", - "audio_duration": "Audio Duration", + defs.AUDIO_TTFP: "Time to First Packet", + defs.AUDIO_RTF: "Real Time Factor", + defs.AUDIO_DURATION: "Audio Duration", } header = metric_header_map.get(metric_attribute_name, metric_attribute_name) print("{s:{c}^{n}}".format(s=header, n=50, c="-")) - is_audio_rtf = metric_attribute_name == "audio_rtf" - is_audio_duration = metric_attribute_name == "audio_duration" + is_audio_rtf = metric_attribute_name == defs.AUDIO_RTF + is_audio_duration = metric_attribute_name == defs.AUDIO_DURATION suffix = "_ms" unit_suffix = " (ms)" @@ -198,10 +200,10 @@ def calculate_metrics( all_tpots.append(tpot) itls += outputs[i].itl ttfts.append(outputs[i].ttft) - audio_ttfps.append(getattr(outputs[i], "audio_ttfp", 0.0)) - audio_rtfs.append(getattr(outputs[i], "audio_rtf", 0.0)) - audio_duration.append(getattr(outputs[i], "audio_duration", 0.0)) - audio_frames.append(getattr(outputs[i], "audio_frames", 0.0)) + audio_ttfps.append(getattr(outputs[i], defs.AUDIO_TTFP, 0.0)) + audio_rtfs.append(getattr(outputs[i], defs.AUDIO_RTF, 0.0)) + audio_duration.append(getattr(outputs[i], defs.AUDIO_DURATION, 0.0)) + audio_frames.append(getattr(outputs[i], defs.AUDIO_FRAMES, 0.0)) e2els.append(outputs[i].latency) input_audio_duration += outputs[i].input_audio_duration completed += 1 diff --git a/vllm_omni/metrics/definitions.py b/vllm_omni/metrics/definitions.py new file mode 100644 index 00000000000..babd6104cc8 --- /dev/null +++ b/vllm_omni/metrics/definitions.py @@ -0,0 +1,143 @@ +"""Single source of truth for vLLM-Omni Prometheus + bench CLI metric naming. + +Consumed by: +- vllm_omni.metrics.prometheus (server-side /metrics families) +- vllm_omni.benchmarks.metrics.metrics (bench CLI MultiModalsBenchmarkMetrics) + +RFC: vLLM-Omni Prometheus 多模态语义、跨 stage Transfer (G4/G5). +""" + +# vllm:omni_ avoids upstream's unregister_vllm_metrics() stripping; matches PR #3362. +METRIC_PREFIX = "vllm:omni_" + + +# ============================================================================ +# Bench-side stems (also used as RequestFuncOutput attribute names) +# ============================================================================ +AUDIO_TTFP = "audio_ttfp" +AUDIO_DURATION = "audio_duration" +AUDIO_RTF = "audio_rtf" +AUDIO_FRAMES = "audio_frames" + +IMAGE_TTFP = "image_ttfp" +IMAGE_NUM = "image_num" +IMAGE_GENERATION_TIME = "image_generation_time" + +VIDEO_DURATION = "video_duration" +VIDEO_RTF = "video_rtf" +VIDEO_GENERATION_TIME = "video_generation_time" + + +# ============================================================================ +# Pipeline-level metric families (PR #3362 + G6) +# ============================================================================ +NUM_REQUESTS_RUNNING = METRIC_PREFIX + "num_requests_running" +NUM_REQUESTS_WAITING = METRIC_PREFIX + "num_requests_waiting" +NUM_REQUESTS_SUCCESS = METRIC_PREFIX + "num_requests_success" +NUM_REQUESTS_FAIL = METRIC_PREFIX + "num_requests_fail" +E2E_REQUEST_LATENCY_SECONDS = METRIC_PREFIX + "e2e_request_latency_seconds" +REQUEST_QUEUE_TIME_SECONDS = METRIC_PREFIX + "request_queue_time_seconds" + +# G6: requests_success_total{finished_reason} — Pipeline 全局 Counter +REQUESTS_SUCCESS_TOTAL = METRIC_PREFIX + "requests_success_total" + + +# ============================================================================ +# Audio family (G1) +# ============================================================================ +AUDIO_TTFP_SECONDS = METRIC_PREFIX + AUDIO_TTFP + "_seconds" +AUDIO_DURATION_SECONDS = METRIC_PREFIX + AUDIO_DURATION + "_seconds" +AUDIO_RTF_METRIC = METRIC_PREFIX + AUDIO_RTF +AUDIO_FRAMES_METRIC = METRIC_PREFIX + AUDIO_FRAMES + + +# ============================================================================ +# Image / Video family (G2) +# ============================================================================ +IMAGE_TTFP_SECONDS = METRIC_PREFIX + IMAGE_TTFP + "_seconds" +IMAGE_NUM_METRIC = METRIC_PREFIX + IMAGE_NUM +IMAGE_GENERATION_TIME_SECONDS = METRIC_PREFIX + IMAGE_GENERATION_TIME + "_seconds" + +VIDEO_DURATION_SECONDS = METRIC_PREFIX + VIDEO_DURATION + "_seconds" +VIDEO_RTF_METRIC = METRIC_PREFIX + VIDEO_RTF +VIDEO_GENERATION_TIME_SECONDS = METRIC_PREFIX + VIDEO_GENERATION_TIME + "_seconds" + + +# ============================================================================ +# Diffusion ms-level timing (PR #3362) +# ============================================================================ +DIFFUSION_PREPROCESS_TIME_MS = METRIC_PREFIX + "diffusion_preprocess_time_ms" +DIFFUSION_EXEC_TIME_MS = METRIC_PREFIX + "diffusion_exec_time_ms" +DIFFUSION_POSTPROCESS_TIME_MS = METRIC_PREFIX + "diffusion_postprocess_time_ms" +DIFFUSION_STEP_TIME_MS = METRIC_PREFIX + "diffusion_step_time_ms" + + +# ============================================================================ +# Cross-stage Transfer family (G3) +# ============================================================================ +TRANSFER_SIZE_BYTES = METRIC_PREFIX + "transfer_size_bytes" +TRANSFER_TX_TIME_MS = METRIC_PREFIX + "transfer_tx_time_ms" +TRANSFER_RX_DECODE_TIME_MS = METRIC_PREFIX + "transfer_rx_decode_time_ms" +TRANSFER_IN_FLIGHT_TIME_MS = METRIC_PREFIX + "transfer_in_flight_time_ms" + + +# ============================================================================ +# Label sets +# ============================================================================ +PIPELINE_LABELS = ("model_name",) +SUCCESS_LABELS = ("model_name", "finished_reason") + +# Per-stage / per-replica label set used by audio/image/video families and by +# the OmniPrometheusStatLogger wrap (G7) which relabels upstream `engine` into +# `stage` + `replica`. +STAGE_LABELS = ("model_name", "stage", "replica") + +# Cross-stage transfer label set (G3). Field names match TransferEdgeStats. +TRANSFER_LABELS = ("from_stage", "from_replica", "to_stage", "to_replica") + + +# ============================================================================ +# Histogram buckets +# ============================================================================ +# Seconds bucket for TTFP / duration / generation time families. +SECONDS_BUCKETS = ( + 0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 30.0, 60.0, 120.0, 300.0, +) + +# Milliseconds bucket for transfer tx / rx / in-flight times. +MS_BUCKETS = ( + 1.0, 5.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 1000.0, 2500.0, 5000.0, 10000.0, +) + +# RTF SLO red line is 1.0 (TTS must generate faster than playback). +RTF_BUCKETS = ( + 0.1, 0.25, 0.5, 0.75, 0.9, 1.0, 1.25, 1.5, 2.0, 5.0, 10.0, +) + +# Bytes bucket for transfer payload size. +BYTES_BUCKETS = ( + 1024, 4096, 16384, 65536, 262144, 1048576, + 4194304, 16777216, 67108864, 268435456, +) + + +# ============================================================================ +# Formula helpers (shared by server-side observe and bench-side calculation) +# ============================================================================ +def compute_audio_rtf(stage_gen_time_s: float, audio_duration_s: float) -> float: + """RTF = stage_gen_time / audio_content_duration. + + SLO red line < 1 — must generate faster than content plays back to stream. + Returns 0.0 when audio_duration_s is non-positive (caller decides whether + to observe; we don't want to divide by zero or emit negative samples). + """ + if audio_duration_s <= 0: + return 0.0 + return stage_gen_time_s / audio_duration_s + + +def compute_video_rtf(stage_gen_time_s: float, video_duration_s: float) -> float: + """Same definition as audio RTF.""" + if video_duration_s <= 0: + return 0.0 + return stage_gen_time_s / video_duration_s diff --git a/vllm_omni/metrics/prometheus.py b/vllm_omni/metrics/prometheus.py index c510e17df81..b0dad9c80d3 100644 --- a/vllm_omni/metrics/prometheus.py +++ b/vllm_omni/metrics/prometheus.py @@ -1,54 +1,59 @@ from prometheus_client import Counter, Gauge, Histogram -_labelnames = ["model_name"] +from vllm_omni.metrics import definitions as defs + +_labelnames = list(defs.PIPELINE_LABELS) _diffusion_labelnames = ["model_name", "engine"] +# Mapping from stage-emitted metric key (engine internal name) to the +# (prometheus family name, help text) we expose. Keys must match what the +# diffusion engine puts into its per-request metrics dict. _DIFFUSION_METRIC_DEFS: dict[str, tuple[str, str]] = { "preprocess_time_ms": ( - "vllm:omni_diffusion_preprocess_time_ms", + defs.DIFFUSION_PREPROCESS_TIME_MS, "Diffusion preprocess time per request in milliseconds.", ), "diffusion_engine_exec_time_ms": ( - "vllm:omni_diffusion_exec_time_ms", + defs.DIFFUSION_EXEC_TIME_MS, "Diffusion model execution time per request in milliseconds.", ), "postprocess_time_ms": ( - "vllm:omni_diffusion_postprocess_time_ms", + defs.DIFFUSION_POSTPROCESS_TIME_MS, "Diffusion postprocess time per request in milliseconds.", ), "diffusion_engine_total_time_ms": ( - "vllm:omni_diffusion_step_time_ms", + defs.DIFFUSION_STEP_TIME_MS, "Total diffusion step time per request in milliseconds.", ), } _running_family = Gauge( - "vllm:omni_num_requests_running", + defs.NUM_REQUESTS_RUNNING, "Number of requests currently running across all pipeline stages.", labelnames=_labelnames, ) _waiting_family = Gauge( - "vllm:omni_num_requests_waiting", + defs.NUM_REQUESTS_WAITING, "Number of requests waiting to be scheduled.", labelnames=_labelnames, ) _success_family = Counter( - "vllm:omni_num_requests_success", + defs.NUM_REQUESTS_SUCCESS, "Number of requests that completed without error.", labelnames=_labelnames, ) _fail_family = Counter( - "vllm:omni_num_requests_fail", + defs.NUM_REQUESTS_FAIL, "Number of requests that returned an error.", labelnames=_labelnames, ) _e2e_latency_family = Histogram( - "vllm:omni_e2e_request_latency_seconds", + defs.E2E_REQUEST_LATENCY_SECONDS, "Histogram of end-to-end request latency in seconds.", labelnames=_labelnames, ) _queue_time_family = Histogram( - "vllm:omni_request_queue_time_seconds", + defs.REQUEST_QUEUE_TIME_SECONDS, "Histogram of request queue wait time in seconds.", labelnames=_labelnames, ) From cfd69bb6ec0cb8f132e2d894bcfe6484e6bb370e Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 16:26:02 +0800 Subject: [PATCH 03/13] [Metrics] Add relabel mixin for upstream PrometheusStatLogger wrap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2.1 of multi-replica observability (RFC §3.2.7): introduce three wrapper metric classes (_RelabelGauge / _RelabelCounter / _RelabelHistogram) that intercept labelnames at family creation and translate .labels(engine=idx, ...) kwargs into stage/replica via a process-level engine→(stage, replica) map. These are the ingredients for OmniPrometheusStatLogger (Phase 2.2), which will populate the engine map and slot the wrappers into upstream's _gauge_cls / _counter_cls / _histogram_cls hooks. Shipped standalone with 18 unit tests so the relabel logic can be vetted before the StatLogger subclass lands. --- tests/metrics/test_stat_logger.py | 228 ++++++++++++++++++++++++++++++ vllm_omni/metrics/stat_logger.py | 91 ++++++++++++ 2 files changed, 319 insertions(+) create mode 100644 tests/metrics/test_stat_logger.py create mode 100644 vllm_omni/metrics/stat_logger.py diff --git a/tests/metrics/test_stat_logger.py b/tests/metrics/test_stat_logger.py new file mode 100644 index 00000000000..74a0730219c --- /dev/null +++ b/tests/metrics/test_stat_logger.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import pytest +from prometheus_client import CollectorRegistry, generate_latest + +from vllm_omni.metrics.stat_logger import ( + _ENGINE_INDEX_MAP, + _RelabelCounter, + _RelabelGauge, + _RelabelHistogram, + _rewrite_label_kwargs, + _rewrite_labelnames, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture(autouse=True) +def _isolate_engine_map(): + """Each test gets a clean _ENGINE_INDEX_MAP.""" + _ENGINE_INDEX_MAP.clear() + yield + _ENGINE_INDEX_MAP.clear() + + +@pytest.fixture +def registry() -> CollectorRegistry: + return CollectorRegistry() + + +# --------------------------------------------------------------------------- +# _rewrite_labelnames +# --------------------------------------------------------------------------- + + +class TestRewriteLabelnames: + def test_engine_at_end(self): + assert _rewrite_labelnames(["model_name", "engine"]) == [ + "model_name", + "stage", + "replica", + ] + + def test_engine_in_middle(self): + # Upstream uses `labelnames + ["reason"]` etc., putting engine in middle. + assert _rewrite_labelnames(["model_name", "engine", "reason"]) == [ + "model_name", + "stage", + "replica", + "reason", + ] + + def test_no_engine_label(self): + # Unaffected (e.g. omni's own families that don't use engine). + assert _rewrite_labelnames(["model_name"]) == ["model_name"] + + def test_tuple_input_returns_tuple(self): + out = _rewrite_labelnames(("model_name", "engine")) + assert isinstance(out, tuple) + assert out == ("model_name", "stage", "replica") + + def test_none_passthrough(self): + assert _rewrite_labelnames(None) is None + + +# --------------------------------------------------------------------------- +# _rewrite_label_kwargs +# --------------------------------------------------------------------------- + + +class TestRewriteLabelKwargs: + def test_engine_kwarg_translated(self): + _ENGINE_INDEX_MAP[7] = ("talker", "1") + out = _rewrite_label_kwargs({"engine": 7, "model_name": "m"}) + assert out == {"stage": "talker", "replica": "1", "model_name": "m"} + + def test_engine_with_extra_kwargs(self): + # Mirrors upstream's `.labels(engine=idx, model_name=m, sleep_state=s)`. + _ENGINE_INDEX_MAP[3] = ("thinker", "0") + out = _rewrite_label_kwargs( + {"engine": 3, "model_name": "m", "sleep_state": "awake"} + ) + assert out == { + "stage": "thinker", + "replica": "0", + "model_name": "m", + "sleep_state": "awake", + } + + def test_no_engine_kwarg_passthrough(self): + out = _rewrite_label_kwargs({"model_name": "m", "stage": "talker"}) + assert out == {"model_name": "m", "stage": "talker"} + + def test_missing_engine_idx_raises(self): + # Empty map → fail-fast rather than emit a wrong (stage, replica). + with pytest.raises(KeyError): + _rewrite_label_kwargs({"engine": 999, "model_name": "m"}) + + +# --------------------------------------------------------------------------- +# Wrapper class behavior +# --------------------------------------------------------------------------- + + +class TestRelabelGauge: + def test_labelnames_rewritten_at_creation(self, registry): + g = _RelabelGauge( + name="omni_test_gauge", + documentation="test", + labelnames=["model_name", "engine"], + registry=registry, + ) + assert g._labelnames == ("model_name", "stage", "replica") + + def test_labels_kwarg_translated(self, registry): + _ENGINE_INDEX_MAP[5] = ("diffusion", "0") + g = _RelabelGauge( + name="omni_test_gauge_kwarg", + documentation="test", + labelnames=["model_name", "engine"], + registry=registry, + ) + g.labels(engine=5, model_name="qwen-omni").set(42.0) + + out = generate_latest(registry).decode() + assert ( + 'omni_test_gauge_kwarg{model_name="qwen-omni",replica="0",stage="diffusion"} 42.0' + in out + ) + + def test_labels_positional_passthrough(self, registry): + # Phase 2.2's per_engine_labelvalues setter feeds positional 3-tuples; + # our mixin must not mangle positional .labels() calls. + g = _RelabelGauge( + name="omni_test_gauge_pos", + documentation="test", + labelnames=["model_name", "engine"], + registry=registry, + ) + g.labels("qwen-omni", "thinker", "0").set(7.0) + + out = generate_latest(registry).decode() + assert ( + 'omni_test_gauge_pos{model_name="qwen-omni",replica="0",stage="thinker"} 7.0' + in out + ) + + def test_multiprocess_mode_kwarg_passthrough(self, registry): + # Upstream creates Gauges with multiprocess_mode="mostrecent" — must not + # be eaten by our mixin. + g = _RelabelGauge( + name="omni_test_gauge_mp", + documentation="test", + labelnames=["model_name", "engine"], + multiprocess_mode="mostrecent", + registry=registry, + ) + assert g._multiprocess_mode == "mostrecent" + + +class TestRelabelCounter: + def test_labelnames_rewritten(self, registry): + c = _RelabelCounter( + name="omni_test_counter", + documentation="test", + labelnames=["model_name", "engine", "finished_reason"], + registry=registry, + ) + assert c._labelnames == ( + "model_name", + "stage", + "replica", + "finished_reason", + ) + + def test_labels_kwarg_translated(self, registry): + _ENGINE_INDEX_MAP[2] = ("thinker", "0") + c = _RelabelCounter( + name="omni_test_counter_kwarg", + documentation="test", + labelnames=["model_name", "engine", "finished_reason"], + registry=registry, + ) + c.labels(engine=2, model_name="m", finished_reason="stop").inc(3) + + out = generate_latest(registry).decode() + assert ( + 'omni_test_counter_kwarg_total{finished_reason="stop",model_name="m",replica="0",stage="thinker"} 3.0' + in out + ) + + +class TestRelabelHistogram: + def test_labelnames_rewritten(self, registry): + h = _RelabelHistogram( + name="omni_test_histo", + documentation="test", + labelnames=["model_name", "engine"], + registry=registry, + ) + assert h._labelnames == ("model_name", "stage", "replica") + + def test_labels_kwarg_translated_and_observe(self, registry): + _ENGINE_INDEX_MAP[0] = ("talker", "0") + h = _RelabelHistogram( + name="omni_test_histo_obs", + documentation="test", + labelnames=["model_name", "engine"], + registry=registry, + ) + h.labels(engine=0, model_name="m").observe(0.5) + + out = generate_latest(registry).decode() + assert ( + 'omni_test_histo_obs_count{model_name="m",replica="0",stage="talker"} 1.0' + in out + ) + + def test_no_engine_label_unaffected(self, registry): + # Families without engine label (e.g. omni-side own metrics) pass through. + h = _RelabelHistogram( + name="omni_test_no_engine", + documentation="test", + labelnames=["model_name"], + registry=registry, + ) + assert h._labelnames == ("model_name",) + h.labels(model_name="m").observe(1.0) diff --git a/vllm_omni/metrics/stat_logger.py b/vllm_omni/metrics/stat_logger.py new file mode 100644 index 00000000000..aa96e167312 --- /dev/null +++ b/vllm_omni/metrics/stat_logger.py @@ -0,0 +1,91 @@ +"""OmniPrometheusStatLogger — wrap upstream PrometheusStatLogger. + +Goal (RFC §3.2.7): rewrite the upstream `engine` single-label scheme into a +`stage` + `replica` two-label scheme so that the ~37 `vllm:*` metric families +automatically gain per-(stage, replica) visibility for multi-replica deployments. + +Phase 2.1 ships only the three wrapper metric classes + the process-level +engine→(stage, replica) map. The OmniPrometheusStatLogger subclass that wires +everything together lands in Phase 2.2. +""" + +from __future__ import annotations + +from prometheus_client import Counter, Gauge, Histogram + +# Process-wide translation table written by OmniPrometheusStatLogger at init. +# Keys are flat engine_idx values (as upstream PrometheusStatLogger sees them); +# values are the (stage_name, replica_id_str) tuple we expose as labels. +# +# Module-level rather than per-instance because the wrapper metric classes are +# constructed by upstream's __init__ and never get a back-reference to the +# StatLogger that owns them. vLLM runs a single Orchestrator/StatLogger per +# process, so a module global is safe; tests isolate by .clear()ing first. +_ENGINE_INDEX_MAP: dict[int, tuple[str, str]] = {} + + +def _rewrite_labelnames(labelnames): + """Replace `engine` in ``labelnames`` with (`stage`, `replica`) in place. + + Preserves ordering (so ``["model_name", "engine", "reason"]`` becomes + ``["model_name", "stage", "replica", "reason"]``) and the original + container type (list vs tuple). + """ + if labelnames is None: + return labelnames + seq = list(labelnames) + if "engine" not in seq: + return labelnames + out: list[str] = [] + for name in seq: + if name == "engine": + out.extend(("stage", "replica")) + else: + out.append(name) + return type(labelnames)(out) if not isinstance(labelnames, list) else out + + +def _rewrite_label_kwargs(kwargs: dict) -> dict: + """Translate ``.labels(engine=idx, ...)`` kwargs into ``stage``/``replica``. + + Raises ``KeyError`` when ``engine_idx`` is missing from the map — fail-fast + is preferable to silently emitting series under a wrong (stage, replica). + """ + if "engine" not in kwargs: + return kwargs + engine_idx = kwargs.pop("engine") + stage, replica = _ENGINE_INDEX_MAP[engine_idx] + kwargs["stage"] = stage + kwargs["replica"] = replica + return kwargs + + +class _RelabelMixin: + """Mixin: rewrite ``labelnames`` at family creation and ``.labels()`` kwargs. + + Used to derive ``_RelabelGauge`` / ``_RelabelCounter`` / ``_RelabelHistogram`` + that drop into upstream ``PrometheusStatLogger._gauge_cls`` / ``_counter_cls`` / + ``_histogram_cls`` slots. + """ + + def __init__(self, *args, **kwargs): + if "labelnames" in kwargs: + kwargs["labelnames"] = _rewrite_labelnames(kwargs["labelnames"]) + super().__init__(*args, **kwargs) + + def labels(self, *args, **kwargs): + if kwargs: + kwargs = _rewrite_label_kwargs(kwargs) + return super().labels(*args, **kwargs) + + +class _RelabelGauge(_RelabelMixin, Gauge): + pass + + +class _RelabelCounter(_RelabelMixin, Counter): + pass + + +class _RelabelHistogram(_RelabelMixin, Histogram): + pass From 7c4496ef09cc9bee7509c8c36e283656f3fda54e Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 16:34:24 +0800 Subject: [PATCH 04/13] [Metrics] Handle positional .labels() and str engine values in relabel mixin Phase 2.1 only intercepted .labels(engine=int_idx, ...) kwarg form. Upstream PrometheusStatLogger has four .labels() call sites for engine-bearing families: loggers.py:510 kwarg form, int engine (gauge_engine_sleep_state) loggers.py:646 positional, str engine (counter_prompt_tokens_by_source) loggers.py:679 positional, str engine (counter_request_success_base) loggers.py:1056 kwarg form, str engine (info_gauge via metrics_info) Track the original engine label index at family creation so positional .labels() can splice (stage, replica) at the right offset, and accept both int and str engine values in either form. Replaces the narrow _rewrite_label_kwargs helper with _engine_to_stage_replica which is shared by both code paths. Adds tests for positional engine at middle index (str + int values), str-form kwarg engine, and child metric non-recursion. --- tests/metrics/test_stat_logger.py | 134 ++++++++++++++++++++++++------ vllm_omni/metrics/stat_logger.py | 63 +++++++++----- 2 files changed, 154 insertions(+), 43 deletions(-) diff --git a/tests/metrics/test_stat_logger.py b/tests/metrics/test_stat_logger.py index 74a0730219c..06d26311669 100644 --- a/tests/metrics/test_stat_logger.py +++ b/tests/metrics/test_stat_logger.py @@ -8,7 +8,7 @@ _RelabelCounter, _RelabelGauge, _RelabelHistogram, - _rewrite_label_kwargs, + _engine_to_stage_replica, _rewrite_labelnames, ) @@ -64,37 +64,25 @@ def test_none_passthrough(self): # --------------------------------------------------------------------------- -# _rewrite_label_kwargs +# _engine_to_stage_replica # --------------------------------------------------------------------------- -class TestRewriteLabelKwargs: - def test_engine_kwarg_translated(self): +class TestEngineToStageReplica: + def test_int_engine_value(self): + # Mirrors upstream `.labels(engine=idx, ...)` with int (loggers.py:510). _ENGINE_INDEX_MAP[7] = ("talker", "1") - out = _rewrite_label_kwargs({"engine": 7, "model_name": "m"}) - assert out == {"stage": "talker", "replica": "1", "model_name": "m"} - - def test_engine_with_extra_kwargs(self): - # Mirrors upstream's `.labels(engine=idx, model_name=m, sleep_state=s)`. - _ENGINE_INDEX_MAP[3] = ("thinker", "0") - out = _rewrite_label_kwargs( - {"engine": 3, "model_name": "m", "sleep_state": "awake"} - ) - assert out == { - "stage": "thinker", - "replica": "0", - "model_name": "m", - "sleep_state": "awake", - } - - def test_no_engine_kwarg_passthrough(self): - out = _rewrite_label_kwargs({"model_name": "m", "stage": "talker"}) - assert out == {"model_name": "m", "stage": "talker"} + assert _engine_to_stage_replica(7) == ("talker", "1") + + def test_str_engine_value(self): + # Mirrors upstream `metrics_info["engine"] = str(idx)` (loggers.py:1055). + _ENGINE_INDEX_MAP[2] = ("thinker", "0") + assert _engine_to_stage_replica("2") == ("thinker", "0") def test_missing_engine_idx_raises(self): # Empty map → fail-fast rather than emit a wrong (stage, replica). with pytest.raises(KeyError): - _rewrite_label_kwargs({"engine": 999, "model_name": "m"}) + _engine_to_stage_replica(999) # --------------------------------------------------------------------------- @@ -226,3 +214,101 @@ def test_no_engine_label_unaffected(self, registry): ) assert h._labelnames == ("model_name",) h.labels(model_name="m").observe(1.0) + + +# --------------------------------------------------------------------------- +# Positional .labels() with engine value (loggers.py:646, 679) +# --------------------------------------------------------------------------- + + +class TestPositionalEngine: + def test_positional_engine_at_middle_index(self, registry): + # Mirrors `counter_prompt_tokens_by_source.labels(model_name, str(idx), source)`. + # Family original labelnames = ["model_name", "engine", "source"]. + _ENGINE_INDEX_MAP[5] = ("talker", "0") + c = _RelabelCounter( + name="omni_test_pos_mid", + documentation="test", + labelnames=["model_name", "engine", "source"], + registry=registry, + ) + c.labels("m", "5", "decoder").inc(2) + + out = generate_latest(registry).decode() + assert ( + 'omni_test_pos_mid_total{model_name="m",replica="0",source="decoder",stage="talker"} 2.0' + in out + ) + + def test_positional_engine_with_int_value(self, registry): + # Defensive: positional form may also receive an int (we accept both). + _ENGINE_INDEX_MAP[3] = ("thinker", "1") + c = _RelabelCounter( + name="omni_test_pos_int", + documentation="test", + labelnames=["model_name", "engine", "reason"], + registry=registry, + ) + c.labels("m", 3, "stop").inc() + + out = generate_latest(registry).decode() + assert ( + 'omni_test_pos_int_total{model_name="m",reason="stop",replica="1",stage="thinker"} 1.0' + in out + ) + + +# --------------------------------------------------------------------------- +# String-form engine kwarg (loggers.py:1056 info_gauge) +# --------------------------------------------------------------------------- + + +class TestStrEngineKwarg: + def test_engine_kwarg_str_form(self, registry): + # Mirrors `info_gauge.labels(**metrics_info)` where metrics_info["engine"]="0". + _ENGINE_INDEX_MAP[0] = ("thinker", "0") + g = _RelabelGauge( + name="omni_test_info", + documentation="test", + labelnames=["cache_size", "engine"], + multiprocess_mode="mostrecent", + registry=registry, + ) + # Upstream pattern: pass everything as kwargs from the metrics_info dict. + g.labels(cache_size="big", engine="0").set(1) + + out = generate_latest(registry).decode() + assert ( + 'omni_test_info{cache_size="big",replica="0",stage="thinker"} 1.0' + in out + ) + + +# --------------------------------------------------------------------------- +# Child metric does not re-trigger relabel logic +# --------------------------------------------------------------------------- + + +class TestChildNoRecursion: + def test_child_set_does_not_relookup(self, registry): + # Once .labels() returns a child, subsequent .set()/.inc() must not + # consult _ENGINE_INDEX_MAP again. We verify by clearing the map + # AFTER labels() and proving .set() still works. + _ENGINE_INDEX_MAP[4] = ("diffusion", "0") + g = _RelabelGauge( + name="omni_test_child", + documentation="test", + labelnames=["model_name", "engine"], + registry=registry, + ) + child = g.labels(engine=4, model_name="m") + _ENGINE_INDEX_MAP.clear() # would break a second .labels() lookup + child.set(99.0) # but set() is on the bound child — no map needed + + # Re-populate so generate_latest doesn't trip on anything else. + _ENGINE_INDEX_MAP[4] = ("diffusion", "0") + out = generate_latest(registry).decode() + assert ( + 'omni_test_child{model_name="m",replica="0",stage="diffusion"} 99.0' + in out + ) diff --git a/vllm_omni/metrics/stat_logger.py b/vllm_omni/metrics/stat_logger.py index aa96e167312..b631619fe9c 100644 --- a/vllm_omni/metrics/stat_logger.py +++ b/vllm_omni/metrics/stat_logger.py @@ -45,37 +45,62 @@ def _rewrite_labelnames(labelnames): return type(labelnames)(out) if not isinstance(labelnames, list) else out -def _rewrite_label_kwargs(kwargs: dict) -> dict: - """Translate ``.labels(engine=idx, ...)`` kwargs into ``stage``/``replica``. +def _engine_to_stage_replica(engine_value) -> tuple[str, str]: + """Look up (stage, replica) for an engine_idx, accepting int or str input. - Raises ``KeyError`` when ``engine_idx`` is missing from the map — fail-fast - is preferable to silently emitting series under a wrong (stage, replica). + Upstream emits engine values in two flavors: + - int form, e.g. ``gauge_engine_sleep_state.labels(engine=idx, ...)`` (loggers.py:510) + - str form, e.g. ``info_gauge.labels(**metrics_info)`` where ``metrics_info["engine"] = str(idx)`` (loggers.py:1055) + + Raises ``KeyError`` when the value is missing from the map — fail-fast is + preferable to silently emitting series under a wrong (stage, replica). """ - if "engine" not in kwargs: - return kwargs - engine_idx = kwargs.pop("engine") - stage, replica = _ENGINE_INDEX_MAP[engine_idx] - kwargs["stage"] = stage - kwargs["replica"] = replica - return kwargs + key = int(engine_value) if isinstance(engine_value, str) else engine_value + return _ENGINE_INDEX_MAP[key] class _RelabelMixin: - """Mixin: rewrite ``labelnames`` at family creation and ``.labels()`` kwargs. + """Mixin: rewrite ``labelnames`` at family creation and ``.labels()`` calls. + + Handles all four upstream forms encountered in + ``vllm.v1.metrics.loggers.PrometheusStatLogger``: - Used to derive ``_RelabelGauge`` / ``_RelabelCounter`` / ``_RelabelHistogram`` - that drop into upstream ``PrometheusStatLogger._gauge_cls`` / ``_counter_cls`` / - ``_histogram_cls`` slots. + 1. ``.labels(engine=idx, ...)`` kwarg with int engine (loggers.py:510) + 2. ``.labels(model_name, str(idx), source)`` positional with str engine + (loggers.py:646, 679) + 3. ``.labels(**metrics_info)`` kwarg with str engine (loggers.py:1056) + 4. Families without an ``engine`` label — passthrough (e.g. lora_info) + + Drops into upstream's ``_gauge_cls`` / ``_counter_cls`` / ``_histogram_cls`` + class slots. """ def __init__(self, *args, **kwargs): - if "labelnames" in kwargs: - kwargs["labelnames"] = _rewrite_labelnames(kwargs["labelnames"]) + # Remember where `engine` sat in the original labelnames so positional + # `.labels()` calls can splice (stage, replica) at the right offset. + labelnames = kwargs.get("labelnames") + if labelnames is not None: + original = list(labelnames) + self._engine_label_index = ( + original.index("engine") if "engine" in original else -1 + ) + kwargs["labelnames"] = _rewrite_labelnames(labelnames) + else: + self._engine_label_index = -1 super().__init__(*args, **kwargs) def labels(self, *args, **kwargs): - if kwargs: - kwargs = _rewrite_label_kwargs(kwargs) + if self._engine_label_index >= 0: + if args: + # Positional form: replace args[engine_idx] with (stage, replica). + idx = self._engine_label_index + if idx < len(args): + stage, replica = _engine_to_stage_replica(args[idx]) + args = (*args[:idx], stage, replica, *args[idx + 1 :]) + elif "engine" in kwargs: + stage, replica = _engine_to_stage_replica(kwargs.pop("engine")) + kwargs["stage"] = stage + kwargs["replica"] = replica return super().labels(*args, **kwargs) From b385ce72ae81218861bfee7d7b203a9aa6dcea64 Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 16:37:48 +0800 Subject: [PATCH 05/13] [Metrics] Add OmniPrometheusStatLogger subclass with per-replica labels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2.2b of multi-replica observability (RFC §3.2.7). Subclass upstream PrometheusStatLogger to expose ~37 vllm:* metric families with {stage, replica} labels instead of the single {engine} label. Three pieces of glue: 1. Class-level slot overrides (_gauge_cls / _counter_cls / _histogram_cls) plug the relabel mixin wrappers into upstream's family-construction path so labelnames are rewritten as families are created. 2. A property descriptor on per_engine_labelvalues intercepts upstream's loggers.py:433 self-assignment and rewrites each [model_name, str(idx)] tuple to [model_name, stage, replica] so downstream create_metric_per_engine feeds 3-element labelvalues into the 3-label families. 3. __init__ accepts a stage_replica_map: dict[int, tuple[str, str]] and populates the process-level _ENGINE_INDEX_MAP that the wrappers consult on every .labels() call. Map is cleared first so re-init in the same process (tests, orchestrator restart) starts clean. Dynamic add/remove of replicas at runtime is intentionally out of scope. --- tests/metrics/test_stat_logger.py | 67 +++++++++++++++++++++++++++++++ vllm_omni/metrics/stat_logger.py | 59 +++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/tests/metrics/test_stat_logger.py b/tests/metrics/test_stat_logger.py index 06d26311669..16bfe3e4697 100644 --- a/tests/metrics/test_stat_logger.py +++ b/tests/metrics/test_stat_logger.py @@ -5,6 +5,7 @@ from vllm_omni.metrics.stat_logger import ( _ENGINE_INDEX_MAP, + OmniPrometheusStatLogger, _RelabelCounter, _RelabelGauge, _RelabelHistogram, @@ -312,3 +313,69 @@ def test_child_set_does_not_relookup(self, registry): 'omni_test_child{model_name="m",replica="0",stage="diffusion"} 99.0' in out ) + + +# --------------------------------------------------------------------------- +# OmniPrometheusStatLogger — focused on the wrap mechanics (full PrometheusStatLogger +# init requires a real VllmConfig and is exercised by the orchestrator integration +# test in Phase 2.3). +# --------------------------------------------------------------------------- + + +class TestOmniPrometheusStatLogger: + def test_class_slots_point_to_wrappers(self): + # Upstream's __init__ uses self._gauge_cls(...) etc. when constructing + # families; class-level slot override is how we inject the relabel logic. + assert OmniPrometheusStatLogger._gauge_cls is _RelabelGauge + assert OmniPrometheusStatLogger._counter_cls is _RelabelCounter + assert OmniPrometheusStatLogger._histogram_cls is _RelabelHistogram + + def test_per_engine_labelvalues_setter_rewrites_to_3tuple(self): + # Construct via __new__ to skip the upstream PrometheusStatLogger __init__ + # (which needs a real VllmConfig). We only verify the property descriptor. + sl = OmniPrometheusStatLogger.__new__(OmniPrometheusStatLogger) + sl._stage_replica_map = { + 0: ("thinker", "0"), + 1: ("talker", "0"), + 2: ("talker", "1"), + } + + # Mirror upstream's loggers.py:433 assignment shape. + sl.per_engine_labelvalues = { + 0: ["my-model", "0"], + 1: ["my-model", "1"], + 2: ["my-model", "2"], + } + + # Getter should return the 3-tuple form for downstream + # create_metric_per_engine consumers. + assert sl.per_engine_labelvalues == { + 0: ["my-model", "thinker", "0"], + 1: ["my-model", "talker", "0"], + 2: ["my-model", "talker", "1"], + } + + def test_per_engine_labelvalues_getter_returns_internal_dict(self): + sl = OmniPrometheusStatLogger.__new__(OmniPrometheusStatLogger) + sl._stage_replica_map = {0: ("thinker", "0")} + sl._omni_per_engine_labelvalues = {0: ["m", "thinker", "0"]} + assert sl.per_engine_labelvalues == {0: ["m", "thinker", "0"]} + + def test_stage_replica_map_property_exposed(self): + sl = OmniPrometheusStatLogger.__new__(OmniPrometheusStatLogger) + srm = {0: ("thinker", "0"), 1: ("diffusion", "0")} + sl._stage_replica_map = srm + assert sl.stage_replica_map is srm + + def test_init_populates_engine_index_map(self): + # Simulate the bookkeeping portion of __init__ (clear + update) without + # calling super, since super needs a real VllmConfig. + _ENGINE_INDEX_MAP[99] = ("stale", "stale") # leftover from prior + srm = {0: ("thinker", "0"), 1: ("talker", "0")} + + # Manually invoke the bookkeeping the way __init__ does it. + _ENGINE_INDEX_MAP.clear() + _ENGINE_INDEX_MAP.update(srm) + + assert dict(_ENGINE_INDEX_MAP) == srm + assert 99 not in _ENGINE_INDEX_MAP # old entry was cleared diff --git a/vllm_omni/metrics/stat_logger.py b/vllm_omni/metrics/stat_logger.py index b631619fe9c..0154856883c 100644 --- a/vllm_omni/metrics/stat_logger.py +++ b/vllm_omni/metrics/stat_logger.py @@ -12,6 +12,8 @@ from __future__ import annotations from prometheus_client import Counter, Gauge, Histogram +from vllm.config import VllmConfig +from vllm.v1.metrics.loggers import PrometheusStatLogger # Process-wide translation table written by OmniPrometheusStatLogger at init. # Keys are flat engine_idx values (as upstream PrometheusStatLogger sees them); @@ -114,3 +116,60 @@ class _RelabelCounter(_RelabelMixin, Counter): class _RelabelHistogram(_RelabelMixin, Histogram): pass + + +class OmniPrometheusStatLogger(PrometheusStatLogger): + """Wrap upstream PrometheusStatLogger to expose per-(stage, replica) labels. + + Replaces the upstream single ``engine`` label with two labels ``stage`` and + ``replica`` so that the ~37 ``vllm:*`` metric families gain per-replica + visibility for multi-replica deployments. See RFC §3.2.7. + + The orchestrator builds ``stage_replica_map`` from the static stage_pools + config; flat engine_idx values map 1:1 to (stage_name, replica_id) tuples. + Dynamic add/remove of replicas at runtime is intentionally not supported + in this iteration — see RFC §3.4 risks. + """ + + # Inject our wrapper metric classes into upstream's class-level slots so + # every ~37 family is created with `engine` rewritten to `stage`+`replica`. + _gauge_cls = _RelabelGauge + _counter_cls = _RelabelCounter + _histogram_cls = _RelabelHistogram + + def __init__( + self, + vllm_config: VllmConfig, + stage_replica_map: dict[int, tuple[str, str]], + ) -> None: + self._stage_replica_map = stage_replica_map + # Populate the process-level translation table that wrapper metric + # classes consult on every `.labels()` call. Cleared first so a + # second OmniPrometheusStatLogger in the same process (e.g. tests, + # orchestrator restart) starts from a clean slate. + _ENGINE_INDEX_MAP.clear() + _ENGINE_INDEX_MAP.update(stage_replica_map) + super().__init__( + vllm_config=vllm_config, + engine_indexes=list(stage_replica_map.keys()), + ) + + @property + def stage_replica_map(self) -> dict[int, tuple[str, str]]: + return self._stage_replica_map + + @property + def per_engine_labelvalues(self) -> dict[int, list[object]]: + return self._omni_per_engine_labelvalues + + @per_engine_labelvalues.setter + def per_engine_labelvalues(self, value: dict[int, list[object]]) -> None: + # Upstream sets {idx: [model_name, str(idx)]} (loggers.py:433); we drop + # the engine str and append (stage, replica) so labelvalues match the + # 3-element labelnames our wrapper classes produce. + rewritten: dict[int, list[object]] = {} + for idx, vals in value.items(): + model_name = vals[0] + stage, replica = self._stage_replica_map[idx] + rewritten[idx] = [model_name, stage, replica] + self._omni_per_engine_labelvalues = rewritten From 2c8d9642e97ba2a129130d08ef53700067c5c059 Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 16:47:44 +0800 Subject: [PATCH 06/13] [Metrics] Wire OmniPrometheusStatLogger into orchestrator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2.3 of multi-replica observability (RFC §3.2.7) — completes the end-to-end path so the ~37 vllm:* metric families now expose per-(stage, replica) series. Three targeted edits to orchestrator.py: 1. Build a flat stage_replica_map at __init__ by walking stage_pools and assigning each (stage_id, replica_id) tuple a monotonically increasing engine_idx via prefix-sum across pools. The reverse map _stage_replica_to_engine_idx is consulted at record() time to translate the orchestrator's (stage_id, replica_id) loop variables back into the flat engine_idx upstream expects. 2. Swap PrometheusStatLogger for OmniPrometheusStatLogger and pass stage_replica_map instead of the old range(num_stages) engine_indexes. 3. record() now passes the per-replica flat engine_idx instead of just stage_id, so multi-replica deployments emit distinct {stage, replica} series per replica rather than aggregating them. The stage label currently uses str(stage_id); migrating to semantic names ("thinker", "talker", "diffusion") is a separate config schema change deferred to a later iteration. --- vllm_omni/engine/orchestrator.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 2cd706d4c7f..8435b79a825 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -21,7 +21,6 @@ from vllm.sampling_params import SamplingParams from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.metrics.stats import IterationStats from vllm_omni.engine import OmniEngineCoreRequest @@ -29,6 +28,7 @@ from vllm_omni.engine.serialization import serialize_additional_information from vllm_omni.engine.stage_pool import StagePool from vllm_omni.metrics.prometheus import OmniRequestCounter +from vllm_omni.metrics.stat_logger import OmniPrometheusStatLogger from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -151,10 +151,23 @@ def __init__( (p.stage_vllm_config for p in stage_pools if p.stage_vllm_config is not None), None, ) + # Build flat engine_idx ↔ (stage_id, replica_id) maps so that the wrap + # exposes the ~37 vllm:* families with per-(stage, replica) labels. + # The reverse map is consulted at record() time to find the engine_idx + # to update from the orchestrator's (stage_id, replica_id) loop. + stage_replica_map: dict[int, tuple[str, str]] = {} + self._stage_replica_to_engine_idx: dict[tuple[int, int], int] = {} + flat_idx = 0 + for stage_id, pool in enumerate(stage_pools): + for replica_id in range(pool.num_replicas): + stage_replica_map[flat_idx] = (str(stage_id), str(replica_id)) + self._stage_replica_to_engine_idx[(stage_id, replica_id)] = flat_idx + flat_idx += 1 + if vllm_config_for_stats is not None: - self._stat_logger: PrometheusStatLogger | None = PrometheusStatLogger( + self._stat_logger: OmniPrometheusStatLogger | None = OmniPrometheusStatLogger( vllm_config=vllm_config_for_stats, - engine_indexes=list(range(self.num_stages)), + stage_replica_map=stage_replica_map, ) else: self._stat_logger = None @@ -483,7 +496,7 @@ async def _orchestration_loop(self) -> None: self._stat_logger.record( raw_outputs.scheduler_stats, iteration_stats, - engine_idx=stage_id, + engine_idx=self._stage_replica_to_engine_idx[(stage_id, replica_id)], ) except asyncio.CancelledError: raise From df6deb9e50463fc936fbbef8295bb628be507fd7 Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 17:13:26 +0800 Subject: [PATCH 07/13] [Metrics] Add OmniModalityMetrics with 8 audio/image/video families MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3.1 of multi-modal observability (RFC G1/G2). Per-modality business-semantic Prometheus families with {model_name, stage, replica} labels: audio: ttfp_seconds, duration_seconds, rtf, frames (Counter) image: ttfp_seconds, num (Counter), generation_time_seconds video: generation_time_seconds video_duration_seconds and video_rtf are intentionally deferred — diffusion video pipelines (i2v / t2v / cogvideo / hunyuan / wan) expose num_frames + fps in heterogeneous shapes and a clean abstraction is out of scope for this iteration. Text-path metrics (TTFT/ITL/TPOT) are NOT here — they come from the upstream vllm:*{stage="thinker", ...} families exposed by the Phase 2 OmniPrometheusStatLogger wrap. RFC §3.2.6 single-source naming via definitions.py constants; RTF families use RTF_BUCKETS, time families use SECONDS_BUCKETS. observe APIs accept stage/replica at call time since one OmniModalityMetrics instance per pipeline serves all (stage, replica) combinations. --- tests/metrics/test_modality.py | 183 +++++++++++++++++++++++++++++++++ vllm_omni/metrics/modality.py | 155 ++++++++++++++++++++++++++++ 2 files changed, 338 insertions(+) create mode 100644 tests/metrics/test_modality.py create mode 100644 vllm_omni/metrics/modality.py diff --git a/tests/metrics/test_modality.py b/tests/metrics/test_modality.py new file mode 100644 index 00000000000..e6dfd4fd931 --- /dev/null +++ b/tests/metrics/test_modality.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import pytest +from prometheus_client import REGISTRY, generate_latest + +from vllm_omni.metrics import definitions as defs +from vllm_omni.metrics.modality import OmniModalityMetrics + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +_MODEL = "test-modality-model" + + +@pytest.fixture(scope="module") +def mod() -> OmniModalityMetrics: + return OmniModalityMetrics(model_name=_MODEL) + + +# Each test uses a distinct (stage, replica) so counter accumulation +# across tests doesn't couple assertions. +_AUDIO_STAGE = ("talker", "0") +_IMAGE_STAGE = ("diffusion", "0") +_VIDEO_STAGE = ("diffusion", "1") + + +def _sample_value(output: str, line_prefix: str) -> float | None: + for line in output.splitlines(): + if line.startswith(line_prefix): + return float(line.split()[-1]) + return None + + +# --------------------------------------------------------------------------- +# Family registration +# --------------------------------------------------------------------------- + + +_EXPECTED_FAMILIES = [ + defs.AUDIO_TTFP_SECONDS, + defs.AUDIO_DURATION_SECONDS, + defs.AUDIO_RTF_METRIC, + defs.AUDIO_FRAMES_METRIC, + defs.IMAGE_TTFP_SECONDS, + defs.IMAGE_NUM_METRIC, + defs.IMAGE_GENERATION_TIME_SECONDS, + defs.VIDEO_GENERATION_TIME_SECONDS, +] + + +class TestRegistration: + def test_all_eight_families_present(self, mod: OmniModalityMetrics) -> None: + # Trigger at least one observation per family so the registry exposes them. + mod.observe_audio_ttfp("s", "r", 0.1) + mod.observe_audio_duration("s", "r", 1.0) + mod.observe_audio_rtf("s", "r", 0.5) + mod.inc_audio_frames("s", "r", 1) + mod.observe_image_ttfp("s", "r", 0.2) + mod.inc_image_num("s", "r", 1) + mod.observe_image_generation_time("s", "r", 0.5) + mod.observe_video_generation_time("s", "r", 1.0) + + out = generate_latest(REGISTRY).decode() + for name in _EXPECTED_FAMILIES: + assert f"# HELP {name}" in out, f"missing family: {name}" + + def test_video_duration_and_rtf_intentionally_absent(self) -> None: + # Phase 3 deliberately drops these — see modality.py docstring. + out = generate_latest(REGISTRY).decode() + assert defs.VIDEO_DURATION_SECONDS not in out + assert defs.VIDEO_RTF_METRIC not in out + + +# --------------------------------------------------------------------------- +# Audio observe API +# --------------------------------------------------------------------------- + + +class TestAudio: + def test_audio_ttfp_observed(self, mod: OmniModalityMetrics) -> None: + stage, replica = "talker_ttfp", "0" + mod.observe_audio_ttfp(stage, replica, 0.42) + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.AUDIO_TTFP_SECONDS}_count{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 1.0 + + def test_audio_duration_observed(self, mod: OmniModalityMetrics) -> None: + stage, replica = "talker_dur", "0" + mod.observe_audio_duration(stage, replica, 3.5) + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.AUDIO_DURATION_SECONDS}_sum{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 3.5 + + def test_audio_rtf_observed(self, mod: OmniModalityMetrics) -> None: + stage, replica = "talker_rtf", "0" + mod.observe_audio_rtf(stage, replica, 0.45) + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.AUDIO_RTF_METRIC}_sum{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 0.45 + + def test_audio_frames_inc(self, mod: OmniModalityMetrics) -> None: + stage, replica = "talker_frames", "0" + mod.inc_audio_frames(stage, replica, 240) + mod.inc_audio_frames(stage, replica, 60) + out = generate_latest(REGISTRY).decode() + # Counter family auto-suffixes with _total in the exposed name. + prefix = f'{defs.AUDIO_FRAMES_METRIC}_total{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 300.0 + + def test_audio_frames_zero_or_negative_skipped(self, mod: OmniModalityMetrics) -> None: + stage, replica = "talker_zero", "0" + mod.inc_audio_frames(stage, replica, 0) + mod.inc_audio_frames(stage, replica, -5) + # No observation → no series exposed for this (stage, replica) yet. + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.AUDIO_FRAMES_METRIC}_total{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) is None + + +# --------------------------------------------------------------------------- +# Image observe API +# --------------------------------------------------------------------------- + + +class TestImage: + def test_image_ttfp_observed(self, mod: OmniModalityMetrics) -> None: + stage, replica = "diffusion_ttfp", "0" + mod.observe_image_ttfp(stage, replica, 1.5) + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.IMAGE_TTFP_SECONDS}_count{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 1.0 + + def test_image_num_inc(self, mod: OmniModalityMetrics) -> None: + stage, replica = "diffusion_num", "0" + mod.inc_image_num(stage, replica, 4) + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.IMAGE_NUM_METRIC}_total{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 4.0 + + def test_image_generation_time_observed(self, mod: OmniModalityMetrics) -> None: + stage, replica = "diffusion_gen", "0" + mod.observe_image_generation_time(stage, replica, 2.7) + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.IMAGE_GENERATION_TIME_SECONDS}_sum{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 2.7 + + +# --------------------------------------------------------------------------- +# Video observe API +# --------------------------------------------------------------------------- + + +class TestVideo: + def test_video_generation_time_observed(self, mod: OmniModalityMetrics) -> None: + stage, replica = "diffusion_video", "0" + mod.observe_video_generation_time(stage, replica, 5.2) + out = generate_latest(REGISTRY).decode() + prefix = f'{defs.VIDEO_GENERATION_TIME_SECONDS}_sum{{model_name="{_MODEL}",replica="{replica}",stage="{stage}"}}' + assert _sample_value(out, prefix) == 5.2 + + +# --------------------------------------------------------------------------- +# Bucket selection (RTF uses RTF_BUCKETS, others use SECONDS_BUCKETS) +# --------------------------------------------------------------------------- + + +class TestBucketSelection: + def test_audio_rtf_uses_rtf_buckets(self, mod: OmniModalityMetrics) -> None: + stage, replica = "talker_buckets", "0" + mod.observe_audio_rtf(stage, replica, 0.5) + out = generate_latest(REGISTRY).decode() + # RTF_BUCKETS includes 0.9 and 1.25 — distinctive boundaries vs SECONDS_BUCKETS. + # Check that at least one RTF-distinctive bucket label appears. + rtf_marker = f'{defs.AUDIO_RTF_METRIC}_bucket{{le="0.9"' + assert rtf_marker in out, "audio_rtf should use RTF_BUCKETS containing le=0.9" + + def test_audio_ttfp_uses_seconds_buckets(self, mod: OmniModalityMetrics) -> None: + stage, replica = "talker_seconds", "0" + mod.observe_audio_ttfp(stage, replica, 0.1) + out = generate_latest(REGISTRY).decode() + # SECONDS_BUCKETS includes 0.05 — not in RTF_BUCKETS. + sec_marker = f'{defs.AUDIO_TTFP_SECONDS}_bucket{{le="0.05"' + assert sec_marker in out, "audio_ttfp should use SECONDS_BUCKETS containing le=0.05" diff --git a/vllm_omni/metrics/modality.py b/vllm_omni/metrics/modality.py new file mode 100644 index 00000000000..5c47320792e --- /dev/null +++ b/vllm_omni/metrics/modality.py @@ -0,0 +1,155 @@ +"""OmniModalityMetrics — per-modality Prometheus families (RFC G1/G2). + +Audio / image / video business-semantic metric families with +``{model_name, stage, replica}`` labels. Text-path metrics (TTFT/ITL/TPOT) +are NOT here — they come from the upstream ``vllm:*{stage="thinker", ...}`` +families exposed by ``OmniPrometheusStatLogger`` (Phase 2 wrap). + +Phase 3 covers 8 of the 10 RFC families: +- audio: ttfp, duration, rtf, frames +- image: ttfp, num, generation_time +- video: generation_time + +``video_duration_seconds`` and ``video_rtf`` are intentionally deferred — +diffusion video pipelines (i2v / t2v / cogvideo / hunyuan / wan) expose +num_frames + fps in heterogeneous shapes and a clean abstraction is out of +scope for this iteration. +""" + +from __future__ import annotations + +from prometheus_client import Counter, Histogram + +from vllm_omni.metrics import definitions as defs + +_labelnames = list(defs.STAGE_LABELS) + + +# ---------------------------------------------------------------------------- +# Audio family (G1) — observed at finalize except for ttfp which is observed +# at the streaming hook (first audio packet emerges). +# ---------------------------------------------------------------------------- +_audio_ttfp_family = Histogram( + defs.AUDIO_TTFP_SECONDS, + "Time from request arrival to first audio packet, in seconds.", + labelnames=_labelnames, + buckets=defs.SECONDS_BUCKETS, +) +_audio_duration_family = Histogram( + defs.AUDIO_DURATION_SECONDS, + "Generated audio content duration, in seconds (audio_frames / sample_rate).", + labelnames=_labelnames, + buckets=defs.SECONDS_BUCKETS, +) +_audio_rtf_family = Histogram( + defs.AUDIO_RTF_METRIC, + "Audio real-time factor (stage_gen_time_s / audio_duration_s); SLO red line < 1.", + labelnames=_labelnames, + buckets=defs.RTF_BUCKETS, +) +_audio_frames_family = Counter( + defs.AUDIO_FRAMES_METRIC, + "Total audio frames generated; throughput recovered via rate().", + labelnames=_labelnames, +) + + +# ---------------------------------------------------------------------------- +# Image family (G2) +# ---------------------------------------------------------------------------- +_image_ttfp_family = Histogram( + defs.IMAGE_TTFP_SECONDS, + "Time from request arrival to first image (or only image) emitted, in seconds.", + labelnames=_labelnames, + buckets=defs.SECONDS_BUCKETS, +) +_image_num_family = Counter( + defs.IMAGE_NUM_METRIC, + "Total images generated; throughput recovered via rate().", + labelnames=_labelnames, +) +_image_generation_time_family = Histogram( + defs.IMAGE_GENERATION_TIME_SECONDS, + "Per-request image stage generation time, in seconds. Image has no RTF " + "(no content duration), so generation time is exposed independently.", + labelnames=_labelnames, + buckets=defs.SECONDS_BUCKETS, +) + + +# ---------------------------------------------------------------------------- +# Video family (G2) — only generation_time this iteration; duration/rtf +# require num_frames + fps from heterogeneous diffusion pipelines. +# ---------------------------------------------------------------------------- +_video_generation_time_family = Histogram( + defs.VIDEO_GENERATION_TIME_SECONDS, + "Per-request video stage generation time, in seconds.", + labelnames=_labelnames, + buckets=defs.SECONDS_BUCKETS, +) + + +class OmniModalityMetrics: + """Per-modality observe API. Stage/replica are passed at observe time + because a single OmniModalityMetrics instance per pipeline serves all + stage+replica combinations. + + See RFC §3.2.6. + """ + + def __init__(self, model_name: str) -> None: + self._model_name = model_name + + # ---- Audio ------------------------------------------------------------ + + def observe_audio_ttfp(self, stage: str, replica: str, ttfp_seconds: float) -> None: + _audio_ttfp_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).observe(ttfp_seconds) + + def observe_audio_duration(self, stage: str, replica: str, duration_seconds: float) -> None: + _audio_duration_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).observe(duration_seconds) + + def observe_audio_rtf(self, stage: str, replica: str, rtf: float) -> None: + _audio_rtf_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).observe(rtf) + + def inc_audio_frames(self, stage: str, replica: str, n_frames: int) -> None: + if n_frames <= 0: + return + _audio_frames_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).inc(n_frames) + + # ---- Image ------------------------------------------------------------ + + def observe_image_ttfp(self, stage: str, replica: str, ttfp_seconds: float) -> None: + _image_ttfp_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).observe(ttfp_seconds) + + def inc_image_num(self, stage: str, replica: str, n_images: int) -> None: + if n_images <= 0: + return + _image_num_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).inc(n_images) + + def observe_image_generation_time( + self, stage: str, replica: str, gen_time_seconds: float + ) -> None: + _image_generation_time_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).observe(gen_time_seconds) + + # ---- Video ------------------------------------------------------------ + + def observe_video_generation_time( + self, stage: str, replica: str, gen_time_seconds: float + ) -> None: + _video_generation_time_family.labels( + model_name=self._model_name, stage=stage, replica=replica + ).observe(gen_time_seconds) From b285a19622ead2f6d7c505a263232ff4fa37f888 Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 17:24:13 +0800 Subject: [PATCH 08/13] [Metrics] Wire OmniModalityMetrics finalize hook for audio/image/video MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3.2 of multi-modal observability (RFC G1/G2). Emit per-modality Prometheus observations at request finalize for 7 of the 8 modality families: audio: frames (Counter), duration_seconds, rtf image: num (Counter), generation_time_seconds, ttfp_seconds video: generation_time_seconds audio_ttfp_seconds is intentionally NOT here — it requires a first-packet timestamp from the streaming path and lands separately in Phase 3.3. Five connected edits: 1. definitions.py — add resolve_audio_sample_rate(multimodal_output) helper that mirrors the (audio_sample_rate / sample_rate / sampling_rate / sr) fallback chain already used by serving_chat.py for the OpenAI audio response, plus DEFAULT_AUDIO_SAMPLE_RATE=24000. 2. modality.py — add observe_modality_at_finalize() module-level routing function. Extracted from omni_base so unit tests can exercise the routing logic without importing the heavy AsyncOmniBase stack. 3. orchestrator.py — add replica_id to _route_output() signature and into the "type=output" / "type=stage_metrics" queue dicts so the downstream finalize hook in omni_base can emit per-replica labels. The error path at line 867 intentionally omits replica_id; the modality routing function defensive-skips when None. 4. omni_base.py — instantiate self.mod_metrics (OmniModalityMetrics) alongside self.prom_metrics; call observe_modality_at_finalize() inside the existing e2e_done finalize guard so it fires once per request and inherits the try/except isolation. 5. tests/metrics/test_modality.py — 9 routing tests covering all four output_type branches, sample-rate resolution, ttfp clamping on clock skew, and defensive skips for None replica_id / stage_metrics. Uses a stub mod_metrics that records call signatures so we don't need a live PrometheusRegistry for routing assertions. --- tests/metrics/test_modality.py | 179 ++++++++++++++++++++++++++++- vllm_omni/engine/orchestrator.py | 5 +- vllm_omni/entrypoints/omni_base.py | 18 ++- vllm_omni/metrics/definitions.py | 35 ++++++ vllm_omni/metrics/modality.py | 59 ++++++++++ 5 files changed, 293 insertions(+), 3 deletions(-) diff --git a/tests/metrics/test_modality.py b/tests/metrics/test_modality.py index e6dfd4fd931..192f2dfaf47 100644 --- a/tests/metrics/test_modality.py +++ b/tests/metrics/test_modality.py @@ -4,7 +4,7 @@ from prometheus_client import REGISTRY, generate_latest from vllm_omni.metrics import definitions as defs -from vllm_omni.metrics.modality import OmniModalityMetrics +from vllm_omni.metrics.modality import OmniModalityMetrics, observe_modality_at_finalize pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -164,6 +164,183 @@ def test_video_generation_time_observed(self, mod: OmniModalityMetrics) -> None: # --------------------------------------------------------------------------- +class _StubModMetrics: + """Records every observe/inc call so the routing logic can be asserted.""" + + def __init__(self): + self.calls: list[tuple] = [] + + def inc_audio_frames(self, s, r, n): + self.calls.append(("inc_audio_frames", s, r, n)) + + def observe_audio_duration(self, s, r, d): + self.calls.append(("observe_audio_duration", s, r, d)) + + def observe_audio_rtf(self, s, r, rtf): + self.calls.append(("observe_audio_rtf", s, r, rtf)) + + def inc_image_num(self, s, r, n): + self.calls.append(("inc_image_num", s, r, n)) + + def observe_image_generation_time(self, s, r, t): + self.calls.append(("observe_image_generation_time", s, r, t)) + + def observe_image_ttfp(self, s, r, t): + self.calls.append(("observe_image_ttfp", s, r, t)) + + def observe_video_generation_time(self, s, r, t): + self.calls.append(("observe_video_generation_time", s, r, t)) + + +class _Bag: + """Tiny attribute bag for stage_metrics / engine_outputs stubs.""" + + def __init__(self, **kw): + self.__dict__.update(kw) + + +class TestObserveModalityAtFinalize: + def test_audio_path_full(self): + stub = _StubModMetrics() + stage_metrics = _Bag(stage_gen_time_ms=500.0, audio_generated_frames=24000) + engine_outputs = _Bag(multimodal_output={"audio_sample_rate": 24000}) + + observe_modality_at_finalize( + stub, + output_type="audio", + stage_id=1, + replica_id=0, + stage_metrics=stage_metrics, + engine_outputs=engine_outputs, + request_arrival_ts=100.0, + finalize_ts=100.5, + ) + # 24000 frames / 24000 Hz = 1.0s duration; gen 0.5s → rtf 0.5 + assert ("inc_audio_frames", "1", "0", 24000) in stub.calls + assert ("observe_audio_duration", "1", "0", 1.0) in stub.calls + assert ("observe_audio_rtf", "1", "0", 0.5) in stub.calls + + def test_audio_path_zero_frames_skips_duration_and_rtf(self): + stub = _StubModMetrics() + observe_modality_at_finalize( + stub, + output_type="audio", + stage_id=1, + replica_id=0, + stage_metrics=_Bag(stage_gen_time_ms=300.0, audio_generated_frames=0), + engine_outputs=_Bag(multimodal_output={}), + request_arrival_ts=100.0, + finalize_ts=100.3, + ) + # inc with 0 still called (Counter side gates internally to no-op) + assert ("inc_audio_frames", "1", "0", 0) in stub.calls + # but no duration / rtf because duration_s == 0 + assert not any(c[0] == "observe_audio_duration" for c in stub.calls) + assert not any(c[0] == "observe_audio_rtf" for c in stub.calls) + + def test_audio_uses_resolved_sample_rate_from_multimodal_output(self): + stub = _StubModMetrics() + # Non-default 16 kHz from talker config + observe_modality_at_finalize( + stub, + output_type="audio", + stage_id=1, + replica_id=0, + stage_metrics=_Bag(stage_gen_time_ms=1000.0, audio_generated_frames=16000), + engine_outputs=_Bag(multimodal_output={"sample_rate": 16000}), + request_arrival_ts=0.0, + finalize_ts=1.0, + ) + # 16000 / 16000 = 1.0s + assert ("observe_audio_duration", "1", "0", 1.0) in stub.calls + + def test_image_path_uses_finalize_minus_arrival_for_ttfp(self): + stub = _StubModMetrics() + observe_modality_at_finalize( + stub, + output_type="image", + stage_id=2, + replica_id=1, + stage_metrics=_Bag(stage_gen_time_ms=2000.0), + engine_outputs=_Bag(images=["img1", "img2", "img3"]), + request_arrival_ts=10.0, + finalize_ts=12.5, + ) + assert ("inc_image_num", "2", "1", 3) in stub.calls + assert ("observe_image_generation_time", "2", "1", 2.0) in stub.calls + assert ("observe_image_ttfp", "2", "1", 2.5) in stub.calls + + def test_image_ttfp_clamped_to_zero_on_clock_skew(self): + stub = _StubModMetrics() + observe_modality_at_finalize( + stub, + output_type="image", + stage_id=2, + replica_id=0, + stage_metrics=_Bag(stage_gen_time_ms=1000.0), + engine_outputs=_Bag(images=["img"]), + request_arrival_ts=100.0, + finalize_ts=99.5, # finalize earlier than arrival (impossible but defensive) + ) + assert ("observe_image_ttfp", "2", "0", 0.0) in stub.calls + + def test_video_path_only_emits_generation_time(self): + stub = _StubModMetrics() + observe_modality_at_finalize( + stub, + output_type="video", + stage_id=2, + replica_id=0, + stage_metrics=_Bag(stage_gen_time_ms=5200.0), + engine_outputs=_Bag(), + request_arrival_ts=0.0, + finalize_ts=5.3, + ) + assert stub.calls == [("observe_video_generation_time", "2", "0", 5.2)] + + def test_text_path_no_calls(self): + stub = _StubModMetrics() + observe_modality_at_finalize( + stub, + output_type="text", + stage_id=0, + replica_id=0, + stage_metrics=_Bag(stage_gen_time_ms=100.0), + engine_outputs=_Bag(), + request_arrival_ts=0.0, + finalize_ts=0.1, + ) + assert stub.calls == [] + + def test_replica_id_none_skipped(self): + stub = _StubModMetrics() + observe_modality_at_finalize( + stub, + output_type="audio", + stage_id=1, + replica_id=None, # error path: orchestrator emitted without replica_id + stage_metrics=_Bag(stage_gen_time_ms=500.0, audio_generated_frames=240), + engine_outputs=_Bag(multimodal_output={}), + request_arrival_ts=0.0, + finalize_ts=0.5, + ) + assert stub.calls == [] + + def test_stage_metrics_none_skipped(self): + stub = _StubModMetrics() + observe_modality_at_finalize( + stub, + output_type="audio", + stage_id=1, + replica_id=0, + stage_metrics=None, + engine_outputs=_Bag(multimodal_output={}), + request_arrival_ts=0.0, + finalize_ts=0.5, + ) + assert stub.calls == [] + + class TestBucketSelection: def test_audio_rtf_uses_rtf_buckets(self, mod: OmniModalityMetrics) -> None: stage, replica = "talker_buckets", "0" diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 8435b79a825..872ca6aa9ca 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -576,7 +576,7 @@ async def _handle_processed_outputs(self, stage_id: int, replica_id: int, output ) stage_metrics.pipeline_timings = dict(req_state.pipeline_timings) - await self._route_output(stage_id, output, req_state, stage_metrics) + await self._route_output(stage_id, replica_id, output, req_state, stage_metrics) async def _handle_stage_error(self, stage_id: int, output: Any) -> None: """Emit a frontend-visible error and clean up request state.""" @@ -632,6 +632,7 @@ def _maybe_clone_diffusion_params_for_cfg(self, request_id: str, params: Any) -> async def _route_output( self, stage_id: int, + replica_id: int, output: Any, req_state: OrchestratorRequestState, stage_metrics: Any, @@ -652,6 +653,7 @@ async def _route_output( "type": "output", "request_id": req_id, "stage_id": stage_id, + "replica_id": replica_id, "engine_outputs": output, "metrics": stage_metrics, "finished": finished and stage_id == req_state.final_stage_id, @@ -664,6 +666,7 @@ async def _route_output( "type": "stage_metrics", "request_id": req_id, "stage_id": stage_id, + "replica_id": replica_id, "metrics": stage_metrics, "stage_submit_ts": submit_ts, } diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index 65278dacacf..b718547a14f 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -18,6 +18,7 @@ from vllm_omni.entrypoints.client_request_state import ClientRequestState from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin from vllm_omni.entrypoints.utils import coerce_param_message_types, get_final_stage_id_for_e2e +from vllm_omni.metrics.modality import OmniModalityMetrics, observe_modality_at_finalize from vllm_omni.metrics.prometheus import OmniPrometheusMetrics from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific @@ -187,6 +188,7 @@ def __init__( self.request_states: dict[str, ClientRequestState] = {} self.prom_metrics = OmniPrometheusMetrics(model_name=model) + self.mod_metrics = OmniModalityMetrics(model_name=model) self.default_sampling_params_list = self.engine.default_sampling_params_list if not self.output_modalities: @@ -443,6 +445,8 @@ def _process_single_result( if not stage_meta["final_output"]: return None + output_type = getattr(engine_outputs, "final_output_type", stage_meta["final_output_type"]) + try: rid_key = str(req_id) if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done and finished: @@ -457,6 +461,19 @@ def _process_single_result( queue_ms = _pt.get("queue_wait_ms") queue_seconds = queue_ms / 1000.0 if queue_ms is not None else None self.prom_metrics.request_succeeded(e2e_seconds, queue_seconds=queue_seconds) + + # Modality observe (Phase 3.2). Inside the same finalize guard so + # it fires once per request and inherits the try/except isolation. + observe_modality_at_finalize( + self.mod_metrics, + output_type=output_type, + stage_id=stage_id, + replica_id=result.get("replica_id"), + stage_metrics=_m, + engine_outputs=engine_outputs, + request_arrival_ts=req_start_ts.get(req_id, wall_start_ts), + finalize_ts=now, + ) except Exception: logger.exception("[%s] Finalize request handling error", self.__class__.__name__) @@ -469,7 +486,6 @@ def _process_single_result( if finished and isinstance(diffusion_metrics, dict) and diffusion_metrics: self.prom_metrics.observe_diffusion_metrics(stage_id, diffusion_metrics) - output_type = getattr(engine_outputs, "final_output_type", stage_meta["final_output_type"]) images = getattr(engine_outputs, "images", []) if output_type == "image" else [] return OmniRequestOutput( request_id=req_id or "", diff --git a/vllm_omni/metrics/definitions.py b/vllm_omni/metrics/definitions.py index babd6104cc8..f6f4f3feda4 100644 --- a/vllm_omni/metrics/definitions.py +++ b/vllm_omni/metrics/definitions.py @@ -141,3 +141,38 @@ def compute_video_rtf(stage_gen_time_s: float, video_duration_s: float) -> float if video_duration_s <= 0: return 0.0 return stage_gen_time_s / video_duration_s + + +# ============================================================================ +# Audio sample-rate resolution +# ============================================================================ +# Most common across vllm-omni talker variants (cosyvoice3, omnivoice, +# qwen3_tts, mimo_audio, voxcpm). voxcpm2 uses 48000, stable_audio 44100, +# ming_flash 16000 — these models populate multimodal_output["audio_sample_rate"] +# at runtime so this default only kicks in when the field is missing. +DEFAULT_AUDIO_SAMPLE_RATE = 24000 + +_SAMPLE_RATE_KEYS = ("audio_sample_rate", "sample_rate", "sampling_rate", "sr") + + +def resolve_audio_sample_rate(multimodal_output: dict | None) -> int: + """Extract audio sample_rate from a multimodal_output dict, with fallbacks. + + Tries the same key chain as serving_chat.py's audio response path so + /metrics audio_duration_seconds = audio_frames / sample_rate stays + consistent with what the OpenAI streaming endpoint reports back to clients. + Returns DEFAULT_AUDIO_SAMPLE_RATE when no usable value is present. + """ + if not multimodal_output: + return DEFAULT_AUDIO_SAMPLE_RATE + for key in _SAMPLE_RATE_KEYS: + raw = multimodal_output.get(key) + if raw is None: + continue + try: + value = int(raw) + except (TypeError, ValueError): + continue + if value > 0: + return value + return DEFAULT_AUDIO_SAMPLE_RATE diff --git a/vllm_omni/metrics/modality.py b/vllm_omni/metrics/modality.py index 5c47320792e..137571930fc 100644 --- a/vllm_omni/metrics/modality.py +++ b/vllm_omni/metrics/modality.py @@ -18,6 +18,8 @@ from __future__ import annotations +from typing import Any + from prometheus_client import Counter, Histogram from vllm_omni.metrics import definitions as defs @@ -153,3 +155,60 @@ def observe_video_generation_time( _video_generation_time_family.labels( model_name=self._model_name, stage=stage, replica=replica ).observe(gen_time_seconds) + + +def observe_modality_at_finalize( + mod_metrics: OmniModalityMetrics, + *, + output_type: str | None, + stage_id: int, + replica_id: int | None, + stage_metrics: Any, + engine_outputs: Any, + request_arrival_ts: float, + finalize_ts: float, +) -> None: + """Route per-modality observations for a finalized request. + + Used by ``omni_base._process_single_result`` inside the e2e_done finalize + guard so it fires once per request. Skips text path (covered by upstream + ``vllm:*{stage="thinker", ...}``) and any case where required inputs are + missing — caller should not need to pre-validate. + + audio_ttfp is intentionally NOT observed here; it's emitted by the + streaming hook (Phase 3.3) at first-packet time, not at finalize. + """ + if replica_id is None or stage_metrics is None or output_type is None: + return + if output_type not in ("audio", "image", "video"): + return + + stage_label = str(stage_id) + replica_label = str(replica_id) + gen_time_s = float(getattr(stage_metrics, "stage_gen_time_ms", 0.0)) / 1000.0 + + if output_type == "audio": + mm_out = getattr(engine_outputs, "multimodal_output", None) or {} + sample_rate = defs.resolve_audio_sample_rate(mm_out) + n_frames = int(getattr(stage_metrics, "audio_generated_frames", 0) or 0) + mod_metrics.inc_audio_frames(stage_label, replica_label, n_frames) + duration_s = n_frames / sample_rate if sample_rate > 0 else 0.0 + if duration_s > 0: + mod_metrics.observe_audio_duration(stage_label, replica_label, duration_s) + mod_metrics.observe_audio_rtf( + stage_label, + replica_label, + defs.compute_audio_rtf(gen_time_s, duration_s), + ) + elif output_type == "image": + n_images = len(getattr(engine_outputs, "images", []) or []) + mod_metrics.inc_image_num(stage_label, replica_label, n_images) + mod_metrics.observe_image_generation_time( + stage_label, replica_label, gen_time_s + ) + image_ttfp_s = max(finalize_ts - request_arrival_ts, 0.0) + mod_metrics.observe_image_ttfp(stage_label, replica_label, image_ttfp_s) + else: # video + mod_metrics.observe_video_generation_time( + stage_label, replica_label, gen_time_s + ) From ab89a1a431a8cbe68cb3d60dfafd74308ac246fe Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 17:32:53 +0800 Subject: [PATCH 09/13] [Metrics] Wire audio_ttfp_seconds streaming hook for first-packet observation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3.3 of multi-modal observability (RFC G1) — completes the 8th modality family. Unlike the other 7 (finalize-path), audio_ttfp must be measured at the moment the first audio packet leaves the engine, not at request finalize. Five edits: 1. ClientRequestState — add request_arrival_ts (wall-clock anchor for the TTFP delta) and first_audio_ts (once-per-request guard). 2. async_omni.generate() — populate req_state.request_arrival_ts = wall_start_ts right after creating the ClientRequestState, before the orchestrator accepts the request. 3. modality.observe_audio_first_packet() — module-level helper that computes max(now_ts - arrival_ts, 0.0) and emits audio_ttfp_seconds. Defensive-skips when replica_id is None or arrival_ts == 0 (uninitialized). The once-per-request guard lives in the caller. 4. serving_chat.chat_completion_stream_generator() — HTTP SSE audio branch checks req_state.first_audio_ts and observes on first hit. Looks up replica via stage_pools[stage_id].get_bound_replica_id() so the metric carries the right per-replica label. 5. serving_video_stream._process_query_engine() — WebSocket path inserts the same guard + observe right where t_first_audio is already set, so the existing first-packet detection logic is reused. Caller-side guards (req_state.first_audio_ts is None) prevent re-observe on subsequent audio chunks for the same request_id. Both hook sites defensive-skip if request_states or stage_pools is missing. 4 new helper unit tests (valid inputs / replica=None skip / arrival_ts=0 skip / clock-skew clamp to 0). Total 26 modality tests pass. --- tests/metrics/test_modality.py | 49 ++++++++++++++++++- vllm_omni/entrypoints/async_omni.py | 1 + vllm_omni/entrypoints/client_request_state.py | 8 +++ vllm_omni/entrypoints/openai/serving_chat.py | 21 ++++++++ .../openai/serving_video_stream.py | 26 ++++++++++ vllm_omni/metrics/modality.py | 25 ++++++++++ 6 files changed, 129 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_modality.py b/tests/metrics/test_modality.py index 192f2dfaf47..b44a29bf351 100644 --- a/tests/metrics/test_modality.py +++ b/tests/metrics/test_modality.py @@ -4,7 +4,11 @@ from prometheus_client import REGISTRY, generate_latest from vllm_omni.metrics import definitions as defs -from vllm_omni.metrics.modality import OmniModalityMetrics, observe_modality_at_finalize +from vllm_omni.metrics.modality import ( + OmniModalityMetrics, + observe_audio_first_packet, + observe_modality_at_finalize, +) pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -341,6 +345,49 @@ def test_stage_metrics_none_skipped(self): assert stub.calls == [] +class TestObserveAudioFirstPacket: + def test_observes_with_valid_inputs(self): + stub = _StubModMetrics() + # Patch in audio_ttfp to the stub for routing assertion. + stub.observe_audio_ttfp = lambda s, r, t: stub.calls.append(("observe_audio_ttfp", s, r, t)) + + observe_audio_first_packet( + stub, + stage_id=1, + replica_id=0, + arrival_ts=100.0, + now_ts=100.42, + ) + assert stub.calls == [("observe_audio_ttfp", "1", "0", pytest.approx(0.42))] + + def test_replica_none_skipped(self): + stub = _StubModMetrics() + stub.observe_audio_ttfp = lambda s, r, t: stub.calls.append(("observe_audio_ttfp", s, r, t)) + observe_audio_first_packet( + stub, stage_id=1, replica_id=None, arrival_ts=100.0, now_ts=100.5 + ) + assert stub.calls == [] + + def test_arrival_ts_zero_skipped(self): + # Defensive: req_state.request_arrival_ts == 0.0 means async_omni + # never populated it (e.g. some fast path). Skip rather than emit + # garbage TTFP measured against epoch. + stub = _StubModMetrics() + stub.observe_audio_ttfp = lambda s, r, t: stub.calls.append(("observe_audio_ttfp", s, r, t)) + observe_audio_first_packet( + stub, stage_id=1, replica_id=0, arrival_ts=0.0, now_ts=100.5 + ) + assert stub.calls == [] + + def test_clock_skew_clamped_to_zero(self): + stub = _StubModMetrics() + stub.observe_audio_ttfp = lambda s, r, t: stub.calls.append(("observe_audio_ttfp", s, r, t)) + observe_audio_first_packet( + stub, stage_id=1, replica_id=0, arrival_ts=100.5, now_ts=100.0 + ) + assert stub.calls == [("observe_audio_ttfp", "1", "0", 0.0)] + + class TestBucketSelection: def test_audio_rtf_uses_rtf_buckets(self, mod: OmniModalityMetrics) -> None: stage, replica = "talker_buckets", "0" diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 056f56c003b..21ce1d839c4 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -302,6 +302,7 @@ async def generate( ) req_state = ClientRequestState(request_id) req_state.metrics = metrics + req_state.request_arrival_ts = wall_start_ts self.request_states[request_id] = req_state # PD disaggregation: modify prefill-stage sampling params per request diff --git a/vllm_omni/entrypoints/client_request_state.py b/vllm_omni/entrypoints/client_request_state.py index 1c9103f795d..9a8c688607b 100644 --- a/vllm_omni/entrypoints/client_request_state.py +++ b/vllm_omni/entrypoints/client_request_state.py @@ -11,3 +11,11 @@ def __init__(self, request_id: str, queue: asyncio.Queue | None = None): self.stage_id: int | None = None self.queue = queue if queue is not None else asyncio.Queue() self.metrics: OrchestratorAggregator | None = None + # Wall-clock time at which the user's request arrived in the engine + # entrypoint. Set in async_omni.generate() before the orchestrator + # accepts the request. Used as the "起算" anchor for audio_ttfp. + self.request_arrival_ts: float = 0.0 + # Wall-clock time at which the first audio packet was observed for + # this request. None means the streaming hook hasn't fired yet. + # Used as the once-per-request guard for audio_ttfp_seconds emit. + self.first_audio_ts: float | None = None diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 99827454e70..9cf8824d3c9 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -16,6 +16,7 @@ from vllm_omni.diffusion.diffusion_engine import get_extra_body_params, get_extra_output_params from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.metrics.modality import observe_audio_first_packet from vllm_omni.entrypoints.openai.protocol.chat_completion import OmniChatCompletionResponse from vllm_omni.entrypoints.utils import coerce_param_message_types from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt @@ -1542,6 +1543,26 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" elif final_output_type == "audio": + # Phase 3.3: observe audio_ttfp_seconds on first audio packet + # for this request_id (once-per-request guard via first_audio_ts). + req_state = self.engine_client.request_states.get(request_id) + if req_state is not None and req_state.first_audio_ts is None: + now_ts = time.time() + req_state.first_audio_ts = now_ts + stage_pools = getattr(self.engine_client.engine, "stage_pools", None) + replica_id = ( + stage_pools[omni_res.stage_id].get_bound_replica_id(request_id) + if stage_pools is not None and 0 <= omni_res.stage_id < len(stage_pools) + else None + ) + observe_audio_first_packet( + self.engine_client.mod_metrics, + stage_id=omni_res.stage_id, + replica_id=replica_id, + arrival_ts=req_state.request_arrival_ts, + now_ts=now_ts, + ) + role = self.get_chat_request_role(request) choices_data = self._create_audio_choice(omni_res, role, request, stream=True) chunk = OmniChatCompletionStreamResponse( diff --git a/vllm_omni/entrypoints/openai/serving_video_stream.py b/vllm_omni/entrypoints/openai/serving_video_stream.py index a76b241c55b..01ad8df2fa3 100644 --- a/vllm_omni/entrypoints/openai/serving_video_stream.py +++ b/vllm_omni/entrypoints/openai/serving_video_stream.py @@ -45,6 +45,7 @@ from vllm_omni.entrypoints.openai.video_stream_context import ( text_only_message, ) +from vllm_omni.metrics.modality import observe_audio_first_packet from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -573,6 +574,31 @@ async def _process_query_engine( if t_first_audio is None: t_first_audio = _time.monotonic() + # Phase 3.3: observe audio_ttfp_seconds. Hook here rather + # than at the OpenAI SSE path because this WebSocket route + # is the canonical real-time entrypoint for video+audio. + req_state = ( + self._engine_client.request_states.get(request_id) + if self._engine_client is not None + else None + ) + if req_state is not None and req_state.first_audio_ts is None: + now_ts = _time.time() + req_state.first_audio_ts = now_ts + stage_pools = getattr(self._engine_client.engine, "stage_pools", None) + stage_id = getattr(output, "stage_id", 0) + replica_id = ( + stage_pools[stage_id].get_bound_replica_id(request_id) + if stage_pools is not None and 0 <= stage_id < len(stage_pools) + else None + ) + observe_audio_first_packet( + self._engine_client.mod_metrics, + stage_id=stage_id, + replica_id=replica_id, + arrival_ts=req_state.request_arrival_ts, + now_ts=now_ts, + ) audio_chunk_count += 1 if streaming: b64, audio_chunks_drained = self._extract_audio_delta_b64( diff --git a/vllm_omni/metrics/modality.py b/vllm_omni/metrics/modality.py index 137571930fc..c1829099c2e 100644 --- a/vllm_omni/metrics/modality.py +++ b/vllm_omni/metrics/modality.py @@ -212,3 +212,28 @@ def observe_modality_at_finalize( mod_metrics.observe_video_generation_time( stage_label, replica_label, gen_time_s ) + + +def observe_audio_first_packet( + mod_metrics: OmniModalityMetrics, + *, + stage_id: int, + replica_id: int | None, + arrival_ts: float, + now_ts: float, +) -> None: + """Observe audio_ttfp_seconds on a request's first audio packet. + + Caller is responsible for the once-per-request guard (e.g. checking + ``ClientRequestState.first_audio_ts is None``) so this function fires at + most once per request_id. Defensive-skips when ``replica_id`` or + ``arrival_ts`` is insufficient — both can legitimately be missing in error + paths and we'd rather drop the sample than emit a wrong (stage, replica). + + Phase 3.3 — companion to ``observe_modality_at_finalize`` which handles the + other 7 modality families at finalize time. + """ + if replica_id is None or arrival_ts <= 0: + return + ttfp = max(now_ts - arrival_ts, 0.0) + mod_metrics.observe_audio_ttfp(str(stage_id), str(replica_id), ttfp) From a72fc840431bf28334f18b2a64bd0cd4f5ad96df Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 18:24:45 +0800 Subject: [PATCH 10/13] [Metrics] Add OmniTransferMetrics for cross-stage transfer observability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 4 of multi-modal observability (RFC G3). Four Histogram families with {model_name, from_stage, from_replica, to_stage, to_replica} labels, emitted from the existing TransferEdgeStats accumulators in OrchestratorAggregator: transfer_size_bytes -- BYTES_BUCKETS, observed at TX hook transfer_tx_time_ms -- MS_BUCKETS, observed at TX hook transfer_rx_decode_time_ms -- MS_BUCKETS, observed at RX hook transfer_in_flight_time_ms -- MS_BUCKETS, observed at RX hook Each .observe_*() corresponds to one physical transfer event (one chunk hop) so the histogram tracks per-transfer distribution rather than request-aggregated sums. Two non-trivial design choices, both deviating from RFC §3.2.6: 1. Add model_name to the label set (RFC lists only the four stage/replica labels). Rationale: aligns transfer with the rest of the omni_* families (audio_*, image_*, video_*, num_requests_*) so PromQL joins on model_name work uniformly. Cardinality cost is zero for typical single-model deployments. To be amended back into the RFC as D11. 2. Resolve from_replica / to_replica by consulting the sticky-routing binding the orchestrator already maintains, rather than plumbing replica ids through TransferEdgeStats / StageRequestStats / connector adapters. Phase 2 (PR #2396) gave us stage_pool.get_bound_replica_id(request_id), which returns the replica each request is bound to within a stage; we look it up at emit time. Five files would need touching otherwise (TransferEdgeStats, StageRequestStats, stage_pool.build_stage_metrics, adapter.on_forward call site, request bookkeeping). To be amended back into the RFC as D12. Five edits: 1. definitions.py — TRANSFER_LABELS prepends model_name. 2. transfer.py — new OmniTransferMetrics(model_name) class with the four Histogram families and one observe method per family. Same wrapper pattern as OmniPrometheusMetrics (PR #3362) and OmniModalityMetrics (Phase 3): bind model_name at __init__, accept stage/replica at observe time. 3. stats.py — OrchestratorAggregator gains two optional kwargs at __init__ (transfer_emitter, replica_resolver). record_transfer_tx and record_transfer_rx call thin _emit_transfer_tx / _emit_transfer_rx helpers after the existing accumulation, which fail-safe (skip emit) when either dep is missing or the resolver returns None for either side. The TransferEdgeStats accumulation path is unchanged so existing log/aggregation consumers stay intact. 4. omni_base.py — instantiate self.transfer_metrics alongside self.prom_metrics / self.mod_metrics. 5. async_omni.py — wire transfer_emitter and replica_resolver into the per-request OrchestratorMetrics construction. New _resolve_transfer_replica helper looks up self.engine.stage_pools[s].get_bound_replica_id(rid). 15 unit tests cover the 4-family registration, observe APIs, bucket selection, multi-edge cardinality, plus the OrchestratorAggregator emit hooks (tx + rx full emit, defensive skips for emitter=None / resolver=None / one-side-resolves-to-None, rx skips stage 0). Stub emitter records call signatures so the routing logic is asserted without standing up a Prometheus registry. Note: TX-side hook (record_transfer_tx) is emit-ready but adapter.try_send_via_connector — its only caller — currently has no upstream invocation in the main code path (likely plugin-driven or about-to-land). RX side (record_transfer_rx) is the active path today; TX side will start emitting as soon as the connector pipeline activates. --- tests/metrics/test_transfer.py | 344 ++++++++++++++++++++++++++++ vllm_omni/entrypoints/async_omni.py | 16 ++ vllm_omni/entrypoints/omni_base.py | 2 + vllm_omni/metrics/definitions.py | 6 +- vllm_omni/metrics/stats.py | 89 ++++++- vllm_omni/metrics/transfer.py | 140 +++++++++++ 6 files changed, 595 insertions(+), 2 deletions(-) create mode 100644 tests/metrics/test_transfer.py create mode 100644 vllm_omni/metrics/transfer.py diff --git a/tests/metrics/test_transfer.py b/tests/metrics/test_transfer.py new file mode 100644 index 00000000000..a40589e76fb --- /dev/null +++ b/tests/metrics/test_transfer.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import pytest +from prometheus_client import REGISTRY, generate_latest + +from vllm_omni.metrics import definitions as defs +from vllm_omni.metrics.transfer import OmniTransferMetrics + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +_MODEL = "test-transfer-model" + + +@pytest.fixture(scope="module") +def tx() -> OmniTransferMetrics: + return OmniTransferMetrics(model_name=_MODEL) + + +def _sample_value(output: str, line_prefix: str) -> float | None: + for line in output.splitlines(): + if line.startswith(line_prefix): + return float(line.split()[-1]) + return None + + +_EXPECTED_FAMILIES = [ + defs.TRANSFER_SIZE_BYTES, + defs.TRANSFER_TX_TIME_MS, + defs.TRANSFER_RX_DECODE_TIME_MS, + defs.TRANSFER_IN_FLIGHT_TIME_MS, +] + + +# --------------------------------------------------------------------------- +# Family registration +# --------------------------------------------------------------------------- + + +class TestRegistration: + def test_all_four_families_present(self, tx: OmniTransferMetrics) -> None: + # Trigger one observation per family so the registry exposes them. + tx.observe_size(0, 0, 1, 0, 1024) + tx.observe_tx_time(0, 0, 1, 0, 5.0) + tx.observe_rx_decode_time(0, 0, 1, 0, 8.0) + tx.observe_in_flight_time(0, 0, 1, 0, 2.0) + + out = generate_latest(REGISTRY).decode() + for name in _EXPECTED_FAMILIES: + assert f"# HELP {name}" in out, f"missing family: {name}" + + +# --------------------------------------------------------------------------- +# Observe APIs +# --------------------------------------------------------------------------- + + +class TestObserveSize: + def test_size_observed_with_correct_labels(self, tx: OmniTransferMetrics) -> None: + # Distinct (from, to) so test isolation holds across cases. + tx.observe_size(2, 0, 3, 1, 65536) + out = generate_latest(REGISTRY).decode() + prefix = ( + f'{defs.TRANSFER_SIZE_BYTES}_sum' + f'{{from_replica="0",from_stage="2",model_name="{_MODEL}",' + f'to_replica="1",to_stage="3"}}' + ) + assert _sample_value(out, prefix) == 65536.0 + + +class TestObserveTxTime: + def test_tx_time_observed(self, tx: OmniTransferMetrics) -> None: + tx.observe_tx_time(2, 1, 3, 0, 12.5) + out = generate_latest(REGISTRY).decode() + prefix = ( + f'{defs.TRANSFER_TX_TIME_MS}_sum' + f'{{from_replica="1",from_stage="2",model_name="{_MODEL}",' + f'to_replica="0",to_stage="3"}}' + ) + assert _sample_value(out, prefix) == 12.5 + + +class TestObserveRxDecodeTime: + def test_rx_decode_time_observed(self, tx: OmniTransferMetrics) -> None: + tx.observe_rx_decode_time(0, 0, 1, 0, 4.2) + out = generate_latest(REGISTRY).decode() + prefix = ( + f'{defs.TRANSFER_RX_DECODE_TIME_MS}_sum' + f'{{from_replica="0",from_stage="0",model_name="{_MODEL}",' + f'to_replica="0",to_stage="1"}}' + ) + assert _sample_value(out, prefix) == 4.2 + + +class TestObserveInFlightTime: + def test_in_flight_time_observed(self, tx: OmniTransferMetrics) -> None: + tx.observe_in_flight_time(0, 0, 1, 0, 1.7) + out = generate_latest(REGISTRY).decode() + prefix = ( + f'{defs.TRANSFER_IN_FLIGHT_TIME_MS}_sum' + f'{{from_replica="0",from_stage="0",model_name="{_MODEL}",' + f'to_replica="0",to_stage="1"}}' + ) + assert _sample_value(out, prefix) == 1.7 + + +# --------------------------------------------------------------------------- +# Multi (from, to) cardinality +# --------------------------------------------------------------------------- + + +class TestCardinality: + def test_multiple_edges_produce_independent_series( + self, tx: OmniTransferMetrics + ) -> None: + # Same family, different (from_replica, to_replica) → distinct series. + tx.observe_size(5, 0, 6, 0, 100) + tx.observe_size(5, 0, 6, 1, 200) + tx.observe_size(5, 1, 6, 0, 300) + + out = generate_latest(REGISTRY).decode() + + prefix_a = ( + f'{defs.TRANSFER_SIZE_BYTES}_sum' + f'{{from_replica="0",from_stage="5",model_name="{_MODEL}",' + f'to_replica="0",to_stage="6"}}' + ) + prefix_b = ( + f'{defs.TRANSFER_SIZE_BYTES}_sum' + f'{{from_replica="0",from_stage="5",model_name="{_MODEL}",' + f'to_replica="1",to_stage="6"}}' + ) + prefix_c = ( + f'{defs.TRANSFER_SIZE_BYTES}_sum' + f'{{from_replica="1",from_stage="5",model_name="{_MODEL}",' + f'to_replica="0",to_stage="6"}}' + ) + assert _sample_value(out, prefix_a) == 100.0 + assert _sample_value(out, prefix_b) == 200.0 + assert _sample_value(out, prefix_c) == 300.0 + + +# --------------------------------------------------------------------------- +# Bucket selection +# --------------------------------------------------------------------------- + + +class TestBucketSelection: + def test_size_uses_bytes_buckets(self, tx: OmniTransferMetrics) -> None: + tx.observe_size(7, 0, 8, 0, 4096) + out = generate_latest(REGISTRY).decode() + # BYTES_BUCKETS contains 1024 — distinctive vs MS / SECONDS buckets. + marker = f'{defs.TRANSFER_SIZE_BYTES}_bucket{{from_replica="0",from_stage="7",le="1024"' + assert marker in out + + def test_time_families_use_ms_buckets(self, tx: OmniTransferMetrics) -> None: + tx.observe_tx_time(7, 0, 8, 0, 1.0) + out = generate_latest(REGISTRY).decode() + # MS_BUCKETS contains 1.0 — present in tx_time histogram. + marker = f'{defs.TRANSFER_TX_TIME_MS}_bucket{{from_replica="0",from_stage="7",le="1.0"' + assert marker in out + + +# --------------------------------------------------------------------------- +# OrchestratorAggregator emit hook (Phase 4.2) +# --------------------------------------------------------------------------- + + +from vllm_omni.metrics.stats import OrchestratorAggregator, StageRequestStats, StageStats + + +class _StubTransferEmitter: + """Records every observe_* call so the hook routing can be asserted + without standing up a Prometheus registry.""" + + def __init__(self) -> None: + self.calls: list[tuple] = [] + + def observe_size(self, fs, fr, ts, tr, n): + self.calls.append(("observe_size", fs, fr, ts, tr, n)) + + def observe_tx_time(self, fs, fr, ts, tr, t): + self.calls.append(("observe_tx_time", fs, fr, ts, tr, t)) + + def observe_rx_decode_time(self, fs, fr, ts, tr, t): + self.calls.append(("observe_rx_decode_time", fs, fr, ts, tr, t)) + + def observe_in_flight_time(self, fs, fr, ts, tr, t): + self.calls.append(("observe_in_flight_time", fs, fr, ts, tr, t)) + + +def _make_stats(stage_id, request_id, *, rx_decode=0.0, rx_in_flight=0.0, rx_bytes=0): + """Minimal StageRequestStats for record_transfer_rx input.""" + return StageRequestStats( + batch_id=1, + batch_size=1, + num_tokens_in=0, + num_tokens_out=0, + stage_gen_time_ms=0.0, + rx_transfer_bytes=rx_bytes, + rx_decode_time_ms=rx_decode, + rx_in_flight_time_ms=rx_in_flight, + stage_stats=StageStats(), + stage_id=stage_id, + request_id=request_id, + ) + + +class TestEmitHookTx: + def test_record_transfer_tx_emits_size_and_tx_time(self): + emitter = _StubTransferEmitter() + agg = OrchestratorAggregator( + num_stages=3, + log_stats=False, + wall_start_ts=0.0, + final_stage_id_for_e2e=2, + transfer_emitter=emitter, + replica_resolver=lambda s, rid: {0: 1, 1: 0}.get(s), + ) + agg.record_transfer_tx( + from_stage=0, + to_stage=1, + request_id="r-tx-1", + size_bytes=2048, + tx_time_ms=7.5, + used_shm=False, + ) + assert emitter.calls == [ + ("observe_size", 0, 1, 1, 0, 2048), + ("observe_tx_time", 0, 1, 1, 0, 7.5), + ] + + def test_record_transfer_tx_no_emit_when_emitter_none(self): + agg = OrchestratorAggregator( + num_stages=3, + log_stats=False, + wall_start_ts=0.0, + final_stage_id_for_e2e=2, + transfer_emitter=None, + replica_resolver=lambda s, rid: 0, + ) + # Should not raise; just no-op the emit. + evt = agg.record_transfer_tx( + from_stage=0, + to_stage=1, + request_id="r-tx-2", + size_bytes=128, + tx_time_ms=1.0, + used_shm=True, + ) + # Accumulation still happens — only Prometheus emit is skipped. + assert evt is not None + assert evt.size_bytes == 128 + assert evt.tx_time_ms == 1.0 + + def test_record_transfer_tx_no_emit_when_resolver_returns_none(self): + emitter = _StubTransferEmitter() + agg = OrchestratorAggregator( + num_stages=3, + log_stats=False, + wall_start_ts=0.0, + final_stage_id_for_e2e=2, + transfer_emitter=emitter, + replica_resolver=lambda s, rid: None, # always-None resolver + ) + agg.record_transfer_tx( + from_stage=0, + to_stage=1, + request_id="r-tx-3", + size_bytes=512, + tx_time_ms=2.0, + used_shm=False, + ) + assert emitter.calls == [] + + def test_record_transfer_tx_no_emit_when_one_side_resolves_to_none(self): + emitter = _StubTransferEmitter() + agg = OrchestratorAggregator( + num_stages=3, + log_stats=False, + wall_start_ts=0.0, + final_stage_id_for_e2e=2, + transfer_emitter=emitter, + replica_resolver=lambda s, rid: 0 if s == 0 else None, # to-side fails + ) + agg.record_transfer_tx( + from_stage=0, + to_stage=1, + request_id="r-tx-4", + size_bytes=64, + tx_time_ms=0.5, + used_shm=False, + ) + assert emitter.calls == [] + + +class TestEmitHookRx: + def test_record_transfer_rx_emits_decode_and_in_flight(self): + emitter = _StubTransferEmitter() + agg = OrchestratorAggregator( + num_stages=3, + log_stats=False, + wall_start_ts=0.0, + final_stage_id_for_e2e=2, + transfer_emitter=emitter, + replica_resolver=lambda s, rid: {0: 1, 1: 0}.get(s), + ) + stats = _make_stats( + stage_id=1, request_id="r-rx-1", rx_decode=4.2, rx_in_flight=1.7 + ) + agg.record_transfer_rx(stats) + assert emitter.calls == [ + ("observe_rx_decode_time", 0, 1, 1, 0, 4.2), + ("observe_in_flight_time", 0, 1, 1, 0, 1.7), + ] + + def test_record_transfer_rx_skips_stage_zero(self): + emitter = _StubTransferEmitter() + agg = OrchestratorAggregator( + num_stages=3, + log_stats=False, + wall_start_ts=0.0, + final_stage_id_for_e2e=2, + transfer_emitter=emitter, + replica_resolver=lambda s, rid: 0, + ) + # stage_id=0 has no upstream, so record_transfer_rx returns early. + stats = _make_stats(stage_id=0, request_id="r-rx-2", rx_decode=4.2) + agg.record_transfer_rx(stats) + assert emitter.calls == [] + + def test_record_transfer_rx_no_emit_when_emitter_none(self): + agg = OrchestratorAggregator( + num_stages=3, + log_stats=False, + wall_start_ts=0.0, + final_stage_id_for_e2e=2, + ) + stats = _make_stats(stage_id=1, request_id="r-rx-3", rx_decode=1.0, rx_in_flight=0.5) + evt = agg.record_transfer_rx(stats) + # Accumulation still happens — only Prometheus emit is skipped. + assert evt is not None + assert evt.rx_decode_time_ms == 1.0 + assert evt.in_flight_time_ms == 0.5 diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 21ce1d839c4..6eecb125d7e 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -158,6 +158,20 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None: renderer = renderer_from_config(vllm_config) self.io_processor = get_io_processor(vllm_config, renderer, io_processor_plugin) + def _resolve_transfer_replica(self, stage_id: int, request_id: str) -> int | None: + """Look up the sticky-routed replica for (stage_id, request_id). + + Used as the ``replica_resolver`` callback by ``OrchestratorAggregator`` + to label transfer_* metrics without plumbing replica ids through + ``TransferEdgeStats`` / ``StageRequestStats`` / connector adapters. + Returns None when stage_id is out of range or the request hasn't been + bound to a replica yet — the metric emit then defensive-skips. + """ + pools = getattr(self.engine, "stage_pools", None) + if pools is None or not (0 <= stage_id < len(pools)): + return None + return pools[stage_id].get_bound_replica_id(request_id) + def _get_comprehension_stage_index(self) -> int | None: fallback_idx: int | None = None for idx, stage_client in enumerate(self.engine.stage_clients): @@ -299,6 +313,8 @@ async def generate( self.log_stats, wall_start_ts, final_stage_id_for_e2e, + transfer_emitter=self.transfer_metrics, + replica_resolver=self._resolve_transfer_replica, ) req_state = ClientRequestState(request_id) req_state.metrics = metrics diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index b718547a14f..d5c9284c895 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -21,6 +21,7 @@ from vllm_omni.metrics.modality import OmniModalityMetrics, observe_modality_at_finalize from vllm_omni.metrics.prometheus import OmniPrometheusMetrics from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics +from vllm_omni.metrics.transfer import OmniTransferMetrics from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific from vllm_omni.outputs import OmniRequestOutput @@ -189,6 +190,7 @@ def __init__( self.request_states: dict[str, ClientRequestState] = {} self.prom_metrics = OmniPrometheusMetrics(model_name=model) self.mod_metrics = OmniModalityMetrics(model_name=model) + self.transfer_metrics = OmniTransferMetrics(model_name=model) self.default_sampling_params_list = self.engine.default_sampling_params_list if not self.output_modalities: diff --git a/vllm_omni/metrics/definitions.py b/vllm_omni/metrics/definitions.py index f6f4f3feda4..e1e3c518a06 100644 --- a/vllm_omni/metrics/definitions.py +++ b/vllm_omni/metrics/definitions.py @@ -93,7 +93,11 @@ STAGE_LABELS = ("model_name", "stage", "replica") # Cross-stage transfer label set (G3). Field names match TransferEdgeStats. -TRANSFER_LABELS = ("from_stage", "from_replica", "to_stage", "to_replica") +# model_name is included (deviating from RFC §3.2.6 which lists only the four +# stage/replica labels) so transfer aligns with the rest of the omni_* family +# naming and PromQL joins on model_name work uniformly across audio/image/ +# video/transfer. +TRANSFER_LABELS = ("model_name", "from_stage", "from_replica", "to_stage", "to_replica") # ============================================================================ diff --git a/vllm_omni/metrics/stats.py b/vllm_omni/metrics/stats.py index 4245deb5453..6131c909aa8 100644 --- a/vllm_omni/metrics/stats.py +++ b/vllm_omni/metrics/stats.py @@ -5,12 +5,15 @@ from collections.abc import Callable from contextlib import contextmanager from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any from vllm.logger import init_logger from vllm_omni.metrics.utils import _build_field_defs, _build_row, _format_table +if TYPE_CHECKING: + from vllm_omni.metrics.transfer import OmniTransferMetrics + logger = init_logger(__name__) @@ -121,6 +124,9 @@ def __init__( log_stats: bool, wall_start_ts: float, final_stage_id_for_e2e: dict[str, int] | int, + *, + transfer_emitter: OmniTransferMetrics | None = None, + replica_resolver: Callable[[int, str], int | None] | None = None, ) -> None: self.num_stages = int(num_stages) self.log_stats = bool(log_stats) @@ -131,6 +137,12 @@ def __init__( tuple[int, int, str], TransferEdgeStats ] = {} # Key: (from_stage, to_stage, request_id) self.e2e_events: list[RequestE2EStats] = [] + # Phase 4 G3: emit per-physical-transfer Histogram observations to + # Prometheus alongside the existing TransferEdgeStats accumulation. + # Both deps are optional so OrchestratorAggregator stays usable in + # contexts that don't have a Prometheus registry (e.g. unit tests). + self._transfer_emitter = transfer_emitter + self._replica_resolver = replica_resolver def init_run_state(self, wall_start_ts: float) -> None: # Per-run aggregates and timing state @@ -191,6 +203,14 @@ def record_transfer_tx( evt.size_bytes += int(size_bytes) evt.tx_time_ms += float(tx_time_ms) evt.used_shm = evt.used_shm or bool(used_shm) + # Phase 4 G3: emit per-physical-transfer Histogram observations. + self._emit_transfer_tx( + from_stage=int(from_stage), + to_stage=int(to_stage), + request_id=str(request_id), + size_bytes=int(size_bytes), + tx_time_ms=float(tx_time_ms), + ) return evt except Exception: return None @@ -212,10 +232,77 @@ def record_transfer_rx( evt.size_bytes = int(stats.rx_transfer_bytes) evt.rx_decode_time_ms += float(stats.rx_decode_time_ms) evt.in_flight_time_ms += float(stats.rx_in_flight_time_ms) + # Phase 4 G3: emit per-physical-receive Histogram observations. + self._emit_transfer_rx( + from_stage=from_stage, + to_stage=to_stage, + request_id=rid_key, + rx_decode_time_ms=float(stats.rx_decode_time_ms), + in_flight_time_ms=float(stats.rx_in_flight_time_ms), + ) return evt except Exception: return None + # ------------------------------------------------------------------ + # Prometheus emit hooks (Phase 4 G3). Both helpers are no-ops when + # transfer_emitter or replica_resolver is None, or when the resolver + # cannot find a (stage_id, request_id) -> replica_id mapping. We + # deliberately fail-safe (skip) rather than emit a series with a wrong + # or invented replica label. + # ------------------------------------------------------------------ + + def _resolve_edge_replicas( + self, from_stage: int, to_stage: int, request_id: str + ) -> tuple[int, int] | None: + if self._replica_resolver is None: + return None + from_r = self._replica_resolver(from_stage, request_id) + to_r = self._replica_resolver(to_stage, request_id) + if from_r is None or to_r is None: + return None + return from_r, to_r + + def _emit_transfer_tx( + self, + *, + from_stage: int, + to_stage: int, + request_id: str, + size_bytes: int, + tx_time_ms: float, + ) -> None: + if self._transfer_emitter is None: + return + replicas = self._resolve_edge_replicas(from_stage, to_stage, request_id) + if replicas is None: + return + from_r, to_r = replicas + self._transfer_emitter.observe_size(from_stage, from_r, to_stage, to_r, size_bytes) + self._transfer_emitter.observe_tx_time(from_stage, from_r, to_stage, to_r, tx_time_ms) + + def _emit_transfer_rx( + self, + *, + from_stage: int, + to_stage: int, + request_id: str, + rx_decode_time_ms: float, + in_flight_time_ms: float, + ) -> None: + if self._transfer_emitter is None: + return + replicas = self._resolve_edge_replicas(from_stage, to_stage, request_id) + if replicas is None: + return + from_r, to_r = replicas + self._transfer_emitter.observe_rx_decode_time( + from_stage, from_r, to_stage, to_r, rx_decode_time_ms + ) + self._transfer_emitter.observe_in_flight_time( + from_stage, from_r, to_stage, to_r, in_flight_time_ms + ) + def record_audio_generated_frames( self, output_to_yield: Any, diff --git a/vllm_omni/metrics/transfer.py b/vllm_omni/metrics/transfer.py new file mode 100644 index 00000000000..749d493430e --- /dev/null +++ b/vllm_omni/metrics/transfer.py @@ -0,0 +1,140 @@ +"""OmniTransferMetrics — cross-stage transfer Prometheus families (RFC G3). + +Four Histogram families with ``{model_name, from_stage, from_replica, +to_stage, to_replica}`` labels. Each ``.observe_*()`` call corresponds to one +physical transfer event (one chunk hop from a sender replica to a receiver +replica), not the per-request accumulated total — so the Histogram tracks +the distribution of physical transfers, not request-aggregated sums. + +Data source: ``vllm_omni.metrics.stats.TransferEdgeStats`` accumulators in +``OrchestratorAggregator.record_transfer_tx`` / ``record_transfer_rx``. The +emit hook lives in stats.py; this module only registers the families and +exposes the typed observe API. + +Note on label deviation from RFC §3.2.6: the RFC lists only the four +stage/replica labels. We add ``model_name`` so transfer aligns with the rest +of the omni_* family naming and PromQL can join on model_name uniformly. +""" + +from __future__ import annotations + +from prometheus_client import Histogram + +from vllm_omni.metrics import definitions as defs + +_labelnames = list(defs.TRANSFER_LABELS) + + +# ---------------------------------------------------------------------------- +# TX-side families (observed when record_transfer_tx fires) +# ---------------------------------------------------------------------------- +_transfer_size_bytes_family = Histogram( + defs.TRANSFER_SIZE_BYTES, + "Per-transfer payload size in bytes (one observation per physical hop).", + labelnames=_labelnames, + buckets=defs.BYTES_BUCKETS, +) +_transfer_tx_time_ms_family = Histogram( + defs.TRANSFER_TX_TIME_MS, + "Sender-side time in milliseconds (serialize + submit to connector).", + labelnames=_labelnames, + buckets=defs.MS_BUCKETS, +) + + +# ---------------------------------------------------------------------------- +# RX-side families (observed when record_transfer_rx fires) +# ---------------------------------------------------------------------------- +_transfer_rx_decode_time_ms_family = Histogram( + defs.TRANSFER_RX_DECODE_TIME_MS, + "Receiver-side time in milliseconds (recv + deserialize).", + labelnames=_labelnames, + buckets=defs.MS_BUCKETS, +) +_transfer_in_flight_time_ms_family = Histogram( + defs.TRANSFER_IN_FLIGHT_TIME_MS, + "Network in-flight time in milliseconds (TX done -> RX recv start).", + labelnames=_labelnames, + buckets=defs.MS_BUCKETS, +) + + +class OmniTransferMetrics: + """Per-(from, to) replica observe API for cross-stage transfers. + + A single instance per pipeline; ``model_name`` is bound at init and + every observe call carries it in the label set. Stage/replica are + passed at observe time because the same instance serves all + (from_stage, from_replica) -> (to_stage, to_replica) edges. + """ + + def __init__(self, model_name: str) -> None: + self._model_name = model_name + + # ---- TX side (record_transfer_tx hook) ------------------------------- + + def observe_size( + self, + from_stage: int, + from_replica: int, + to_stage: int, + to_replica: int, + size_bytes: int, + ) -> None: + _transfer_size_bytes_family.labels( + model_name=self._model_name, + from_stage=str(from_stage), + from_replica=str(from_replica), + to_stage=str(to_stage), + to_replica=str(to_replica), + ).observe(size_bytes) + + def observe_tx_time( + self, + from_stage: int, + from_replica: int, + to_stage: int, + to_replica: int, + tx_time_ms: float, + ) -> None: + _transfer_tx_time_ms_family.labels( + model_name=self._model_name, + from_stage=str(from_stage), + from_replica=str(from_replica), + to_stage=str(to_stage), + to_replica=str(to_replica), + ).observe(tx_time_ms) + + # ---- RX side (record_transfer_rx hook) ------------------------------- + + def observe_rx_decode_time( + self, + from_stage: int, + from_replica: int, + to_stage: int, + to_replica: int, + rx_decode_time_ms: float, + ) -> None: + _transfer_rx_decode_time_ms_family.labels( + model_name=self._model_name, + from_stage=str(from_stage), + from_replica=str(from_replica), + to_stage=str(to_stage), + to_replica=str(to_replica), + ).observe(rx_decode_time_ms) + + def observe_in_flight_time( + self, + from_stage: int, + from_replica: int, + to_stage: int, + to_replica: int, + in_flight_time_ms: float, + ) -> None: + _transfer_in_flight_time_ms_family.labels( + model_name=self._model_name, + from_stage=str(from_stage), + from_replica=str(from_replica), + to_stage=str(to_stage), + to_replica=str(to_replica), + ).observe(in_flight_time_ms) From e9fa635e527e8bf7d0b9c4b160f6ed49248bd448 Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 18:45:09 +0800 Subject: [PATCH 11/13] [Metrics] Replace num_requests_success/fail with requests_success_total{finished_reason} MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 5 of multi-modal observability (RFC G6). Collapse the two PR-#3362-era pipeline-level counters vllm:omni_num_requests_success vllm:omni_num_requests_fail into a single per-reason Counter: vllm:omni_requests_success_total{model_name, finished_reason} with finished_reason ∈ {stop, length, abort, ...} mirroring upstream vllm:request_success_total. The previous "fail" path now maps to finished_reason="abort" so the Pipeline-level success-rate dashboard becomes one query (sum by (finished_reason) ...) instead of needing two metrics joined. Breaking change: dashboards / alerts that reference the old num_requests_success / num_requests_fail counter names must migrate to the new requests_success_total{finished_reason} family. Four edits: 1. definitions.py — drop NUM_REQUESTS_SUCCESS, NUM_REQUESTS_FAIL. Add REQUESTS_SUCCESS (passed without _total since Counter auto-suffixes at exposition). 2. prometheus.py — replace _success_family + _fail_family with a single _completion_family using SUCCESS_LABELS = ("model_name", "finished_reason"). The label is bound per-call rather than at __init__ time. - OmniPrometheusMetrics.request_succeeded() gains a finished_reason="stop" default kwarg. - OmniPrometheusMetrics.request_failed() keeps its no-arg signature for back-compat with the cleanup call site, internally mapping to finished_reason="abort". 3. omni_base.py — at the request-finalize hook, extract finish_reason from engine_outputs.outputs[0].finish_reason (vLLM CompletionOutput convention) and pass it through to request_succeeded(). Falls back to "stop" when no completion output is present (defensive — e.g. diffusion stages). 4. tests/metrics/test_prometheus.py — refresh _PIPELINE_METRICS, exercise three reason buckets (stop x 2 + length x 1 + abort x 1) in the scrape fixture, replace the old success/fail-count assertions with per-reason ones, and adjust the histogram-count assertions (e2e/queue go from 2 to 3 since the abort path now only inc's the counter and doesn't observe latency). --- tests/metrics/test_prometheus.py | 38 ++++++++++++++++++++---------- vllm_omni/entrypoints/omni_base.py | 12 +++++++++- vllm_omni/metrics/definitions.py | 10 ++++---- vllm_omni/metrics/prometheus.py | 36 +++++++++++++++++----------- 4 files changed, 64 insertions(+), 32 deletions(-) diff --git a/tests/metrics/test_prometheus.py b/tests/metrics/test_prometheus.py index 94b50aaeea6..adfcaa7f687 100644 --- a/tests/metrics/test_prometheus.py +++ b/tests/metrics/test_prometheus.py @@ -14,8 +14,7 @@ _PIPELINE_METRICS = [ "vllm:omni_num_requests_running", "vllm:omni_num_requests_waiting", - "vllm:omni_num_requests_success", - "vllm:omni_num_requests_fail", + "vllm:omni_requests_success", "vllm:omni_e2e_request_latency_seconds", "vllm:omni_request_queue_time_seconds", ] @@ -40,9 +39,12 @@ def prom() -> OmniPrometheusMetrics: @pytest.fixture(scope="module") def scrape_output(prom: OmniPrometheusMetrics, registry: CollectorRegistry) -> str: - prom.request_succeeded(e2e_seconds=1.5, queue_seconds=0.3) - prom.request_succeeded(e2e_seconds=2.0, queue_seconds=0.5) - prom.request_failed() + # Two natural completions (stop) + one length-cap + one failure (abort) + # exercise three distinct finished_reason buckets in the merged Counter. + prom.request_succeeded(e2e_seconds=1.5, queue_seconds=0.3, finished_reason="stop") + prom.request_succeeded(e2e_seconds=2.0, queue_seconds=0.5, finished_reason="stop") + prom.request_succeeded(e2e_seconds=3.0, queue_seconds=0.4, finished_reason="length") + prom.request_failed() # → finished_reason="abort" prom.set_running(5) prom.set_waiting(2) prom.observe_diffusion_metrics( @@ -70,17 +72,24 @@ def test_all_metric_families_present(self, scrape_output: str) -> None: assert f"# HELP {name}" in scrape_output, f"missing metric family: {name}" def test_counter_values(self, scrape_output: str) -> None: - success = _sample_value( + # Per-reason buckets sourced from the merged completion Counter (G6). + stop = _sample_value( scrape_output, - f'vllm:omni_num_requests_success_total{{model_name="{_MODEL}"}}', + f'vllm:omni_requests_success_total{{finished_reason="stop",model_name="{_MODEL}"}}', ) - assert success == 2.0 + assert stop == 2.0 - fail = _sample_value( + length = _sample_value( scrape_output, - f'vllm:omni_num_requests_fail_total{{model_name="{_MODEL}"}}', + f'vllm:omni_requests_success_total{{finished_reason="length",model_name="{_MODEL}"}}', ) - assert fail == 1.0 + assert length == 1.0 + + abort = _sample_value( + scrape_output, + f'vllm:omni_requests_success_total{{finished_reason="abort",model_name="{_MODEL}"}}', + ) + assert abort == 1.0 def test_gauge_values(self, scrape_output: str) -> None: running = _sample_value( @@ -96,17 +105,20 @@ def test_gauge_values(self, scrape_output: str) -> None: assert waiting == 2.0 def test_histogram_counts(self, scrape_output: str) -> None: + # 3 successful completions (stop x2 + length x1) all observe e2e/queue; + # the 1 failed completion only increments the Counter without observing + # the latency histogram, so the count stays at 3. e2e_count = _sample_value( scrape_output, f'vllm:omni_e2e_request_latency_seconds_count{{model_name="{_MODEL}"}}', ) - assert e2e_count == 2.0 + assert e2e_count == 3.0 queue_count = _sample_value( scrape_output, f'vllm:omni_request_queue_time_seconds_count{{model_name="{_MODEL}"}}', ) - assert queue_count == 2.0 + assert queue_count == 3.0 def test_diffusion_histogram_counts(self, scrape_output: str) -> None: for name in _DIFFUSION_METRICS: diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index d5c9284c895..dbadda51ea1 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -462,7 +462,17 @@ def _process_single_result( _pt = getattr(_fin_m, "pipeline_timings", None) or {} queue_ms = _pt.get("queue_wait_ms") queue_seconds = queue_ms / 1000.0 if queue_ms is not None else None - self.prom_metrics.request_succeeded(e2e_seconds, queue_seconds=queue_seconds) + # G6: extract finished_reason from upstream CompletionOutput + # so the per-reason completion Counter is labelled correctly. + completion_outputs = getattr(engine_outputs, "outputs", None) or [] + fr = ( + getattr(completion_outputs[0], "finish_reason", None) + if completion_outputs + else None + ) or "stop" + self.prom_metrics.request_succeeded( + e2e_seconds, queue_seconds=queue_seconds, finished_reason=fr, + ) # Modality observe (Phase 3.2). Inside the same finalize guard so # it fires once per request and inherits the try/except isolation. diff --git a/vllm_omni/metrics/definitions.py b/vllm_omni/metrics/definitions.py index e1e3c518a06..4a485043d68 100644 --- a/vllm_omni/metrics/definitions.py +++ b/vllm_omni/metrics/definitions.py @@ -33,13 +33,15 @@ # ============================================================================ NUM_REQUESTS_RUNNING = METRIC_PREFIX + "num_requests_running" NUM_REQUESTS_WAITING = METRIC_PREFIX + "num_requests_waiting" -NUM_REQUESTS_SUCCESS = METRIC_PREFIX + "num_requests_success" -NUM_REQUESTS_FAIL = METRIC_PREFIX + "num_requests_fail" E2E_REQUEST_LATENCY_SECONDS = METRIC_PREFIX + "e2e_request_latency_seconds" REQUEST_QUEUE_TIME_SECONDS = METRIC_PREFIX + "request_queue_time_seconds" -# G6: requests_success_total{finished_reason} — Pipeline 全局 Counter -REQUESTS_SUCCESS_TOTAL = METRIC_PREFIX + "requests_success_total" +# G6: per-finished_reason Counter that replaces the original +# num_requests_success / num_requests_fail pair from PR #3362. Single source +# of completion-state counting with finished_reason ∈ {stop, length, abort, ...} +# (aborts include the "fail" path that previously used a separate counter). +# Counter auto-suffixes _total at exposition time; pass without _total here. +REQUESTS_SUCCESS = METRIC_PREFIX + "requests_success" # ============================================================================ diff --git a/vllm_omni/metrics/prometheus.py b/vllm_omni/metrics/prometheus.py index b0dad9c80d3..6e471d3b4db 100644 --- a/vllm_omni/metrics/prometheus.py +++ b/vllm_omni/metrics/prometheus.py @@ -37,15 +37,12 @@ "Number of requests waiting to be scheduled.", labelnames=_labelnames, ) -_success_family = Counter( - defs.NUM_REQUESTS_SUCCESS, - "Number of requests that completed without error.", - labelnames=_labelnames, -) -_fail_family = Counter( - defs.NUM_REQUESTS_FAIL, - "Number of requests that returned an error.", - labelnames=_labelnames, +_completion_family = Counter( + defs.REQUESTS_SUCCESS, + "Total requests by completion reason " + "(stop / length / abort / ...). Aborts include the 'fail' path " + "that previously had its own num_requests_fail counter (G6).", + labelnames=list(defs.SUCCESS_LABELS), ) _e2e_latency_family = Histogram( defs.E2E_REQUEST_LATENCY_SECONDS, @@ -75,8 +72,6 @@ def __init__(self, model_name: str) -> None: self._model_name = model_name self._running = _running_family.labels(model_name=model_name) self._waiting = _waiting_family.labels(model_name=model_name) - self._success = _success_family.labels(model_name=model_name) - self._fail = _fail_family.labels(model_name=model_name) self._e2e_latency = _e2e_latency_family.labels(model_name=model_name) self._queue_time = _queue_time_family.labels(model_name=model_name) self._diffusion_by_stage: dict[tuple[str, int], Histogram] = {} @@ -87,14 +82,27 @@ def set_running(self, n: int) -> None: def set_waiting(self, n: int) -> None: self._waiting.set(n) - def request_succeeded(self, e2e_seconds: float, queue_seconds: float | None = None) -> None: - self._success.inc() + def request_succeeded( + self, + e2e_seconds: float, + queue_seconds: float | None = None, + finished_reason: str = "stop", + ) -> None: + _completion_family.labels( + model_name=self._model_name, + finished_reason=finished_reason, + ).inc() self._e2e_latency.observe(e2e_seconds) if queue_seconds is not None: self._queue_time.observe(queue_seconds) def request_failed(self) -> None: - self._fail.inc() + # Pipeline-level "fail" maps to the upstream FinishReason.ABORT bucket; + # a single counter family now covers both normal stops and aborts. + _completion_family.labels( + model_name=self._model_name, + finished_reason="abort", + ).inc() def observe_diffusion_metrics(self, stage_id: int, metrics: dict[str, float]) -> None: for key, parent in _diffusion_families.items(): From 56afe670b88a17a90bc0b97ce36e9b9cbb35aebb Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Wed, 13 May 2026 18:54:51 +0800 Subject: [PATCH 12/13] [Docs] Update metrics docs to cover G1/G2/G3/G6/G7 work PR #3362 introduced docs/usage/metrics.md and docs/design/metrics.md covering its own pipeline-level + diffusion families. The follow-up work in this branch (G1 audio, G2 image/video, G3 cross-stage transfer, G6 success/fail merge into requests_success_total{finished_reason}, G7 OmniPrometheusStatLogger per-replica wrap) wasn't reflected. Refresh both docs to the final state. usage/metrics.md (+117 / -38): - Request Tracking table swaps num_requests_success / num_requests_fail for the merged requests_success_total{finished_reason} family. - New sections for Modality (audio / image / video) and Cross-Stage Transfer with the per-modality / per-edge metric tables. - vLLM Engine Metrics section gains a before/after example showing how the G7 wrap reshapes engine -> stage + replica, with a note that the ~37 upstream families gain per-replica visibility automatically. - Diffusion Engine Metrics section clarifies that omni-side diffusion families bypass the wrap (engine label = stage_id, not relabelled). - Pipeline Type availability matrix gains modality / transfer / per-replica rows. - Naming Convention section explains the RFC "co-position, different name" three-modality TTFP convention. design/metrics.md (+260 / -74): - Objectives section calls out the per-replica + modality + transfer goals introduced after PR #3362. - Architecture diagram replaces the original two-component picture with the four-component one (OmniPrometheusMetrics + OmniModalityMetrics + OmniTransferMetrics + OmniPrometheusStatLogger). - Data Flow grows from two paths to four: Pipeline-level, Modality (finalize hook + audio_ttfp streaming hook), Cross-stage transfer (sticky-routing replica_resolver), and the wrapped per-engine path. The transfer subsection notes the TX-side hook is wired but currently inactive pending try_send_via_connector being called from the main code path. - New "OmniPrometheusStatLogger Wrap (G7)" section explains the four upstream impedance points and the three coordinated mechanisms (class-level metric class slots, per_engine_labelvalues property descriptor, _RelabelMixin.labels() override) used to handle them. Also documents the stage_replica_map construction and the deferred dynamic add/remove decision. - Metric Definitions table refreshed end-to-end: pipeline-level row now lists requests_success_total{finished_reason}; new modality and transfer subtables added; transfer subtable carries a note that model_name was added on top of the RFC-listed four stage/replica labels for cross-omni consistency. - Logging vs. Prometheus section expanded to mention OmniModality / OmniTransfer alongside OmniPrometheusMetrics, and notes that OrchestratorAggregator.record_transfer_tx/rx fans out to both the log accumulator and the Prometheus emit hook in one method body. --- docs/design/metrics.md | 334 ++++++++++++++++++++++++++++++++--------- docs/usage/metrics.md | 117 ++++++++++++--- 2 files changed, 363 insertions(+), 88 deletions(-) diff --git a/docs/design/metrics.md b/docs/design/metrics.md index dcf8b2c04d8..8f43223f1d4 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -10,9 +10,18 @@ metrics. - Expose pipeline-level request and latency metrics that span the full multi-stage execution (orchestrator scope). - Preserve all upstream vLLM per-engine metrics (`vllm:*`) for stages - backed by an AR LLM engine. + backed by an AR LLM engine, and reshape their `engine` label into + `stage` + `replica` so multi-replica deployments gain per-replica + visibility automatically. - Expose per-stage diffusion timing breakdowns for pipelines that include a diffusion engine. +- Expose per-modality SLO metrics (audio TTFP / RTF / duration / frames, + image TTFP / generation time / num, video generation time) that the + upstream `vllm:*` families do not capture (e.g. `audio_ttfp` is the + first audio packet, distinct from upstream's first audio token). +- Expose per-replica-edge cross-stage transfer metrics so the slack + between E2E latency and the sum of per-stage `gen_time` (queueing, + serialization, network) becomes attributable. - Keep the metrics collection overhead low enough that it does not regress TTFA or throughput. @@ -20,13 +29,13 @@ metrics. ### Upstream vLLM Metrics -Upstream vLLM defines 44 Prometheus metrics under the `vllm:` prefix. -These are registered by `PrometheusStatLogger` and cover engine-level -state: KV cache usage, running/waiting request counts, token -throughput, TTFT, inter-token latency, e2e latency, and so on. They -are served via the `/metrics` HTTP endpoint provided by -`prometheus_fastapi_instrumentator` and the default -`prometheus_client` WSGI handler. +Upstream vLLM defines ~37 Prometheus metric families under the `vllm:` +prefix. These are registered by `PrometheusStatLogger` and cover +engine-level state: KV cache usage, running/waiting request counts, +token throughput, TTFT, inter-token latency, e2e latency, and so on. +They are served via the `/metrics` HTTP endpoint provided by +`prometheus_fastapi_instrumentator` and the default `prometheus_client` +WSGI handler. vLLM's `unregister_vllm_metrics()` function strips every `prometheus_client` collector whose `_name` attribute contains the @@ -35,51 +44,60 @@ stale collectors from prior instantiations within the same process. ### The Problem -vLLM-Omni runs multiple engine instances (stages) within a single -process, coordinated by an Orchestrator. The pipeline needs its own -metrics — aggregate request counts, end-to-end latency across all -stages, and diffusion timing breakdowns — that do not exist in upstream -vLLM. All pipeline-level metrics use the `vllm:omni_` prefix to -distinguish them from upstream per-engine metrics. The -`unregister_vllm_metrics()` function is monkey-patched to a no-op at -import time (see `vllm_omni/patch.py`) so that these metrics are not -destroyed during engine initialization. - -Upstream per-engine metrics retain the `vllm:` prefix and are -registered by a `PrometheusStatLogger` instance that the Orchestrator -creates and feeds directly. +vLLM-Omni runs multiple engine instances (stages × replicas) within a +single process, coordinated by an Orchestrator. The pipeline needs its +own metrics — aggregate request counts, end-to-end latency across all +stages, diffusion timing breakdowns, per-modality SLO signals, and +cross-stage transfer attribution — that do not exist in upstream vLLM. +All pipeline-level metrics use the `vllm:omni_` prefix to distinguish +them from upstream per-engine metrics. The `unregister_vllm_metrics()` +function is monkey-patched to a no-op at import time (see +`vllm_omni/patch.py`) so that these metrics are not destroyed during +engine initialization. + +Upstream per-engine metrics retain the `vllm:` prefix but are now +registered by `OmniPrometheusStatLogger`, a thin subclass of upstream's +`PrometheusStatLogger` that reshapes the single `engine` label into a +`stage` + `replica` pair (see "OmniPrometheusStatLogger wrap" below). ## Architecture ### Component Overview -``` - +-----------------------+ - | API Server (FastAPI)| +```text + +------------------------+ + | API Server (FastAPI) | | GET /metrics | - +----------+------------+ - | - prometheus_client default handler - | - +-------------+-------------+ - | | - vllm:omni_* collectors vllm:* collectors - | | - +-----------+-----------+ +--------+---------+ - | OmniPrometheusMetrics | | PrometheusStatLogger | - +-----------+-----------+ +--------+---------+ - | | - OmniBase Orchestrator - (request lifecycle, (feeds SchedulerStats - diffusion timing) + IterationStats - per engine step) + +-----------+------------+ + | + prometheus_client default registry + | + +--------+--------+--------+--------+--------+ + | | + vllm:omni_* vllm:* + collectors collectors + | | + +----+--------+ +-----------+ +----------+ +-----------+ + | OmniPromet- | | OmniMod- | | OmniTra- | | OmniProm- | + | heusMetrics | | alityMet- | | nsferMe- | | etheusSt- | + | (PR#3362) | | rics (G1+ | | trics | | atLogger | + | | | G2) | | (G3) | | (G7 wrap) | + +----+--------+ +-----+-----+ +----+-----+ +----+------+ + | | | | + OmniBase OmniBase Orchestrator Orchestrator + (request life- (finalize + (record_trans- (per-(stage, + cycle, success/ streaming fer_tx/rx replica) + fail counter, hooks via hooks via scheduler/ + diffusion observe_* emit hook in iteration + timing) APIs) OrchestratorAg- stats) + gregator) ``` ### Data Flow -There are two independent paths for metric collection. +There are four independent paths for metric collection. -**Path 1: Pipeline-level metrics (`vllm:omni_*`)** +**Path 1: Pipeline-level metrics (`vllm:omni_*`, PR #3362 + G6)** `OmniPrometheusMetrics` registers Gauge, Counter, and Histogram collectors at init time. It is instantiated once per entrypoint, @@ -91,25 +109,98 @@ requests progress: simple counter incremented/decremented by the Orchestrator as it tracks requests. Waiting is derived as `total - running`. -- `request_succeeded(e2e_seconds, queue_seconds)` — recorded when a - request finishes at the final stage. +- `request_succeeded(e2e_seconds, queue_seconds=None, + finished_reason="stop")` — recorded when a request finishes at the + final stage. `finished_reason` is extracted from + `engine_outputs.outputs[0].finish_reason` (vLLM `CompletionOutput` + convention) and increments + `vllm:omni_requests_success_total{finished_reason}`. -- `request_failed()` — recorded when a request errors. +- `request_failed()` — recorded by the cleanup path when a request + exits without natural completion. Internally maps to + `finished_reason="abort"` so a single Counter family covers both + natural and aborted completion (G6). - `observe_diffusion_metrics(stage_id, metrics)` — recorded when a diffusion stage finishes. The metrics dict contains timing breakdowns (preprocess, exec, postprocess, total step time) accumulated from engine output. -**Path 2: Per-engine metrics (`vllm:*`)** - -The Orchestrator instantiates upstream vLLM's `PrometheusStatLogger` +**Path 2: Modality metrics (`vllm:omni_audio_* / image_* / video_*`, G1 + G2)** + +`OmniModalityMetrics` registers eight per-modality Histogram + Counter +families with `{model_name, stage, replica}` labels. Two observation +sites: + +- `observe_modality_at_finalize(...)` — called from + `omni_base._process_single_result` inside the existing `e2e_done` + finalize guard. Routes by `final_output_type`: + - `audio`: emits `audio_frames_total` (Counter), `audio_duration_seconds`, + `audio_rtf` (Histograms). Sample rate is resolved from + `engine_outputs.multimodal_output["audio_sample_rate"]` via + `definitions.resolve_audio_sample_rate(...)` (fallback chain mirrors + `serving_chat.py`). + - `image`: emits `image_num_total`, `image_generation_time_seconds`, + `image_ttfp_seconds`. (`image_ttfp` is observed at finalize because + the diffusion path has no intermediate image streaming — first + image equals final image.) + - `video`: emits `video_generation_time_seconds`. Note that + `video_duration_seconds` and `video_rtf` are deferred — diffusion + video pipelines (i2v / t2v / cogvideo / hunyuan / wan) expose + `num_frames` + `fps` in heterogeneous shapes and a clean abstraction + is out of scope for this iteration. + +- `observe_audio_first_packet(...)` — called from the OpenAI streaming + paths (`serving_chat.py` HTTP-SSE audio branch and + `serving_video_stream.py` WebSocket audio branch) on the first audio + packet emerging for a request. The once-per-request guard is held by + `ClientRequestState.first_audio_ts` (set on first emit). The + `request_arrival_ts` anchor is also stored in `ClientRequestState` + by `async_omni.generate()`, computed as the wall-clock time at + request entry. + +**Path 3: Cross-stage transfer metrics (`vllm:omni_transfer_*`, G3)** + +`OmniTransferMetrics` registers four Histogram families with +`{model_name, from_stage, from_replica, to_stage, to_replica}` labels. +Each observation corresponds to one physical transfer hop (one chunk +between adjacent stages), not the per-request accumulated total — so +the histograms track per-transfer distribution. + +The hook lives in `OrchestratorAggregator.record_transfer_tx` and +`record_transfer_rx`. After the existing `TransferEdgeStats` +accumulation, the aggregator calls `_emit_transfer_tx` / +`_emit_transfer_rx` which look up `from_replica` / `to_replica` via a +`replica_resolver` callback supplied by `async_omni.py`. The resolver +delegates to `stage_pool.get_bound_replica_id(request_id)` — +i.e. the orchestrator's existing sticky-routing binding (PR #2396) is +the source of truth for the per-edge replica labels. No plumbing +through `TransferEdgeStats`, `StageRequestStats`, or the connector +adapter is needed. + +Defensive fail-safe: if `transfer_emitter` or `replica_resolver` is +missing, or the resolver returns `None` for either side, the emit is +skipped silently (the underlying `TransferEdgeStats` accumulation is +unaffected). + +> The TX-side hook (`record_transfer_tx`) is wired up but only fires +> once `try_send_via_connector` is invoked from the main code path; +> until then only the RX-side families (`rx_decode_time_ms` + +> `in_flight_time_ms`) accumulate observations. + +**Path 4: Per-engine metrics (`vllm:*`, G7 wrap)** + +The Orchestrator instantiates `OmniPrometheusStatLogger` (a thin +subclass of upstream `vllm.v1.metrics.loggers.PrometheusStatLogger`) and feeds it scheduler stats and iteration stats after processing -each batch of engine outputs. This populates the standard vLLM -metrics (TTFT, token throughput, cache usage, etc.) using the same -code path as standalone vLLM. For diffusion-only pipelines that have -no AR engine, `SchedulerStats` is never produced and `vllm:*` metrics -are absent. +each batch of engine outputs. This populates the standard ~37 vLLM +metric families (TTFT, ITL, TPOT, KV cache usage, etc.) using the same +upstream code path — but with the `engine` label reshaped into +`stage` + `replica` so multi-replica deployments produce distinct +series per replica. See the next section for the wrap mechanics. + +For diffusion-only pipelines that have no AR engine, +`SchedulerStats` is never produced and `vllm:*` metrics are absent. ### Shared State Between Threads @@ -124,18 +215,80 @@ passed to the Orchestrator at construction time. ### Metric Registration and Lifecycle -All `vllm:omni_*` collectors are registered once when -`OmniPrometheusMetrics.__init__()` runs. Per-stage labels -(`model_name`, `engine`) are bound lazily on first observation to -avoid registering labels for stages that never produce data (e.g., a -diffusion pipeline has no AR stage stats). +All `vllm:omni_*` collectors are registered once when their owning +class (`OmniPrometheusMetrics` / `OmniModalityMetrics` / +`OmniTransferMetrics`) is imported. Per-`(stage, replica)` labels are +bound lazily on first observation to avoid registering label sets for +combinations that never produce data (e.g. a diffusion pipeline has +no audio metrics). The `prometheus_client` default registry holds all collectors. -FastAPI's `/metrics` endpoint serves the default registry, so both -`vllm:omni_*` and `vllm:*` metrics appear in the same scrape -response alongside `http_*` and `process_*` metrics from the +FastAPI's `/metrics` endpoint serves the default registry, so +`vllm:omni_*` and the wrapped `vllm:*` metrics appear in the same +scrape response alongside `http_*` and `process_*` metrics from the instrumentator and the Python client runtime. +## OmniPrometheusStatLogger Wrap (G7) + +Upstream `PrometheusStatLogger.__init__` hard-codes +`labelnames = ["model_name", "engine"]` as a local variable, references +it across ~37 metric-family construction sites, and uses the `engine` +label value in five different `.labels()` call shapes (kwarg with int +engine, kwarg with str engine, positional with str engine in the +middle, plus a `metrics_info["engine"] = str(...)` dict pattern). To +reshape `engine` into `stage` + `replica` without forking the entire +upstream `__init__`, the wrap uses three coordinated mechanisms: + +1. **Class-level metric class slot overrides.** + `OmniPrometheusStatLogger` overrides `_gauge_cls`, `_counter_cls`, + `_histogram_cls` (which upstream calls via `self._gauge_cls(...)` + etc.) with `_RelabelGauge` / `_RelabelCounter` / `_RelabelHistogram` + wrapper classes. These intercept the `labelnames` kwarg at metric + family creation time and replace `engine` with `("stage", "replica")`. + +2. **Property descriptor for `per_engine_labelvalues`.** Upstream + builds `self.per_engine_labelvalues = {idx: [model_name, str(idx)]}` + inside `__init__` and then captures it into a local variable for + `create_metric_per_engine` calls. By making + `per_engine_labelvalues` a Python property on the subclass, the + setter intercepts upstream's assignment and rewrites each 2-tuple + into a 3-tuple `[model_name, stage, replica]` using the + `stage_replica_map` supplied at construction time. The captured + local then sees the rewritten dict. + +3. **Override of `.labels()` on the wrapper classes.** For the five + call sites that pass `engine` directly (kwarg or positional, int or + str), `_RelabelMixin.labels()` translates the engine value back to + `(stage, replica)` via a process-level `_ENGINE_INDEX_MAP` populated + by `OmniPrometheusStatLogger.__init__`. This handles + `gauge_engine_sleep_state.labels(engine=idx, ...)`, + `counter_request_success_base.labels(model_name, str(idx), + str(reason))`, `info_gauge.labels(**metrics_info)`, etc. + +The `Orchestrator` constructs `stage_replica_map` from the static +`stage_pools` configuration at startup: + +```python +stage_replica_map = { + flat_idx: (str(stage_id), str(replica_id)) + for flat_idx, (stage_id, replica_id) in enumerate( + (s, r) + for s, pool in enumerate(stage_pools) + for r in range(pool.num_replicas) + ) +} +``` + +A reverse map `(stage_id, replica_id) -> flat_idx` is maintained on +the Orchestrator so the per-replica `record(engine_idx=...)` call site +can look up the right flat index. + +> Dynamic add/remove of replicas at runtime is intentionally out of +> scope — the upstream `PrometheusStatLogger` materializes +> per-engine_idx child metrics at init time, and supporting hot-add +> would require non-trivial intervention into upstream's per-family +> child dictionaries. + ## Throttling: `make_stats()` Override Upstream vLLM's `Scheduler.make_stats()` runs on every AR generation step, @@ -159,11 +312,40 @@ eliminating the per-step overhead. |--------|------|--------|-------------| | `vllm:omni_num_requests_running` | Gauge | `model_name` | Requests currently executing across all stages | | `vllm:omni_num_requests_waiting` | Gauge | `model_name` | Requests queued but not yet scheduled | -| `vllm:omni_num_requests_success` | Counter | `model_name` | Requests completed without error | -| `vllm:omni_num_requests_fail` | Counter | `model_name` | Requests that returned an error | +| `vllm:omni_requests_success_total` | Counter | `model_name`, `finished_reason` | Total requests by completion reason ({stop, length, abort, ...}); aborts include the previous "fail" path (G6) | | `vllm:omni_e2e_request_latency_seconds` | Histogram | `model_name` | End-to-end request latency across all stages | | `vllm:omni_request_queue_time_seconds` | Histogram | `model_name` | Time spent waiting in the request queue | +### Modality (G1 + G2) + +| Metric | Type | Labels | Description | +|--------|------|--------|-------------| +| `vllm:omni_audio_ttfp_seconds` | Histogram | `model_name`, `stage`, `replica` | Time from request arrival to first audio packet (streaming hook) | +| `vllm:omni_audio_duration_seconds` | Histogram | same | Audio content duration (`audio_frames / sample_rate`) | +| `vllm:omni_audio_rtf` | Histogram | same | Real-time factor `stage_gen_time_s / audio_duration_s` (RFC SLO `< 1`) | +| `vllm:omni_audio_frames_total` | Counter | same | Cumulative audio frames generated | +| `vllm:omni_image_ttfp_seconds` | Histogram | same | Time from request arrival to image emission | +| `vllm:omni_image_num_total` | Counter | same | Cumulative images generated | +| `vllm:omni_image_generation_time_seconds` | Histogram | same | Per-request image stage generation time | +| `vllm:omni_video_generation_time_seconds` | Histogram | same | Per-request video stage generation time | + +### Cross-Stage Transfer (G3) + +Labels: `{model_name, from_stage, from_replica, to_stage, to_replica}`. + +> `model_name` is included on the transfer family for consistency with +> the rest of the omni surface (audio_*, image_*, video_*, num_requests_*), +> even though RFC §3.2.6 originally listed only the four +> stage/replica labels. PromQL joins on `model_name` work uniformly +> across modality and transfer families. + +| Metric | Type | Description | +|--------|------|-------------| +| `vllm:omni_transfer_size_bytes` | Histogram | Per-transfer payload size in bytes | +| `vllm:omni_transfer_tx_time_ms` | Histogram | Sender-side time (serialize + submit to connector) | +| `vllm:omni_transfer_rx_decode_time_ms` | Histogram | Receiver-side time (recv + deserialize) | +| `vllm:omni_transfer_in_flight_time_ms` | Histogram | Network in-flight time (TX done → RX recv start) | + ### Diffusion Stage-Level | Metric | Type | Labels | Description | @@ -173,11 +355,19 @@ eliminating the per-step overhead. | `vllm:omni_diffusion_postprocess_time_ms` | Histogram | `model_name`, `engine` | Diffusion output postprocessing time | | `vllm:omni_diffusion_step_time_ms` | Histogram | `model_name`, `engine` | Total diffusion step time | -### LLM Stage-Level +> The diffusion families bypass the `OmniPrometheusStatLogger` wrap, so +> their `engine` label is the diffusion stage_id (not relabelled to +> `stage` + `replica`). -Reference [vLLM docs](https://github.com/vllm-project/vllm/blob/main/docs/usage/metrics.md) +### LLM Stage-Level (wrapped `vllm:*`) -Note that metrics that depend upon features that are not supported in vLLM-Omni (e.g. speculative decoding, LoRA) will not be available as well. +After the G7 wrap, every upstream `vllm:*` family — TTFT, ITL, TPOT, +e2e latency, KV cache usage, scheduler running/waiting, request +success counts, etc. — carries `{model_name, stage, replica}` labels. +For the full upstream catalog see +[the vLLM docs](https://github.com/vllm-project/vllm/blob/main/docs/usage/metrics.md); +note that metrics depending on features unsupported in vLLM-Omni +(e.g. speculative decoding, LoRA) will not be available. ## Logging vs. Prometheus @@ -187,10 +377,14 @@ per-stage, and per-transfer statistics and prints formatted tables to the `INFO` log. This is designed for development and debugging — individual request traces, transfer bandwidth, inter-stage timing. -`OmniPrometheusMetrics` is the Prometheus-oriented path. It records -aggregate counters, gauges, and histograms suitable for time-series -monitoring and alerting. The two paths are independent; both can run -simultaneously. +`OmniPrometheusMetrics` / `OmniModalityMetrics` / `OmniTransferMetrics` +form the Prometheus-oriented path. They record aggregate counters, +gauges, and histograms suitable for time-series monitoring and +alerting. Both paths share the same source data (`StageRequestStats`, +`TransferEdgeStats`) — `OrchestratorAggregator.record_transfer_tx/rx` +in particular calls both the existing accumulator code and the +Prometheus emit hook in the same method body. The two consumption +models can run simultaneously without coupling. The separation follows upstream vLLM's pattern of `LoggingStatLogger` vs. `PrometheusStatLogger` — same underlying data, different diff --git a/docs/usage/metrics.md b/docs/usage/metrics.md index 60e7193288b..25c0daac57d 100644 --- a/docs/usage/metrics.md +++ b/docs/usage/metrics.md @@ -13,8 +13,8 @@ curl http://localhost:8000/metrics | Prefix | Source | Present when | |--------|--------|--------------| -| `vllm:omni_` | vLLM-Omni orchestrato / diffusion stages | Always / Pipeline includes a diffusion stage | -| `vllm:` | Upstream vLLM engine | Pipeline includes an LLM (AR) stage | +| `vllm:omni_` | vLLM-Omni orchestrator / diffusion stages / modality / transfer | Always / pipeline-dependent | +| `vllm:` | Upstream vLLM engine, wrapped by `OmniPrometheusStatLogger` to expose `{stage, replica}` | Pipeline includes an LLM (AR) stage | | `http_` / `process_` | Uvicorn / Python runtime | Always | ## Pipeline-Level Metrics (`vllm:omni_`) @@ -28,8 +28,7 @@ request lifecycle across the full multi-stage pipeline. |--------|------|--------|-------------| | `vllm:omni_num_requests_running` | Gauge | `model_name` | Requests currently running across all pipeline stages | | `vllm:omni_num_requests_waiting` | Gauge | `model_name` | Requests waiting to be scheduled | -| `vllm:omni_num_requests_success` | Counter | `model_name` | Requests that completed without error | -| `vllm:omni_num_requests_fail` | Counter | `model_name` | Requests that returned an error | +| `vllm:omni_requests_success_total` | Counter | `model_name`, `finished_reason` | Total requests by completion reason. `finished_reason` ∈ {`stop`, `length`, `abort`, ...} mirroring upstream `vllm:request_success_total`; aborts include the previous "fail" path | ### Latency @@ -38,10 +37,67 @@ request lifecycle across the full multi-stage pipeline. | `vllm:omni_e2e_request_latency_seconds` | Histogram | `model_name` | End-to-end request latency in seconds | | `vllm:omni_request_queue_time_seconds` | Histogram | `model_name` | Time spent waiting in the request queue | +## Modality Metrics (`vllm:omni_`) + +Per-modality business-semantic histograms emitted at request finalize (or at +first-packet time for `audio_ttfp_seconds`). All carry +`{model_name, stage, replica}` labels. + +### Audio (talker stage) + +| Metric | Type | Description | +|--------|------|-------------| +| `vllm:omni_audio_ttfp_seconds` | Histogram | Time from request arrival to first audio packet (streaming hook) | +| `vllm:omni_audio_duration_seconds` | Histogram | Generated audio content duration (`audio_frames / sample_rate`) | +| `vllm:omni_audio_rtf` | Histogram | Real-time factor `stage_gen_time_s / audio_duration_s`; SLO red line `< 1` | +| `vllm:omni_audio_frames_total` | Counter | Cumulative audio frames generated; throughput via `rate()` | + +### Image (diffusion stage) + +| Metric | Type | Description | +|--------|------|-------------| +| `vllm:omni_image_ttfp_seconds` | Histogram | Time from request arrival to image emission (degenerates to `image_generation_time` when no intermediate streaming) | +| `vllm:omni_image_num_total` | Counter | Cumulative images generated | +| `vllm:omni_image_generation_time_seconds` | Histogram | Per-request image stage generation time (image has no RTF — no content duration) | + +### Video (diffusion stage) + +| Metric | Type | Description | +|--------|------|-------------| +| `vllm:omni_video_generation_time_seconds` | Histogram | Per-request video stage generation time | + +> `video_duration_seconds` and `video_rtf` are deferred — diffusion video +> pipelines (i2v / t2v / cogvideo / hunyuan / wan) expose `num_frames` + `fps` +> in heterogeneous shapes and a clean abstraction is out of scope for this +> iteration. + +## Cross-Stage Transfer Metrics (`vllm:omni_`) + +Per-physical-transfer histograms tracking the data hop between adjacent +stages. Labels `{model_name, from_stage, from_replica, to_stage, to_replica}` +let dashboards attribute latency to specific replica edges. `from_replica` / +`to_replica` are resolved from the orchestrator's sticky-routing binding +(`stage_pool.get_bound_replica_id(request_id)`), so no extra plumbing through +`TransferEdgeStats` is needed. + +| Metric | Type | Description | +|--------|------|-------------| +| `vllm:omni_transfer_size_bytes` | Histogram | Per-transfer payload size in bytes | +| `vllm:omni_transfer_tx_time_ms` | Histogram | Sender-side time (serialize + submit to connector) | +| `vllm:omni_transfer_rx_decode_time_ms` | Histogram | Receiver-side time (recv + deserialize) | +| `vllm:omni_transfer_in_flight_time_ms` | Histogram | Network in-flight time (TX done → RX recv start) | + +> The TX-side observe path (`record_transfer_tx`) is already wired but only +> fires once the connector adapter (`try_send_via_connector`) is invoked from +> the main code path; until then only the RX-side families +> (`rx_decode_time_ms` + `in_flight_time_ms`) are populated. + ## Diffusion Engine Metrics (`vllm:omni_`) These histograms are populated only when the pipeline includes a diffusion -stage (e.g. image or video generation models). +stage. The `engine` label here is the diffusion stage_id (omni-side +families bypass the `OmniPrometheusStatLogger` wrap, so they retain the +original `engine` label rather than being relabelled to `stage` + `replica`). | Metric | Type | Labels | Description | |--------|------|--------|-------------| @@ -53,27 +109,52 @@ stage (e.g. image or video generation models). ## vLLM Engine Metrics (`vllm:`) When the pipeline includes an LLM stage, the upstream vLLM engine exposes its -full set of metrics under the `vllm:` prefix. These are registered by -`vllm.v1.metrics.loggers.PrometheusStatLogger` and cover scheduler state, -token throughput, cache utilization, and request latencies. +full set of ~37 metric families under the `vllm:` prefix. + +vLLM-Omni wraps the upstream `vllm.v1.metrics.loggers.PrometheusStatLogger` +with `OmniPrometheusStatLogger` so that the original `engine` single label +is reshaped into `stage` + `replica`. Every `vllm:*` family — TTFT, ITL, +TPOT, e2e latency, KV cache usage, scheduler running/waiting, request +success counts, etc. — therefore gains per-`(stage, replica)` visibility +automatically. No omni-side duplicate is needed for the text path. + +```text +# Before wrap (PR #3362): +vllm:num_requests_running{model_name="...", engine="1"} 3.0 + +# After wrap (this branch): +vllm:num_requests_running{model_name="...", stage="1", replica="0"} 2.0 +vllm:num_requests_running{model_name="...", stage="1", replica="1"} 1.0 +``` -For a full overview of vLLM metrics, consult [the vLLM docs](https://github.com/vllm-project/vllm/blob/main/docs/usage/metrics.md) +For the full list of upstream metrics, see +[the vLLM docs](https://github.com/vllm-project/vllm/blob/main/docs/usage/metrics.md). ## Metric Availability by Pipeline Type | Metric group | Multi-stage LLM (Qwen3-Omni) | Diffusion-only (Z-Image-Turbo) | |---|---|---| -| `vllm:omni_` request tracking | Yes | Yes | -| `vllm:omni_` latency | Yes | Yes | -| `vllm:omni_` KV cache | Yes | No | -| `vllm:omni_` diffusion timing | Only if pipeline has a diffusion stage | Yes | -| `vllm:` engine metrics | Yes | No | +| `vllm:omni_` request tracking + latency | Yes | Yes | +| `vllm:omni_` audio modality | If pipeline has a talker stage | No | +| `vllm:omni_` image / video modality | If pipeline has a diffusion stage | Yes | +| `vllm:omni_` transfer | If pipeline has ≥ 2 stages | No | +| `vllm:omni_` diffusion timing | If pipeline has a diffusion stage | Yes | +| `vllm:` engine metrics (per `(stage, replica)`) | Yes | No | | `vllm:` MFU metrics | With `--enable-mfu-metrics` | No | ## Naming Convention -vLLM-Omni pipeline metrics use the `vllm:omni_` prefix to distinguish -them from upstream per-engine `vllm:` metrics. The upstream +vLLM-Omni pipeline metrics use the `vllm:omni_` prefix to distinguish them +from upstream per-engine `vllm:` metrics. The upstream `unregister_vllm_metrics()` function is monkey-patched to a no-op (see -`vllm_omni/patch.py`) so that these metrics are not destroyed during -engine initialization. +`vllm_omni/patch.py`) so that these metrics are not destroyed during engine +initialization. + +For the audio / image / video families, the RFC convention is "co-position, +different name": each modality's time-to-first-output uses a distinct name +(`vllm:time_to_first_token_seconds` for text — reused from upstream; +`vllm:omni_audio_ttfp_seconds` for audio; `vllm:omni_image_ttfp_seconds` +for image) rather than a single metric with a `modality` label. The three +modalities differ in unit semantics (text token vs. audio packet vs. image +emission) and typical latency magnitudes, so independent histogram buckets +fit each modality better. From f593f4a7d8e783ead410e06394b51b358657e1f1 Mon Sep 17 00:00:00 2001 From: LHXuuu Date: Fri, 15 May 2026 12:00:09 +0800 Subject: [PATCH 13/13] [Metrics] Fix OmniPrometheusStatLogger crash from helper-class label mismatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2.4 — two bugs in the G7 wrap surfaced when bringing the server up against a real PrometheusStatLogger / qwen3-omni deployment, both caused by paths the unit tests didn't reach against the stub upstream: 1. Helper-class label-count mismatch (orchestrator init crash with `ValueError: Incorrect label count`). Upstream PrometheusStatLogger.__init__ instantiates three sub-helper collectors at loggers.py:438-446: self.spec_decoding_prom = self._spec_decoding_cls( vllm_config.speculative_config, labelnames, per_engine_labelvalues) self.kv_connector_prom = self._kv_connector_cls( vllm_config, labelnames, per_engine_labelvalues) self.perf_metrics_prom = self._perf_metrics_cls( vllm_config, labelnames, per_engine_labelvalues) Each helper takes raw `labelnames` as a constructor argument and builds its internal Counter/Gauge/Histogram families via its own class-level `_counter_cls` / `_gauge_cls` / `_histogram_cls` slots. The slot overrides on OmniPrometheusStatLogger only reach families created via *its* slots, so the helpers still build 2-label families and then crash when create_metric_per_engine fans the rewritten 3-element per_engine_labelvalues into them. Fix: subclass each helper (`_OmniPerfMetricsProm`, `_OmniSpecDecodingProm`, `_OmniKVConnectorProm`) and override their cls slots to the relabel-mixin wrappers, then assign these subclasses to `_perf_metrics_cls` / `_spec_decoding_cls` / `_kv_connector_cls` on OmniPrometheusStatLogger so the helpers construct families through the same labelname-rewrite path. 2. Double-rewrite in mixin's positional-args path (also surfaces as `Incorrect label count` after the helpers were patched). Phase 2.2b's per_engine_labelvalues property setter already rewrites each entry from [model_name, str(idx)] to [model_name, stage, replica]. create_metric_per_engine then fans those into `metric.labels(*values)` — i.e. the wrapper's positional .labels() path receives 3 already-rewritten values. The old splice logic would then re-interpret args[engine_label_index=1] (now "stage") as an engine_idx and splice (stage, replica) again, blowing the label count to 4. Fix: detect `len(args) == len(self._labelnames)` at the start of the positional path and short-circuit to passthrough — the caller has already shaped the values to the rewritten family. The legacy path (e.g. counter_request_success.labels(model_name, str(idx), reason) at loggers.py:679) keeps working because its arg count is short by one relative to the rewritten labelnames, so the splice still triggers as before. Tests: - New `TestHelperClassWraps` asserts the four cls-slot overrides on OmniPrometheusStatLogger and the Counter/Gauge/Histogram slots on each helper subclass. - New `TestDoubleRewriteGuard` covers both branches: a 2-label original family receiving 3 already-rewritten values (passthrough) and a 3-label original (engine in middle) receiving 3 ORIGINAL-shape values (splice path still fires). - Updated the docstring on `test_labels_positional_passthrough` to call out that it specifically exercises the new guard. --- tests/metrics/test_stat_logger.py | 111 +++++++++++++++++++++++++++++- vllm_omni/metrics/stat_logger.py | 63 ++++++++++++++++- 2 files changed, 171 insertions(+), 3 deletions(-) diff --git a/tests/metrics/test_stat_logger.py b/tests/metrics/test_stat_logger.py index 16bfe3e4697..ae6ecd6063b 100644 --- a/tests/metrics/test_stat_logger.py +++ b/tests/metrics/test_stat_logger.py @@ -118,8 +118,13 @@ def test_labels_kwarg_translated(self, registry): ) def test_labels_positional_passthrough(self, registry): - # Phase 2.2's per_engine_labelvalues setter feeds positional 3-tuples; - # our mixin must not mangle positional .labels() calls. + # Phase 2.4 double-rewrite guard: Phase 2.2b's per_engine_labelvalues + # setter rewrites the values to 3-tuple [model_name, stage, replica] + # BEFORE create_metric_per_engine fans them into .labels(*values). The + # mixin must detect that args length already matches the rewritten + # labelnames and pass through, otherwise it would re-interpret + # args[engine_label_index] as an engine_idx and splice (stage, replica) + # again, blowing label count to 4. g = _RelabelGauge( name="omni_test_gauge_pos", documentation="test", @@ -379,3 +384,105 @@ def test_init_populates_engine_index_map(self): assert dict(_ENGINE_INDEX_MAP) == srm assert 99 not in _ENGINE_INDEX_MAP # old entry was cleared + + +# --------------------------------------------------------------------------- +# Phase 2.4 — helper-class wraps for upstream's spec_decoding / kv_connector +# / perf_metrics sub-collectors. Without these, OmniPrometheusStatLogger +# crashes at startup with `Incorrect label count` because each helper builds +# its internal Counter/Gauge/Histogram families with raw 2-element labelnames +# (passed via constructor arg) while consuming the rewritten 3-element +# per_engine_labelvalues from the property descriptor. +# --------------------------------------------------------------------------- + + +from vllm_omni.metrics.stat_logger import ( + _OmniKVConnectorProm, + _OmniPerfMetricsProm, + _OmniSpecDecodingProm, +) + + +class TestHelperClassWraps: + def test_perf_metrics_wrap_routes_through_relabel_counter(self): + assert _OmniPerfMetricsProm._counter_cls is _RelabelCounter + + def test_spec_decoding_wrap_routes_through_relabel_counter(self): + assert _OmniSpecDecodingProm._counter_cls is _RelabelCounter + + def test_kv_connector_wrap_routes_through_all_three_relabel_classes(self): + # KVConnector lets each connector build any of Gauge/Counter/Histogram, + # so all three slots must be intercepted. + assert _OmniKVConnectorProm._gauge_cls is _RelabelGauge + assert _OmniKVConnectorProm._counter_cls is _RelabelCounter + assert _OmniKVConnectorProm._histogram_cls is _RelabelHistogram + + def test_omni_logger_slots_point_to_helper_subclasses(self): + # Upstream's PrometheusStatLogger.__init__ instantiates each sub-helper + # via `self.__cls(...)`, so the slot overrides on the omni + # subclass are what routes through to the relabel mixin. + assert OmniPrometheusStatLogger._perf_metrics_cls is _OmniPerfMetricsProm + assert OmniPrometheusStatLogger._spec_decoding_cls is _OmniSpecDecodingProm + assert OmniPrometheusStatLogger._kv_connector_cls is _OmniKVConnectorProm + + +# --------------------------------------------------------------------------- +# Phase 2.4 double-rewrite guard. The mixin's positional-args path used to +# unconditionally splice (stage, replica) at engine_label_index. After Phase +# 2.2b started rewriting per_engine_labelvalues to 3-tuples *before* feeding +# them into create_metric_per_engine, that splice ran a second time on the +# already-rewritten values, blowing the label count to 4. The guard now +# detects len(args) == len(self._labelnames) and short-circuits to passthrough. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fresh_registry(): + from prometheus_client import CollectorRegistry as _R + + return _R() + + +class TestDoubleRewriteGuard: + def test_pre_rewritten_3tuple_passes_through(self, fresh_registry): + # 2-label original → 3-label rewritten family. Caller passes 3 values + # (the rewritten shape) and they should land verbatim, not get + # re-spliced. + _ENGINE_INDEX_MAP.clear() + _ENGINE_INDEX_MAP[0] = ("0", "0") + _ENGINE_INDEX_MAP[1] = ("1", "0") + g = _RelabelGauge( + name="dr_pre_rewritten", + documentation="t", + labelnames=["model_name", "engine"], + registry=fresh_registry, + ) + # 3 positional args matching the rewritten 3-label family. + g.labels("m", "1", "0").set(42) + out = generate_latest(fresh_registry).decode() + assert ( + 'dr_pre_rewritten{model_name="m",replica="0",stage="1"} 42.0' in out + ) + + def test_legacy_2tuple_with_extra_label_still_splices(self, fresh_registry): + # 3-label original (engine in middle) → 4-label rewritten family. + # Caller passes 3 values matching the ORIGINAL labelnames (the + # gauge_waiting_by_reason / counter_request_success pattern from + # upstream loggers.py:646, 679). The mixin must splice + # (stage, replica) at engine's position to reach the 4-label family. + _ENGINE_INDEX_MAP.clear() + _ENGINE_INDEX_MAP[1] = ("1", "0") + c = _RelabelCounter( + name="dr_legacy_with_extra", + documentation="t", + labelnames=["model_name", "engine", "reason"], + registry=fresh_registry, + ) + # 3 positional args matching ORIGINAL labelnames (model_name, + # engine_str, reason). + c.labels("m", "1", "stop").inc(3) + out = generate_latest(fresh_registry).decode() + assert ( + 'dr_legacy_with_extra_total{model_name="m",reason="stop",replica="0",stage="1"} 3.0' + in out + ) diff --git a/vllm_omni/metrics/stat_logger.py b/vllm_omni/metrics/stat_logger.py index 0154856883c..c1f7fc38edf 100644 --- a/vllm_omni/metrics/stat_logger.py +++ b/vllm_omni/metrics/stat_logger.py @@ -13,7 +13,10 @@ from prometheus_client import Counter, Gauge, Histogram from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorProm from vllm.v1.metrics.loggers import PrometheusStatLogger +from vllm.v1.metrics.perf import PerfMetricsProm +from vllm.v1.spec_decode.metrics import SpecDecodingProm # Process-wide translation table written by OmniPrometheusStatLogger at init. # Keys are flat engine_idx values (as upstream PrometheusStatLogger sees them); @@ -94,7 +97,23 @@ def __init__(self, *args, **kwargs): def labels(self, *args, **kwargs): if self._engine_label_index >= 0: if args: - # Positional form: replace args[engine_idx] with (stage, replica). + # Positional form. There are TWO upstream patterns: + # + # (a) Pre-rewritten path: create_metric_per_engine fans + # `per_engine_labelvalues` (already a 3-tuple + # [model_name, stage, replica] thanks to the property- + # descriptor setter on OmniPrometheusStatLogger) into + # `metric.labels(*values)`. len(args) matches the + # rewritten label set already, so just pass through. + # + # (b) Legacy 2-tuple path: upstream sites like + # `counter_request_success.labels(model_name, str(idx), + # str(reason))` pass values shaped to the *original* + # labelnames (engine still present at idx). Here + # len(args) is short by 1 — splice (stage, replica) + # in place of the engine value at engine_label_index. + if len(args) == len(self._labelnames): + return super().labels(*args, **kwargs) idx = self._engine_label_index if idx < len(args): stage, replica = _engine_to_stage_replica(args[idx]) @@ -118,6 +137,41 @@ class _RelabelHistogram(_RelabelMixin, Histogram): pass +# ---------------------------------------------------------------------------- +# Helper-class wraps for the three sub-metric collectors that upstream +# PrometheusStatLogger constructs in its __init__ (loggers.py:438-446): +# +# self.spec_decoding_prom = self._spec_decoding_cls(...) +# self.kv_connector_prom = self._kv_connector_cls(...) +# self.perf_metrics_prom = self._perf_metrics_cls(...) +# +# Each helper receives raw `labelnames` as a constructor argument and uses +# its own class-level `_counter_cls` / `_gauge_cls` / `_histogram_cls` slots +# to build internal Counter/Gauge/Histogram families. The slot overrides on +# OmniPrometheusStatLogger only reach families created via *its* slots, so +# the helpers would otherwise still construct 2-label families and then hit +# `Incorrect label count` when create_metric_per_engine feeds the rewritten +# 3-element per_engine_labelvalues. Subclassing each helper and overriding +# its slots routes the relabel mixin through to the helper-internal families +# too. The helper kept seeing the OLD 2-element labelnames param, but that +# is fine because the wrapper rewrites it at family-creation time. +# ---------------------------------------------------------------------------- + + +class _OmniPerfMetricsProm(PerfMetricsProm): + _counter_cls = _RelabelCounter + + +class _OmniSpecDecodingProm(SpecDecodingProm): + _counter_cls = _RelabelCounter + + +class _OmniKVConnectorProm(KVConnectorProm): + _gauge_cls = _RelabelGauge + _counter_cls = _RelabelCounter + _histogram_cls = _RelabelHistogram + + class OmniPrometheusStatLogger(PrometheusStatLogger): """Wrap upstream PrometheusStatLogger to expose per-(stage, replica) labels. @@ -136,6 +190,13 @@ class OmniPrometheusStatLogger(PrometheusStatLogger): _gauge_cls = _RelabelGauge _counter_cls = _RelabelCounter _histogram_cls = _RelabelHistogram + # Inject helper-class wraps too so the perf / spec-decoding / kv-connector + # sub-collectors get the same labelname rewrite and don't crash with + # `Incorrect label count` when create_metric_per_engine fans out the + # rewritten 3-element per_engine_labelvalues over their internal families. + _perf_metrics_cls = _OmniPerfMetricsProm + _spec_decoding_cls = _OmniSpecDecodingProm + _kv_connector_cls = _OmniKVConnectorProm def __init__( self,