diff --git a/tests/conftest.py b/tests/conftest.py index c3ff1024607..0b6e0a399d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -951,15 +951,16 @@ def _start_server(self) -> None: max_wait = 1200 # 20 minutes start_time = time.time() while time.time() - start_time < max_wait: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.settimeout(1) - result = sock.connect_ex((self.host, self.port)) - if result == 0: - print(f"Server ready on {self.host}:{self.port}") - return - except Exception: - pass + # Check for process status + ret = self.proc.poll() + if ret is not None: + raise RuntimeError(f"Server processes exited with code {ret} before becoming ready.") + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(1) + result = sock.connect_ex((self.host, self.port)) + if result == 0: + print(f"Server ready on {self.host}:{self.port}") + return time.sleep(2) raise RuntimeError(f"Server failed to start within {max_wait} seconds") diff --git a/vllm_omni/distributed/ray_utils/utils.py b/vllm_omni/distributed/ray_utils/utils.py index 07513b6601e..ec65cc9fed5 100644 --- a/vllm_omni/distributed/ray_utils/utils.py +++ b/vllm_omni/distributed/ray_utils/utils.py @@ -176,6 +176,15 @@ def run(self, func, *args, **kwargs): runtime_env={"env_vars": {"PYTHONPATH": os.environ.get("PYTHONPATH", "")}, "CUDA_LAUNCH_BLOCKING": "1"}, ).remote() - worker_actor.run.remote(worker_entry_fn, *args, **kwargs) + task_ref = worker_actor.run.remote(worker_entry_fn, *args, **kwargs) - return worker_actor + return worker_actor, task_ref + + +def is_ray_task_alive(task_ref: Any, **kwargs): + """Checks ray task status. Returns FALSE if ray task has exited for any reason.""" + if not RAY_AVAILABLE: + raise ImportError("ray is required to query ray tasks") + + ready, _ = ray.wait([task_ref], **kwargs) + return not bool(ready) diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 59be966c335..818a55758f5 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -487,7 +487,9 @@ def _wait_for_stages_ready(self, timeout: int = 120) -> None: formatted_suggestions = "\n".join(f" {i + 1}) {msg}" for i, msg in enumerate(suggestions)) - logger.warning(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") + logger.error(f"[{self._name}] Stage initialization timeout. Troubleshooting Steps:\n{formatted_suggestions}") + self.close() + raise TimeoutError def start_profile(self, stages: list[int] | None = None) -> None: """Start profiling for specified stages. diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index bfd60f61920..a8f951bb927 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -35,7 +35,7 @@ from vllm_omni.distributed.omni_connectors import build_stage_connectors from vllm_omni.distributed.omni_connectors.adapter import try_recv_via_connector from vllm_omni.distributed.omni_connectors.connectors.base import OmniConnectorBase -from vllm_omni.distributed.ray_utils.utils import kill_ray_actor, start_ray_actor +from vllm_omni.distributed.ray_utils.utils import is_ray_task_alive, kill_ray_actor, start_ray_actor from vllm_omni.engine.arg_utils import AsyncOmniEngineArgs, OmniEngineArgs from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion from vllm_omni.entrypoints.async_omni_llm import AsyncOmniLLM @@ -301,6 +301,8 @@ def __init__(self, stage_config: Any, stage_init_timeout: int = 300): self._in_q: mp.queues.Queue | ZmqQueue | str | None = None self._out_q: mp.queues.Queue | ZmqQueue | str | None = None self._proc: mp.Process | None = None + self._ray_actor: Any | None = None + self._ray_task_ref: Any | None = None self._shm_threshold_bytes: int = 65536 self._stage_init_timeout: int = stage_init_timeout @@ -472,7 +474,7 @@ def init_stage_worker( os.environ["VLLM_LOGGING_PREFIX"] = new_env if worker_backend == "ray": if is_async: - self._ray_actor = start_ray_actor( + self._ray_actor, self._ray_task_ref = start_ray_actor( _stage_worker_async_entry, ray_placement_group, self.stage_id, @@ -484,7 +486,7 @@ def init_stage_worker( stage_init_timeout=self._stage_init_timeout, ) else: - self._ray_actor = start_ray_actor( + self._ray_actor, self._ray_task_ref = start_ray_actor( _stage_worker, ray_placement_group, self.stage_id, @@ -547,9 +549,10 @@ def stop_stage_worker(self) -> None: if callable(close_fn): close_fn() - if hasattr(self, "_ray_actor") and self._ray_actor: + if self._ray_actor is not None: kill_ray_actor(self._ray_actor) self._ray_actor = None + self._ray_task_ref = None elif self._proc is not None: try: self._proc.join(timeout=5) @@ -609,10 +612,19 @@ def try_collect(self) -> dict[str, Any] | None: request_id, engine_outputs (or engine_outputs_shm), and metrics. """ assert self._out_q is not None + if self._proc is not None and not self._proc.is_alive(): + raise RuntimeError("OmniStage Worker process died unexpectedly") + if self._ray_task_ref is not None and not is_ray_task_alive(self._ray_task_ref, timeout=0): + raise RuntimeError("OmniStage Ray actor died unexpectedly") + try: return self._out_q.get_nowait() - except Exception: + except queue.Empty: return None + except Exception as e: + logger.error("Unexpected error when collecting OmniStage output queue:", exc_info=e) + self.stop_stage_worker() + raise def process_engine_inputs( self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None