Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/engine/test_async_omni_engine_do_log_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Guard tests for AsyncOmniEngine.do_log_stats edge cases.

These are pure-Python tests that bypass __init__ and only exercise the
no-op branches of do_log_stats, so no stage cores / threads are needed.
"""

import asyncio

import pytest

from vllm_omni.engine.async_omni_engine import AsyncOmniEngine

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


def _make_bare_engine() -> AsyncOmniEngine:
# Bypass __init__ so we don't spin up stage cores; we only need the
# attributes do_log_stats touches.
return AsyncOmniEngine.__new__(AsyncOmniEngine)


@pytest.mark.asyncio
async def test_do_log_stats_noop_when_manager_missing():
engine = _make_bare_engine()
engine.logger_manager = None
engine.orchestrator_loop = None
await engine.do_log_stats() # should silently return


@pytest.mark.asyncio
async def test_do_log_stats_noop_when_loop_missing():
engine = _make_bare_engine()

class _Manager:
def log(self) -> None: # pragma: no cover - must not be called
raise AssertionError("log() should not be called without a loop")

engine.logger_manager = _Manager()
engine.orchestrator_loop = None
await engine.do_log_stats()


@pytest.mark.asyncio
async def test_do_log_stats_noop_when_loop_not_running():
engine = _make_bare_engine()

class _Manager:
def log(self) -> None: # pragma: no cover - must not be called
raise AssertionError("log() should not be called on a stopped loop")

dead_loop = asyncio.new_event_loop()
dead_loop.close()

engine.logger_manager = _Manager()
engine.orchestrator_loop = dead_loop
await engine.do_log_stats()
2 changes: 2 additions & 0 deletions tests/engine/test_async_omni_engine_stage_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_initialize_stages_restores_device_visibility_after_diffusion_init(monke
from vllm_omni.platforms import current_omni_platform

engine = object.__new__(AsyncOmniEngine)
engine.log_stats = False
engine.model = "dummy-model"
engine.config_path = "dummy-config"
engine.num_stages = 1
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(self, vllm_config, renderer=None):
)

engine = object.__new__(AsyncOmniEngine)
engine.log_stats = False

_stage_client, _out_proc, _vllm_cfg, input_processor = engine._attach_llm_stage(started)

Expand Down
3 changes: 3 additions & 0 deletions tests/engine/test_single_stage_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def _build_engine_skeleton(
engine.stage_configs = stage_cfgs
engine.num_stages = len(stage_cfgs)
engine.async_chunk = False
engine.log_stats = False
engine.single_stage_mode = single_stage_mode
engine._single_stage_id_filter = stage_id_filter
engine._omni_master_address = omni_master_address
Expand Down Expand Up @@ -1366,6 +1367,7 @@ class TestLaunchLlmStageSingleStageMode:
def _build_engine_with_oms(self) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.log_stats = False
engine.single_stage_mode = True
engine._single_stage_id_filter = 0
engine._llm_stage_launch_lock = threading.Lock()
Expand Down Expand Up @@ -1446,6 +1448,7 @@ def test_spawn_stage_core_used_in_normal_mode(self):
"""~single_stage_mode → spawn_stage_core + complete_stage_handshake."""
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.log_stats = False
engine.single_stage_mode = False
engine._omni_master_server = None
engine._llm_stage_launch_lock = threading.Lock()
Expand Down
58 changes: 56 additions & 2 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.metrics.loggers import StatLoggerManager

from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
Expand Down Expand Up @@ -283,6 +284,7 @@ def __init__(
self.num_stages = len(self.stage_configs)
stage0_args = getattr(self.stage_configs[0], "engine_args", None) if self.num_stages > 0 else None
self.async_chunk = bool(getattr(stage0_args, "async_chunk", False))
self.log_stats = not bool(getattr(stage0_args, "disable_log_stats", False))
Comment on lines 284 to +287

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks fine to me. If the StatLoggerManager concurrency issue has been properly resolved, I don't have other blockers.

One small nit: this seems to rely too heavily on the stage0 configuration, which feels somewhat awkward. Probably okay for now, but worth cleaning up later. cc @yinpeiqi

Also, it may be worth taking another look at the logging/stat system for the diffusion path in a follow-up as well, since it seems not fully covered by the current branch yet. @chickeyton

self.stage_clients: list[Any] = []
self.stage_vllm_configs: list[Any] = []
self.output_processors: list[MultimodalOutputProcessor | None] = []
Expand Down Expand Up @@ -412,7 +414,7 @@ def _launch_llm_stage(
addresses, proc, handshake_address = spawn_stage_core(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
log_stats=self.log_stats,
)
started_stage = StartedLlmStage(
stage_id=metadata.stage_id,
Expand Down Expand Up @@ -612,7 +614,7 @@ def _attach_llm_stage(
)
output_processor = MultimodalOutputProcessor(
tokenizer=tokenizer,
log_stats=False,
log_stats=self.log_stats,
engine_core_output_type=started.metadata.engine_output_type,
)
input_processor = None
Expand Down Expand Up @@ -866,6 +868,30 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
self.default_sampling_params_list = default_sampling_params_list
self.stage_metadata = stage_metadata

# Single StatLoggerManager for the whole pipeline, mirroring how
# vLLM AsyncLLM uses one manager with multiple engine indices for DP.
# We treat each stage as a separate "engine_idx" so logs are
# distinguishable as "Engine 000/001/002/...". Using a single manager
# also avoids PrometheusStatLogger registry collisions.
self.logger_manager: StatLoggerManager | None = None
if self.log_stats:
base_vllm_config = next(
(cfg for cfg in self.stage_vllm_configs if cfg is not None),
None,
)
if base_vllm_config is not None:
try:
self.logger_manager = StatLoggerManager(
vllm_config=base_vllm_config,
engine_idxs=list(range(self.num_stages)),
custom_stat_loggers=None,
enable_default_loggers=True,
)
self.logger_manager.log_engine_initialized()
except Exception:
logger.exception("[AsyncOmniEngine] Failed to build StatLoggerManager")
self.logger_manager = None

def _initialize_janus_queues(self) -> None:
"""Initialize janus queues inside orchestrator thread loop context."""
self.request_queue = janus.Queue()
Expand All @@ -882,6 +908,10 @@ def _bootstrap_orchestrator(

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Expose the orchestrator loop so other threads (API server) can
# schedule coroutines onto it via run_coroutine_threadsafe, keeping
# single-threaded access to StatLoggerManager (mirrors AsyncLLM).
self.orchestrator_loop = loop

async def _run_orchestrator() -> None:
self._initialize_janus_queues()
Expand All @@ -895,6 +925,7 @@ async def _run_orchestrator() -> None:
stage_clients=self.stage_clients,
output_processors=self.output_processors,
stage_vllm_configs=self.stage_vllm_configs,
logger_manager=self.logger_manager,
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
Expand Down Expand Up @@ -1450,6 +1481,29 @@ async def abort_async(self, request_ids: list[str]) -> None:
"""Async abort API."""
self.abort(request_ids)

async def do_log_stats(self) -> None:
"""Flush the StatLoggerManager on the orchestrator thread.

