Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 0 additions & 56 deletions tests/engine/test_async_omni_engine_do_log_stats.py

This file was deleted.

2 changes: 0 additions & 2 deletions tests/engine/test_async_omni_engine_stage_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_initialize_stages_restores_device_visibility_after_diffusion_init(monke
from vllm_omni.platforms import current_omni_platform

engine = object.__new__(AsyncOmniEngine)
engine.log_stats = False
engine.model = "dummy-model"
engine.config_path = "dummy-config"
engine.num_stages = 1
Expand Down Expand Up @@ -283,7 +282,6 @@ def __init__(self, vllm_config, renderer=None):
)

engine = object.__new__(AsyncOmniEngine)
engine.log_stats = False

_stage_client, _out_proc, _vllm_cfg, input_processor = engine._attach_llm_stage(started)

Expand Down
3 changes: 0 additions & 3 deletions tests/engine/test_single_stage_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def _build_engine_skeleton(
engine.stage_configs = stage_cfgs
engine.num_stages = len(stage_cfgs)
engine.async_chunk = False
engine.log_stats = False
engine.single_stage_mode = single_stage_mode
engine._single_stage_id_filter = stage_id_filter
engine._omni_master_address = omni_master_address
Expand Down Expand Up @@ -1367,7 +1366,6 @@ class TestLaunchLlmStageSingleStageMode:
def _build_engine_with_oms(self) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.log_stats = False
engine.single_stage_mode = True
engine._single_stage_id_filter = 0
engine._llm_stage_launch_lock = threading.Lock()
Expand Down Expand Up @@ -1448,7 +1446,6 @@ def test_spawn_stage_core_used_in_normal_mode(self):
"""~single_stage_mode → spawn_stage_core + complete_stage_handshake."""
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.log_stats = False
engine.single_stage_mode = False
engine._omni_master_server = None
engine._llm_stage_launch_lock = threading.Lock()
Expand Down
58 changes: 2 additions & 56 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.metrics.loggers import StatLoggerManager

from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
Expand Down Expand Up @@ -285,7 +284,6 @@ def __init__(
self.num_stages = len(self.stage_configs)
stage0_args = getattr(self.stage_configs[0], "engine_args", None) if self.num_stages > 0 else None
self.async_chunk = bool(getattr(stage0_args, "async_chunk", False))
self.log_stats = not bool(getattr(stage0_args, "disable_log_stats", False))
self.stage_clients: list[Any] = []
self.stage_vllm_configs: list[Any] = []
self.output_processors: list[MultimodalOutputProcessor | None] = []
Expand Down Expand Up @@ -415,7 +413,7 @@ def _launch_llm_stage(
addresses, proc, handshake_address = spawn_stage_core(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
log_stats=False,
)
started_stage = StartedLlmStage(
stage_id=metadata.stage_id,
Expand Down Expand Up @@ -617,7 +615,7 @@ def _attach_llm_stage(
)
output_processor = MultimodalOutputProcessor(
tokenizer=tokenizer,
log_stats=self.log_stats,
log_stats=False,
engine_core_output_type=started.metadata.engine_output_type,
)
input_processor = None
Expand Down Expand Up @@ -872,30 +870,6 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
self.default_sampling_params_list = default_sampling_params_list
self.stage_metadata = stage_metadata

# Single StatLoggerManager for the whole pipeline, mirroring how
# vLLM AsyncLLM uses one manager with multiple engine indices for DP.
# We treat each stage as a separate "engine_idx" so logs are
# distinguishable as "Engine 000/001/002/...". Using a single manager
# also avoids PrometheusStatLogger registry collisions.
self.logger_manager: StatLoggerManager | None = None
if self.log_stats:
base_vllm_config = next(
(cfg for cfg in self.stage_vllm_configs if cfg is not None),
None,
)
if base_vllm_config is not None:
try:
self.logger_manager = StatLoggerManager(
vllm_config=base_vllm_config,
engine_idxs=list(range(self.num_stages)),
custom_stat_loggers=None,
enable_default_loggers=True,
)
self.logger_manager.log_engine_initialized()
except Exception:
logger.exception("[AsyncOmniEngine] Failed to build StatLoggerManager")
self.logger_manager = None

def _initialize_janus_queues(self) -> None:
"""Initialize janus queues inside orchestrator thread loop context."""
self.request_queue = janus.Queue()
Expand All @@ -912,10 +886,6 @@ def _bootstrap_orchestrator(

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Expose the orchestrator loop so other threads (API server) can
# schedule coroutines onto it via run_coroutine_threadsafe, keeping
# single-threaded access to StatLoggerManager (mirrors AsyncLLM).
self.orchestrator_loop = loop

async def _run_orchestrator() -> None:
self._initialize_janus_queues()
Expand All @@ -929,7 +899,6 @@ async def _run_orchestrator() -> None:
stage_clients=self.stage_clients,
output_processors=self.output_processors,
stage_vllm_configs=self.stage_vllm_configs,
logger_manager=self.logger_manager,
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
Expand Down Expand Up @@ -1552,29 +1521,6 @@ async def abort_async(self, request_ids: list[str]) -> None:
"""Async abort API."""
self.abort(request_ids)

async def do_log_stats(self) -> None:
"""Flush the StatLoggerManager on the orchestrator thread.

``StatLoggerManager`` is only safe to access from the orchestrator
loop (where ``record()`` runs). Schedule ``log()`` onto that loop
via ``run_coroutine_threadsafe`` so all access stays single-threaded,
matching upstream vLLM ``AsyncLLM``.
"""
manager = self.logger_manager
if manager is None:
return
loop = getattr(self, "orchestrator_loop", None)
if loop is None or not loop.is_running():
return

async def _log() -> None:
manager.log()

try:
await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(_log(), loop))
except Exception:
logger.exception("[AsyncOmniEngine] do_log_stats failed")

def collective_rpc(
self,
method: str,
Expand Down
26 changes: 1 addition & 25 deletions vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.loggers import StatLoggerManager
from vllm.v1.metrics.stats import IterationStats

from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length
from vllm_omni.engine import (
Expand Down Expand Up @@ -124,7 +122,6 @@ def __init__(
stage_vllm_configs: list[Any],
*,
async_chunk: bool = False,
logger_manager: StatLoggerManager | None = None,
) -> None:
self.request_async_queue = request_async_queue
self.output_async_queue = output_async_queue
Expand All @@ -136,8 +133,6 @@ def __init__(
self.stage_clients: list[Any] = stage_clients
self.output_processors: list[Any] = output_processors
self.stage_vllm_configs: list[Any] = stage_vllm_configs
self.logger_manager: StatLoggerManager | None = logger_manager
self.log_stats = self.logger_manager is not None

# Per-request state
self.request_states: dict[str, OrchestratorRequestState] = {}
Expand Down Expand Up @@ -629,13 +624,10 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
"""
processor = self.output_processors[stage_id]

num_outputs = len(raw_outputs.outputs)
iteration_stats = IterationStats() if (self.log_stats and num_outputs) else None

processed = processor.process_outputs(
raw_outputs.outputs,
raw_outputs.timestamp,
iteration_stats,
None,
)

if processed.reqs_to_abort:
Expand All @@ -644,22 +636,6 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
if raw_outputs.scheduler_stats is not None:
processor.update_scheduler_stats(raw_outputs.scheduler_stats)

# Mirror vLLM AsyncLLM output_handler: feed stats to the logger
# manager so LoggingStatLogger can periodically print KV cache /
# prefix cache hit rate, and PrometheusStatLogger can publish.
if self.logger_manager is not None:
try:
self.logger_manager.record(
engine_idx=stage_id,
scheduler_stats=raw_outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception:
logger.exception(
"[Orchestrator] stat logger record failed for stage-%s",
stage_id,
)

return processed.request_outputs

async def _handle_add_request(self, msg: dict[str, Any]) -> None:
Expand Down
7 changes: 5 additions & 2 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,8 +743,11 @@ async def is_tracing_enabled(self) -> bool:
return False

async def do_log_stats(self) -> None:
"""Log statistics via the engine, mirroring vLLM ``AsyncLLM``."""
await self.engine.do_log_stats()
"""Log statistics.

TODO: Forward to Orchestrator process via message.
"""
pass

async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
"""Return the task set exposed by the orchestrator-backed engine."""
Expand Down
Loading