Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/metrics/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from vllm_omni.metrics.loggers import OmniStatLoggerBase, OmniStatLoggerManager
from vllm_omni.metrics.stats import OrchestratorAggregator


def test_orchestrator_aggregator_builds_summary() -> None:
agg = OrchestratorAggregator(num_stages=2, enable_debug_events=False, wall_start_ts=0.0)
agg.set_final_stage_map({"r1": 1})
agg.stage_first_ts[0] = 0.0
agg.stage_last_ts[0] = 0.03
agg.stage_first_ts[1] = 0.05
agg.stage_last_ts[1] = 0.07

agg.on_forward(0, 1, "r1", size_bytes=1024, tx_ms=5.0, used_shm=False)
agg.on_stage_metrics(
0,
"r1",
{
"num_tokens_in": 3,
"num_tokens_out": 3,
"stage_gen_time_ms": 30.0,
"batch_id": 1,
"batch_size": 1,
"rx_transfer_bytes": 0,
"rx_decode_time_ms": 0.0,
},
)
agg.on_stage_metrics(
1,
"r1",
{
"num_tokens_out": 4,
"stage_gen_time_ms": 20.0,
"batch_id": 1,
"batch_size": 1,
"rx_transfer_bytes": 1024,
"rx_decode_time_ms": 5.0,
"rx_in_flight_time_ms": 2.0,
},
)
agg.on_finalize_request(1, "r1", req_start_ts=0.0)

summary = agg.build_run_summary()
data = summary.to_dict()
assert data["e2e_requests"] == 1
assert len(data["stages"]) == 2
assert data["stages"][0]["requests"] == 1
assert data["transfers"][0]["samples"] == 1
assert data["transfers"][0]["total_mbps"] >= 0.0


class _DummyLogger(OmniStatLoggerBase):
def __init__(self, interval_s: float = 0.0) -> None:
super().__init__(interval_s=interval_s)
self.logged: list[dict] = []

def log(self, summary) -> None: # type: ignore[override]
self.logged.append(summary.to_dict())


def test_logger_manager_triggers_logging_on_interval() -> None:
agg = OrchestratorAggregator(num_stages=1, enable_debug_events=False, wall_start_ts=0.0)
agg.set_final_stage_map({"r": 0})
dummy_logger = _DummyLogger(interval_s=0.0)
mgr = OmniStatLoggerManager(
aggregator=agg,
loggers=[dummy_logger],
final_stage_map_provider=lambda: agg.final_stage_map,
)
agg.set_logger_manager(mgr)
agg.stage_first_ts[0] = 0.0
agg.stage_last_ts[0] = 0.01
agg.on_stage_metrics(
0,
"r",
{"num_tokens_out": 1, "stage_gen_time_ms": 1.0, "batch_id": 1, "batch_size": 1, "rx_transfer_bytes": 0},
)
agg.on_finalize_request(0, "r", req_start_ts=0.0)
assert dummy_logger.logged, "logger manager should emit summary when interval is reached"
30 changes: 24 additions & 6 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
get_final_stage_id_for_e2e,
)
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.metrics.loggers import OmniLoggingStatLogger, OmniStatLoggerManager
from vllm_omni.metrics.prometheus import OmniPrometheusStatLogger

logger = init_logger(__name__)

Expand Down Expand Up @@ -308,19 +310,30 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
_req_start_ts: dict[int, float] = {}
_wall_start_ts: float = time.time()
# _last_finish_ts: float = _wall_start_ts
stat_logger_manager: OmniStatLoggerManager | None = None

# Determine the final stage for E2E stats (highest stage_id with
# final_output=True; fallback to last stage)
final_stage_id_for_e2e = get_final_stage_id_for_e2e(
final_stage_id = get_final_stage_id_for_e2e(
output_modalities, self.output_modalities, self.stage_list
)
final_stage_id_to_prompt = {str(request_id): final_stage_id}

# Metrics/aggregation helper
metrics = OrchestratorMetrics(
num_stages,
self._enable_stats,
_wall_start_ts,
)
metrics.set_final_stage_map(final_stage_id_to_prompt)
stat_logger_manager = OmniStatLoggerManager(
aggregator=metrics,
loggers=[
OmniLoggingStatLogger(interval_s=10.0, enabled=True),
OmniPrometheusStatLogger(interval_s=10.0, enabled=True),
],
final_stage_map_provider=lambda: metrics.final_stage_map,
)
# Seed stage-0 queue with all requests
logger.debug(f"[{self._name}] Seeding request into stage-0")
req_state = ClientRequestState(request_id)
Expand All @@ -340,7 +353,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
logger.debug(f"[{self._name}] Enqueued request {request_id} to stage-0")

logger.debug(f"[{self._name}] Entering scheduling loop: stages={num_stages}")
for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]):
for stage_id, stage in enumerate(self.stage_list[: final_stage_id + 1]):
finished = False
while not finished:
result = await req_state.queue.get()
Expand Down Expand Up @@ -387,7 +400,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
# (only once per request at the designated final stage)
try:
rid_key = str(req_id)
if stage_id == final_stage_id_for_e2e and rid_key not in metrics.e2e_done and finished:
if stage_id == final_stage_id and rid_key not in metrics.e2e_done and finished:
metrics.on_finalize_request(
stage_id,
req_id,
Expand Down Expand Up @@ -423,7 +436,7 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
stage.set_engine_outputs(engine_outputs)
# Forward to next stage if there is one
next_stage_id = stage_id + 1
if next_stage_id <= final_stage_id_for_e2e and finished:
if next_stage_id <= final_stage_id and finished:
next_stage: OmniStage = self.stage_list[next_stage_id]
next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt)
sp_next: SamplingParams = sampling_params_list[next_stage_id]
Expand Down Expand Up @@ -465,8 +478,13 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator

# Summarize and print stats
try:
summary = metrics.build_and_log_summary(final_stage_id_for_e2e)
logger.info("[Summary] %s", pformat(summary, sort_dicts=False))
summary_dict = None
if stat_logger_manager:
summary_obj = stat_logger_manager.force_log()
summary_dict = summary_obj.to_dict()
if summary_dict is None:
summary_dict = metrics.build_summary(final_stage_id_to_prompt)
logger.info("[Summary] %s", pformat(summary_dict, sort_dicts=False))
except Exception as e:
logger.exception(f"[{self._name}] Failed to build/log summary: {e}")
finally:
Expand Down
Loading
Loading