``StatLoggerManager`` is only safe to access from the orchestrator
loop (where ``record()`` runs). Schedule ``log()`` onto that loop
via ``run_coroutine_threadsafe`` so all access stays single-threaded,
matching upstream vLLM ``AsyncLLM``.
"""
manager = self.logger_manager
if manager is None:
return
loop = getattr(self, "orchestrator_loop", None)
if loop is None or not loop.is_running():
return

async def _log() -> None:
manager.log()

try:
await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(_log(), loop))
except Exception:
logger.exception("[AsyncOmniEngine] do_log_stats failed")

def collective_rpc(
self,
method: str,
Expand Down
26 changes: 25 additions & 1 deletion vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.loggers import StatLoggerManager
from vllm.v1.metrics.stats import IterationStats

from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length
from vllm_omni.engine import (
Expand Down Expand Up @@ -122,6 +124,7 @@ def __init__(
stage_vllm_configs: list[Any],
*,
async_chunk: bool = False,
logger_manager: StatLoggerManager | None = None,
) -> None:
self.request_async_queue = request_async_queue
self.output_async_queue = output_async_queue
Expand All @@ -133,6 +136,8 @@ def __init__(
self.stage_clients: list[Any] = stage_clients
self.output_processors: list[Any] = output_processors
self.stage_vllm_configs: list[Any] = stage_vllm_configs
self.logger_manager: StatLoggerManager | None = logger_manager

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is log_stats still useful in orchestrator? Could we just

self.log_stats = (self.logger_manager != None)

self.log_stats = self.logger_manager is not None

# Per-request state
self.request_states: dict[str, OrchestratorRequestState] = {}
Expand Down Expand Up @@ -624,10 +629,13 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
"""
processor = self.output_processors[stage_id]

num_outputs = len(raw_outputs.outputs)
iteration_stats = IterationStats() if (self.log_stats and num_outputs) else None

processed = processor.process_outputs(
raw_outputs.outputs,
raw_outputs.timestamp,
None,
iteration_stats,
)

if processed.reqs_to_abort:
Expand All @@ -636,6 +644,22 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
if raw_outputs.scheduler_stats is not None:
processor.update_scheduler_stats(raw_outputs.scheduler_stats)

# Mirror vLLM AsyncLLM output_handler: feed stats to the logger
# manager so LoggingStatLogger can periodically print KV cache /
# prefix cache hit rate, and PrometheusStatLogger can publish.
if self.logger_manager is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diffusion engine don't go into this branch. Do we have any plan for diffusion?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, no idea about diffusion logger. Reusing vLLM's logger makes this PR simple. But something like KV cache isn't appropriate to diffusion.

try:
self.logger_manager.record(
engine_idx=stage_id,
scheduler_stats=raw_outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception:
logger.exception(
"[Orchestrator] stat logger record failed for stage-%s",
stage_id,
)

return processed.request_outputs

async def _handle_add_request(self, msg: dict[str, Any]) -> None:
Expand Down
7 changes: 2 additions & 5 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,8 @@ async def is_tracing_enabled(self) -> bool:
return False

async def do_log_stats(self) -> None:
"""Log statistics.

TODO: Forward to Orchestrator process via message.
"""
pass
"""Log statistics via the engine, mirroring vLLM ``AsyncLLM``."""
await self.engine.do_log_stats()

async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
"""Return the task set exposed by the orchestrator-backed engine."""
Expand Down
Loading