diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index e909923172a..1854afbe9ba 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -18,6 +18,7 @@ from pytest_mock import MockerFixture from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse +from vllm_omni.entrypoints.omni import BackgroundResources from vllm_omni.entrypoints.openai import api_server as api_server_module from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol.audio import CreateAudio, OpenAICreateSpeechRequest @@ -1042,7 +1043,7 @@ async def test_tts_only_no_generate_task(self): omni.output_modalities = [None, "audio"] stage = MagicMock() stage.is_comprehension = False - omni.stage_list = [stage] + omni.resources = BackgroundResources(stage_list=[stage]) tasks = await omni.get_supported_tasks() assert "generate" not in tasks assert "speech" in tasks @@ -1058,7 +1059,7 @@ async def test_omni_model_includes_generate(self): omni.output_modalities = ["text", None, "audio"] stage = MagicMock() stage.is_comprehension = True - omni.stage_list = [stage] + omni.resources = BackgroundResources(stage_list=[stage]) tasks = await omni.get_supported_tasks() assert "generate" in tasks diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py index 9e555fa85c8..18f8b410d8d 100644 --- a/tests/entrypoints/test_omni_diffusion.py +++ b/tests/entrypoints/test_omni_diffusion.py @@ -1192,10 +1192,9 @@ def init_stage_worker(self, *args, **kwargs): from vllm_omni.entrypoints.omni import Omni - # Use very short timeout - omni = Omni(model=MODEL, init_timeout=0.01) - # Verify that no stages are ready - assert len(omni._stages_ready) == 0 + with pytest.raises(TimeoutError): + # Use very short timeout + Omni(model=MODEL, init_timeout=0.01) def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): @@ -1371,15 +1370,25 @@ def _fake_loader(model: str, base_engine_args=None): _setup_multiprocessing_mocks(monkeypatch, mocker) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch, mocker) - from vllm_omni.entrypoints.omni import Omni + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) monkeypatch.setattr( "vllm_omni.entrypoints.utils.load_stage_configs_from_model", _fake_loader, raising=False, ) + from vllm_omni.entrypoints.omni import Omni + omni = Omni(model="any", init_timeout=1, dtype=dtype) # Dtype parsing being checked is on the diffusion path @@ -1417,15 +1426,25 @@ class NotATorchDtype: _setup_multiprocessing_mocks(monkeypatch, mocker) _setup_ipc_mocks(monkeypatch) _setup_log_mocks(monkeypatch) + _setup_connector_mocks(monkeypatch, mocker) - from vllm_omni.entrypoints.omni import Omni + monkeypatch.setattr( + "vllm_omni.entrypoints.omni_stage.OmniStage", + lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs), + raising=False, + ) + + import vllm_omni.entrypoints.omni as omni_module + monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs)) monkeypatch.setattr( "vllm_omni.entrypoints.utils.load_stage_configs_from_model", _fake_loader, raising=False, ) + from vllm_omni.entrypoints.omni import Omni + # Raise TypeError if we get an unrecognized type with pytest.raises(TypeError): Omni(model="any", init_timeout=1, dtype=NotATorchDtype) diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index b11b14c8939..22b3956dac1 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -1042,10 +1042,9 @@ def init_stage_worker(self, *args, **kwargs): from vllm_omni.entrypoints.omni import Omni - # Use very short timeout - omni = Omni(model=MODEL, init_timeout=0.01) - # Verify that no stages are ready - assert len(omni._stages_ready) == 0 + with pytest.raises(TimeoutError): + # Use very short timeout + Omni(model=MODEL, init_timeout=0.01) def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config): diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index b9867903345..cb9b012ec55 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -3,7 +3,6 @@ import asyncio import copy import time -import weakref from collections.abc import AsyncGenerator, Callable, Iterable, Sequence from typing import Any, TypeVar @@ -18,13 +17,12 @@ from vllm_omni.config import OmniModelConfig from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length, try_send_via_connector -from vllm_omni.distributed.ray_utils.utils import try_close_ray from vllm_omni.engine.input_processor import OmniInputProcessor from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker from vllm_omni.entrypoints.client_request_state import ClientRequestState from vllm_omni.entrypoints.omni import OmniBase from vllm_omni.entrypoints.omni_stage import OmniStage -from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType +from vllm_omni.entrypoints.stage_utils import OmniStageTaskType from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load from vllm_omni.entrypoints.utils import ( get_final_stage_id_for_e2e, @@ -41,41 +39,6 @@ logger = init_logger(__name__) -def _weak_close_cleanup_async( - stage_list, stage_in_queues, stage_out_queues, ray_pg, output_handler, zmq_ctx=None, inline_engine=None -): - """Weak reference cleanup function for AsyncOmni instances.""" - if inline_engine is not None: - try: - inline_engine.close() - except Exception as e: - logger.warning("Failed to close inline diffusion engine: %s", e) - if stage_list: - for q in stage_in_queues: - try: - q.put_nowait(SHUTDOWN_TASK) - except Exception as e: - logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for q in stage_out_queues: - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for stage in stage_list: - try: - stage.stop_stage_worker() - except Exception as e: - logger.warning(f"Failed to stop stage worker: {e}") - try_close_ray(ray_pg) - # Cancel output handler - if output_handler is not None: - output_handler.cancel() - if zmq_ctx is not None: - zmq_ctx.term() - - class AsyncOmni(OmniBase): """Asynchronous unified entry point supporting multi-stage pipelines for LLM and Diffusion models. @@ -120,7 +83,6 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: # Request state tracking self.request_states: dict[str, ClientRequestState] = {} - self.output_handler: asyncio.Task | None = None # RPC results storage: {stage_id: {rpc_id: result}} # Used to avoid race condition between output_handler and collective_rpc @@ -131,23 +93,10 @@ def __init__(self, model: str, **kwargs: dict[str, Any]) -> None: super().__init__(model, **kwargs) - # Register weak reference cleanup (called on garbage collection) - self._weak_finalizer = weakref.finalize( - self, - _weak_close_cleanup_async, - self.stage_list, - self._stage_in_queues, - self._stage_out_queues, - self._ray_pg, - self.output_handler, - self._zmq_ctx, - getattr(self, "_inline_engine", None), - ) - async def get_supported_tasks(self) -> set[str]: """Return supported tasks based on stage output modalities and capabilities.""" tasks: set[str] = set() - if "text" in self.output_modalities or any(stage.is_comprehension for stage in self.stage_list): + if "text" in self.output_modalities or any(stage.is_comprehension for stage in self.resources.stage_list): tasks.add("generate") if "audio" in self.output_modalities: tasks.add("speech") @@ -253,7 +202,7 @@ def _process_stage_ready(self, stage: OmniStage, stage_id: int, result: dict[str def _wait_for_stages_ready(self, timeout: int = 120) -> None: """Wait for all stages to report readiness.""" super()._wait_for_stages_ready(timeout) - for stage in self.stage_list: + for stage in self.resources.stage_list: if stage.vllm_config is not None and stage.tokenizer is not None: try: vllm_config = stage.vllm_config @@ -367,11 +316,13 @@ async def generate( if sampling_params_list is None: sampling_params_list = self.default_sampling_params_list - if len(sampling_params_list) != len(self.stage_list): - raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + if len(sampling_params_list) != len(self.resources.stage_list): + raise ValueError( + f"Expected {len(self.resources.stage_list)} sampling params, got {len(sampling_params_list)}" + ) # Orchestrator keeps stage objects for input derivation - num_stages = len(self.stage_list) + num_stages = len(self.resources.stage_list) # Track per-request start time for end-to-end timing _req_start_ts: dict[int, float] = {} _wall_start_ts: float = time.time() @@ -380,7 +331,7 @@ async def generate( # Determine the final stage for E2E stats (highest stage_id with # final_output=True; fallback to last stage) final_stage_id_for_e2e = get_final_stage_id_for_e2e( - output_modalities, self.output_modalities, self.stage_list + output_modalities, self.output_modalities, self.resources.stage_list ) # Metrics/aggregation helper @@ -412,7 +363,7 @@ async def generate( "engine_inputs": prompt, "sampling_params": sp0, } - self.stage_list[0].submit(task) + self.resources.stage_list[0].submit(task) # Submit CFG companion requests to stage-0 if cfg.is_active: @@ -423,7 +374,7 @@ async def generate( "engine_inputs": companion_prompt, "sampling_params": cfg.stage0_sampling_params, } - self.stage_list[0].submit(companion_task) + self.resources.stage_list[0].submit(companion_task) metrics.stage_first_ts[0] = metrics.stage_first_ts[0] or time.time() _req_start_ts[request_id] = time.time() @@ -518,7 +469,7 @@ async def _generate_inline( loop = asyncio.get_running_loop() results = await loop.run_in_executor( None, - self._inline_engine.generate, + self.resources._inline_engine.generate, prompt, sp0, [request_id], @@ -584,7 +535,7 @@ async def _process_async_results( _last_progress_ts = time.time() while not all(all_stages_finished.values()): _loop_iter += 1 - for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): + for stage_id, stage in enumerate(self.resources.stage_list[: final_stage_id_for_e2e + 1]): if all_stages_finished[stage_id]: continue try: @@ -615,13 +566,13 @@ async def _process_async_results( engine_input.pop(_mm_key, None) if engine_input.get("type") == "multimodal": engine_input["type"] = "token" - for i in range(1, len(self.stage_list)): + for i in range(1, len(self.resources.stage_list)): task = { "request_id": request_id, "engine_inputs": engine_input, "sampling_params": sampling_params_list[i], } - self.stage_list[i].submit(task) + self.resources.stage_list[i].submit(task) metrics.stage_first_ts[i] = time.time() all_stages_finished[stage_id] = finished @@ -638,7 +589,7 @@ async def _process_sequential_results( prompt: Any, cfg: CfgCompanionTracker | None = None, ) -> AsyncGenerator[OmniRequestOutput, None]: - for stage_id, stage in enumerate(self.stage_list[: final_stage_id_for_e2e + 1]): + for stage_id, stage in enumerate(self.resources.stage_list[: final_stage_id_for_e2e + 1]): cfg_stage0 = stage_id == 0 and cfg is not None and cfg.is_active finished = False @@ -673,10 +624,10 @@ async def _process_sequential_results( # Forward to next stage if there is one next_stage_id = stage_id + 1 if next_stage_id <= final_stage_id_for_e2e: - next_stage: OmniStage = self.stage_list[next_stage_id] + next_stage: OmniStage = self.resources.stage_list[next_stage_id] # Derive inputs for the next stage, record postprocess time with metrics.stage_postprocess_timer(stage_id, request_id): - next_inputs = next_stage.process_engine_inputs(self.stage_list, prompt) + next_inputs = next_stage.process_engine_inputs(self.resources.stage_list, prompt) sp_next: SamplingParams = sampling_params_list[next_stage_id] if cfg is not None and cfg.is_active and not cfg.is_parent_failed(request_id): @@ -703,7 +654,7 @@ async def _process_sequential_results( next_inputs=next_inputs, sampling_params=sp_next, original_prompt=prompt, - next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, + next_stage_queue_submit_fn=self.resources.stage_list[next_stage_id].submit, metrics=metrics, ) @@ -802,10 +753,10 @@ def _process_single_result( return engine_outputs, finished, output_to_yield def _run_output_handler(self) -> None: - if self.output_handler is not None: + if self.resources.output_handler is not None: return - stage_list = self.stage_list + stage_list = self.resources.stage_list request_states = self.request_states companion_to_parent = self._companion_to_parent @@ -865,15 +816,15 @@ async def output_handler(): else: await req_state.queue.put(error_msg) error_msg = {"request_id": req_state.request_id, "error": str(e)} - self.output_handler = None # Make possible for restart + self.resources.output_handler = None # Make possible for restart - self.output_handler = asyncio.create_task(output_handler()) + self.resources.output_handler = asyncio.create_task(output_handler()) @property def is_running(self) -> bool: if self._inline_diffusion: - return self._inline_engine is not None - return len(self._stage_in_queues) > 0 + return self.resources._inline_engine is not None + return len(self.resources._stage_in_queues) > 0 @property def is_stopped(self) -> bool: @@ -897,16 +848,16 @@ def dead_error(self) -> BaseException: async def abort(self, request_id: str | Iterable[str]) -> None: if self._inline_diffusion: - if self._inline_engine is not None: - self._inline_engine.engine.abort(request_id) + if self.resources._inline_engine is not None: + self.resources._inline_engine.engine.abort(request_id) return None abort_task = {"type": OmniStageTaskType.ABORT, "request_id": request_id} - for stage in self.stage_list: + for stage in self.resources.stage_list: stage.submit(abort_task) return None async def get_vllm_config(self) -> VllmConfig: - for stage in self.stage_list: + for stage in self.resources.stage_list: if stage.is_comprehension: # Use the vllm_config received from worker process if stage.vllm_config is not None: @@ -914,7 +865,7 @@ async def get_vllm_config(self) -> VllmConfig: return None async def get_model_config(self) -> OmniModelConfig: - for stage in self.stage_list: + for stage in self.resources.stage_list: if stage.is_comprehension: # Use the vllm_config received from worker process if stage.vllm_config is not None: @@ -925,13 +876,13 @@ async def get_input_preprocessor(self) -> InputPreprocessor: return None async def get_tokenizer(self) -> TokenizerLike: - for stage in self.stage_list: + for stage in self.resources.stage_list: if stage.is_comprehension: return stage.tokenizer return None async def is_tracing_enabled(self) -> bool: - for stage in self.stage_list: + for stage in self.resources.stage_list: if stage.is_comprehension: return stage.is_tracing_enabled return False diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 488b986d8ee..71149cf1d73 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import json import multiprocessing as mp import os @@ -9,6 +10,7 @@ import weakref from collections.abc import Callable, Generator, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field from typing import Any, Literal, TypeVar, overload import huggingface_hub @@ -37,6 +39,7 @@ try_close_ray, ) from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker +from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion from vllm_omni.entrypoints.omni_stage import OmniStage from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK, OmniStageTaskType from vllm_omni.entrypoints.stage_utils import maybe_load_from_ipc as _load @@ -60,52 +63,6 @@ logger = init_logger(__name__) -def _weak_close_cleanup( - stage_list, - stage_in_queues, - stage_out_queues, - ray_pg, - zmq_ctx=None, - handshake_stop: threading.Event | None = None, - zmq_handshake_socket: zmq.Socket | None = None, - handshake_thread: threading.Thread | None = None, -): - """Weak reference cleanup function for OmniBase instances.""" - if stage_list: - for q in stage_in_queues: - try: - q.put_nowait(SHUTDOWN_TASK) - except Exception as e: - logger.warning(f"Failed to send shutdown signal to stage input queue: {e}") - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for q in stage_out_queues: - close_fn = getattr(q, "close", None) - if callable(close_fn): - close_fn() - for stage in stage_list: - try: - stage.stop_stage_worker() - except Exception as e: - logger.warning(f"Failed to stop stage worker: {e}") - try_close_ray(ray_pg) - - # Gracefully shutdown handshake server thread - if handshake_stop is not None: - handshake_stop.set() - if handshake_thread is not None: - handshake_thread.join(timeout=2.0) - if handshake_thread.is_alive(): - logger.warning("Handshake server thread did not terminate gracefully within timeout") - - # Close ZMQ resources after thread has exited - if zmq_handshake_socket is not None: - zmq_handshake_socket.close(0) - if zmq_ctx is not None: - zmq_ctx.term() - - def _dummy_snapshot_download(model_id): return model_id @@ -138,6 +95,74 @@ def omni_snapshot_download(model_id) -> str: return model_id +@dataclass +class BackgroundResources: + stage_list: list[OmniStage] = field(default_factory=list) + _stage_in_queues: list[Any] = field(default_factory=list) + _stage_out_queues: list[Any] = field(default_factory=list) + _ray_pg = None + _zmq_ctx: zmq.Context | None = None + _zmq_handshake_socket: zmq.Socket | zmq.asyncio.Socket | None = None # type: ignore[name-defined] + _handshake_stop: threading.Event | None = None + _handshake_thread: threading.Thread | None = None + _inline_engine: OmniDiffusion | None = None + + # Async Omni Resources + output_handler: asyncio.Task | None = None + + # Set if any stages are dead + stage_dead: bool = False + + def __call__(self): + """Cleanup all background resources""" + self.stage_dead = True + + if self._inline_engine is not None: + try: + self._inline_engine.close() + except Exception as e: + logger.warning("Failed to close inline diffusion engine: %s", e) + + for q in self._stage_in_queues: + try: + q.put_nowait(SHUTDOWN_TASK) + except Exception as e: + logger.warning("Failed to send shutdown signal to stage input queue: ", exc_info=e) + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() + + for q in self._stage_out_queues: + close_fn = getattr(q, "close", None) + if callable(close_fn): + close_fn() + + for stage in self.stage_list: + try: + stage.stop_stage_worker() + except Exception as e: + logger.warning("Failed to stop stage worker: ", exc_info=e) + + try_close_ray(self._ray_pg) + + # Gracefully shutdown handshake server thread + if self._handshake_stop is not None: + self._handshake_stop.set() + if self._handshake_thread is not None: + self._handshake_thread.join(timeout=2.0) + if self._handshake_thread.is_alive(): + logger.warning("Handshake server thread did not terminate gracefully within timeout") + + if self.output_handler is not None: + self.output_handler.cancel() + + # Close ZMQ resources after thread has exited + if self._zmq_handshake_socket is not None: + self._zmq_handshake_socket.close(0) + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + class OmniBase: """Base class for serving Omni models. @@ -165,19 +190,11 @@ def __init__(self, model: str, **kwargs: Any) -> None: kwargs["model"] = model # Stage management attributes - self.stage_list: list[OmniStage] = [] - self._stage_in_queues: list[Any] = [] - self._stage_out_queues: list[Any] = [] self._stages_ready: set[int] = set() - self._ray_pg = None self._queue_cls = None self._ctx = None - self._zmq_ctx: zmq.Context | None = None self._zmq_master_address: str | None = None self._zmq_master_port: int | None = None - self._zmq_handshake_socket: zmq.Socket | None = None - self._handshake_thread: threading.Thread | None = None - self._handshake_stop: threading.Event | None = None self._handshake_endpoints: dict[int, tuple[str, str]] = {} self._handshake_seen: set[int] = set() # Track which stage IDs have completed ZMQ handshake self._single_stage_id: int | None = None # Optional: deploy only a specific stage ID @@ -189,11 +206,29 @@ def __init__(self, model: str, **kwargs: Any) -> None: # Used by collective_rpc to retrieve results collected from the output queue self._rpc_results: dict[int, dict[str, dict[str, Any]]] = {} + # Set up cleanup finalizer + self.resources = BackgroundResources() + self._weak_finalizer = weakref.finalize(self, self.resources) + # Initialize stages - each stage will create appropriate instance based on stage_type # Stage workers will automatically create OmniLLM or OmniDiffusion instances # based on stage_type in YAML config (handled in omni_stage.py) - logger.info(f"Initializing stages for model: {model}") - self._initialize_stages(model, kwargs) + success = False + try: + logger.info(f"Initializing stages for model: {model}") + self._initialize_stages(model, kwargs) + success = True + finally: + if not success: + self._weak_finalizer() + + @property + def stage_list(self) -> list[OmniStage]: + return self.resources.stage_list + + @property + def output_handler(self) -> asyncio.Task | None: + return self.resources.output_handler def _get_default_cache_config(self, cache_backend: str | None) -> dict[str, Any] | None: if cache_backend == "cache_dit": @@ -306,7 +341,6 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: """Initialize stage list management.""" self._inline_diffusion = False - self._inline_engine = None stage_init_timeout = kwargs.get("stage_init_timeout", 20) shm_threshold_bytes = kwargs.get("shm_threshold_bytes", 65536) @@ -350,10 +384,10 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]: for fut in as_completed(futures): results.append(fut.result()) results.sort(key=lambda x: x[0]) - self.stage_list = [st for _, st in results] - self.default_sampling_params_list = [st.default_sampling_params for st in self.stage_list] - self.output_modalities = [st.final_output_type for st in self.stage_list] - logger.info(f"[{self._name}] Loaded {len(self.stage_list)} stages") + self.resources.stage_list = [st for _, st in results] + self.default_sampling_params_list = [st.default_sampling_params for st in self.resources.stage_list] + self.output_modalities = [st.final_output_type for st in self.resources.stage_list] + logger.info(f"[{self._name}] Loaded {len(self.resources.stage_list)} stages") # Phase 1 optimization: for a single diffusion stage in async mode, # run the engine directly in the orchestrator process to eliminate @@ -418,7 +452,7 @@ def _init_inline_diffusion_engine( cfg_kv_collect_func = load_func_from_config(getattr(stage_config, "cfg_kv_collect_func", None)) - self._inline_engine = OmniDiffusion( + self.resources._inline_engine = OmniDiffusion( model=model, stage_id=stage_id, engine_input_source=getattr(stage_config, "engine_input_source", []), @@ -448,13 +482,13 @@ def _start_stages(self, model: str) -> None: """Start all stage processes.""" if self.worker_backend == "ray": # Initialize Ray Cluster - self._ray_pg = create_placement_group( - number_of_stages=len(self.stage_list), address=self.ray_address, strategy="PACK" + self.resources._ray_pg = create_placement_group( + number_of_stages=len(self.resources.stage_list), address=self.ray_address, strategy="PACK" ) else: # Initialize ZMQ context - if self._zmq_ctx is None: - self._zmq_ctx = zmq.Context() + if self.resources._zmq_ctx is None: + self.resources._zmq_ctx = zmq.Context() # Allocate endpoints for each stage total_stages = len(self.stage_configs) @@ -474,17 +508,17 @@ def _start_stages(self, model: str) -> None: # Start handshake server self.start_handshake_server() - for stage_id, stage in enumerate[OmniStage](self.stage_list): + for stage_id, stage in enumerate[OmniStage](self.resources.stage_list): if self.worker_backend == "ray": in_q = self._queue_cls() out_q = self._queue_cls() else: in_endpoint, out_endpoint = self._handshake_endpoints[stage_id] - in_q = ZmqQueue(self._zmq_ctx, zmq.PUSH, bind=in_endpoint) - out_q = ZmqQueue(self._zmq_ctx, zmq.PULL, bind=out_endpoint) + in_q = ZmqQueue(self.resources._zmq_ctx, zmq.PUSH, bind=in_endpoint) + out_q = ZmqQueue(self.resources._zmq_ctx, zmq.PULL, bind=out_endpoint) - self._stage_in_queues.append(in_q) - self._stage_out_queues.append(out_q) + self.resources._stage_in_queues.append(in_q) + self.resources._stage_out_queues.append(out_q) stage.attach_queues(in_q, out_q) stage_connectors_config = get_stage_connector_config( @@ -518,7 +552,7 @@ def _start_stages(self, model: str) -> None: batch_timeout=self.batch_timeout, connectors_config=stage_connectors_config, worker_backend=self.worker_backend, - ray_placement_group=self._ray_pg, + ray_placement_group=self.resources._ray_pg, ignore_runtime_config=True if self._single_stage_id is not None else False, ) @@ -533,14 +567,14 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: if self._single_stage_id is not None and self.worker_backend != "ray": timeout = self._wait_for_handshakes(timeout) - num_stages = len(self.stage_list) + num_stages = len(self.resources.stage_list) deadline = time.time() + max(0, int(timeout)) logger.info(f"[{self._name}] Waiting for {num_stages} stages to initialize (timeout: {timeout}s)") while len(self._stages_ready) < num_stages and time.time() < deadline: progressed = False - for stage_id, stage in enumerate(self.stage_list): + for stage_id, stage in enumerate(self.resources.stage_list): if stage_id in self._stages_ready: continue @@ -566,7 +600,8 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: ) suggestions = [ - f"Ignore this warning if the model weight download / load from disk time is longer than {timeout}s.", + f"If model weight download / load from disk takes longer than {timeout}s, " + "please increase initialization wait time (stage_init_timeout or call-site timeout)", "Verify GPU/device assignment in config (runtime.devices) is correct.", "Check GPU/host memory availability; reduce model or batch size if needed.", "Check model weights path and network reachability (if loading remotely).", @@ -575,7 +610,11 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: formatted_suggestions = "\n".join(f" {i + 1}) {msg}" for i, msg in enumerate(suggestions)) - logger.warning(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") + logger.error(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") + raise TimeoutError( + f"{self._name}: {len(self._stages_ready)}/{num_stages} stages ready after {timeout}s. " + f"Missing stages: {not_ready}" + ) def _is_profiler_enabled(self, stage_id: int) -> bool: """Check if profiler config is set for a given stage.""" @@ -724,10 +763,10 @@ def start_profile(self, stages: list[int] | None = None) -> None: >>> omni.start_profile(stages=[0, 2]) """ if stages is None: - stages = list(range(len(self.stage_list))) + stages = list(range(len(self.resources.stage_list))) for stage_id in stages: - if stage_id < len(self.stage_list): + if stage_id < len(self.resources.stage_list): if not self._is_profiler_enabled(stage_id): logger.info( "[%s] Skipping start_profile for stage-%s: profiler config not set", @@ -736,7 +775,7 @@ def start_profile(self, stages: list[int] | None = None) -> None: ) continue try: - self.stage_list[stage_id].submit({"type": OmniStageTaskType.PROFILER_START}) + self.resources.stage_list[stage_id].submit({"type": OmniStageTaskType.PROFILER_START}) logger.info("[%s] Sent start_profile to stage-%s", self._name, stage_id) except Exception as e: logger.warning( @@ -752,12 +791,12 @@ def stop_profile(self, stages: list[int] | None = None) -> dict: the file paths for traces and tables. """ if stages is None: - stages = list(range(len(self.stage_list))) + stages = list(range(len(self.resources.stage_list))) all_results = {"traces": [], "tables": []} for stage_id in stages: - if stage_id < len(self.stage_list): + if stage_id < len(self.resources.stage_list): if not self._is_profiler_enabled(stage_id): logger.info( "[%s] Skipping stop_profile for stage-%s: profiler config not set", @@ -864,20 +903,22 @@ def _process_handshake_message(self, msg: Any) -> dict[str, Any]: def _run_handshake_server_loop(self) -> None: """Main loop for handshake server - polls for messages and responds.""" poller = zmq.Poller() - poller.register(self._zmq_handshake_socket, zmq.POLLIN) + poller.register(self.resources._zmq_handshake_socket, zmq.POLLIN) try: - while not self._handshake_stop.is_set(): + while not self.resources._handshake_stop.is_set(): events = poller.poll(1000) - has_message = any(sock == self._zmq_handshake_socket and event == zmq.POLLIN for sock, event in events) + has_message = any( + sock == self.resources._zmq_handshake_socket and event == zmq.POLLIN for sock, event in events + ) if not has_message: continue - msg = msgspec.msgpack.decode(self._zmq_handshake_socket.recv()) + msg = msgspec.msgpack.decode(self.resources._zmq_handshake_socket.recv()) response = msgspec.msgpack.encode(self._process_handshake_message(msg)) - self._zmq_handshake_socket.send(response) + self.resources._zmq_handshake_socket.send(response) finally: - poller.unregister(self._zmq_handshake_socket) + poller.unregister(self.resources._zmq_handshake_socket) def start_handshake_server(self) -> None: """Start the ZMQ handshake server. @@ -887,7 +928,7 @@ def start_handshake_server(self) -> None: Skips starting if the server is already running or ZMQ is not initialized. """ # Skip if already running or ZMQ not initialized - if self._handshake_thread is not None or self._zmq_ctx is None: + if self.resources._handshake_thread is not None or self.resources._zmq_ctx is None: return # Skip if master address/port not configured @@ -899,14 +940,16 @@ def start_handshake_server(self) -> None: local_only=False, host=self._zmq_master_address, port=int(self._zmq_master_port) ) - self._handshake_stop = threading.Event() - self._zmq_handshake_socket = make_zmq_socket(self._zmq_ctx, endpoint, zmq.REP, bind=True, linger=5000) + self.resources._handshake_stop = threading.Event() + self.resources._zmq_handshake_socket = make_zmq_socket( + self.resources._zmq_ctx, endpoint, zmq.REP, bind=True, linger=5000 + ) # Start server thread - self._handshake_thread = threading.Thread( + self.resources._handshake_thread = threading.Thread( target=self._run_handshake_server_loop, daemon=True, name="zmq-handshake-server" ) - self._handshake_thread.start() + self.resources._handshake_thread.start() def _wait_for_handshakes(self, timeout: int = 120) -> int: """Wait for handshakes from all expected stages. @@ -983,23 +1026,6 @@ class Omni(OmniBase): >>> print(outputs) """ - def __init__(self, model: str, **kwargs: Any) -> None: - super().__init__(model, **kwargs) - - # Register weak reference cleanup (called on garbage collection) - self._weak_finalizer = weakref.finalize( - self, - _weak_close_cleanup, - self.stage_list, - self._stage_in_queues, - self._stage_out_queues, - self._ray_pg, - self._zmq_ctx, - self._handshake_stop, - self._zmq_handshake_socket, - self._handshake_thread, - ) - @overload def generate( self, @@ -1099,10 +1125,12 @@ def _run_generation( if sampling_params_list is None: raise ValueError("sampling_params_list is required for pipelined generation") - if len(sampling_params_list) != len(self.stage_list): - raise ValueError(f"Expected {len(self.stage_list)} sampling params, got {len(sampling_params_list)}") + if len(sampling_params_list) != len(self.resources.stage_list): + raise ValueError( + f"Expected {len(self.resources.stage_list)} sampling params, got {len(sampling_params_list)}" + ) - for i, (stage, sp) in enumerate(zip(self.stage_list, sampling_params_list)): + for i, (stage, sp) in enumerate(zip(self.resources.stage_list, sampling_params_list)): ExpectedSPType = OmniDiffusionSamplingParams if stage.stage_type == "diffusion" else SamplingParams if not isinstance(sp, ExpectedSPType): raise ValueError( @@ -1117,7 +1145,7 @@ def _run_generation( request_prompts = list(prompts) # Orchestrator keeps stage objects for input derivation - num_stages = len(self.stage_list) + num_stages = len(self.resources.stage_list) # Generate globally unique request IDs and map them to original prompts request_ids = [f"{i}_{uuid.uuid4()}" for i in range(len(request_prompts))] @@ -1142,7 +1170,7 @@ def _run_generation( else: prompt_modalities = None final_stage_id_for_e2e = get_final_stage_id_for_e2e( - prompt_modalities, self.output_modalities, self.stage_list + prompt_modalities, self.output_modalities, self.resources.stage_list ) final_stage_id_to_prompt[rid] = final_stage_id_for_e2e @@ -1171,7 +1199,7 @@ def _run_generation( "engine_inputs": prompt, "sampling_params": sp0, } - self.stage_list[0].submit(task) + self.resources.stage_list[0].submit(task) _req_start_ts[req_id] = time.time() logger.debug(f"[{self._name}] Enqueued request {req_id} to stage-0") @@ -1207,7 +1235,7 @@ def _run_generation( ) while completed_requests < total_requests: made_progress = False - for stage_id, stage in enumerate(self.stage_list): + for stage_id, stage in enumerate(self.resources.stage_list): result = stage.try_collect() if result is None: continue @@ -1362,7 +1390,7 @@ def _run_generation( success = cfg.forward_parent_with_cfg( req_id, {"engine_outputs": engine_outputs, "stage_id": stage_id}, - self.stage_list, + self.resources.stage_list, self.connectors, sampling_params_list, request_id_to_prompt, @@ -1381,12 +1409,12 @@ def _run_generation( cfg.defer_parent(req_id, engine_outputs, stage_id) continue - next_stage: OmniStage = self.stage_list[next_stage_id] + next_stage: OmniStage = self.resources.stage_list[next_stage_id] try: # Derive inputs for the next stage, record preprocess time with metrics.stage_postprocess_timer(stage_id, req_id): next_inputs = next_stage.process_engine_inputs( - self.stage_list, [request_id_to_prompt[req_id]] + self.resources.stage_list, [request_id_to_prompt[req_id]] ) except Exception as e: completed_requests += 1 @@ -1410,7 +1438,7 @@ def _run_generation( next_inputs=next_inputs, sampling_params=sp_next, original_prompt=request_id_to_prompt[req_id], - next_stage_queue_submit_fn=self.stage_list[next_stage_id].submit, + next_stage_queue_submit_fn=self.resources.stage_list[next_stage_id].submit, metrics=metrics, )