diff --git a/tests/engine/test_async_omni_engine_outputs.py b/tests/engine/test_async_omni_engine_outputs.py new file mode 100644 index 0000000000..ccf9e8cb6b --- /dev/null +++ b/tests/engine/test_async_omni_engine_outputs.py @@ -0,0 +1,65 @@ +"""Tests for AsyncOmniEngine.try_get_output and try_get_output_async. + +Focuses on the critical behavior: when the orchestrator thread dies, +subsequent attempts to collect output raise RuntimeError. +""" + +import queue +from unittest.mock import MagicMock + +import pytest + +from vllm_omni.engine.async_omni_engine import AsyncOmniEngine + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _make_engine(output_queue, *, thread_alive: bool = True) -> AsyncOmniEngine: + """Create an AsyncOmniEngine bypassing __init__.""" + engine = object.__new__(AsyncOmniEngine) + engine.output_queue = output_queue + engine.orchestrator_thread = MagicMock( + is_alive=MagicMock(return_value=thread_alive), + ) + return engine + + +def test_try_get_output_raises_after_orchestrator_dies(): + """Draining remaining results then hitting an empty queue with a dead + orchestrator must raise RuntimeError so callers know the pipeline is gone.""" + mock_queue = MagicMock() + # First call succeeds; second call finds the queue empty. + mock_queue.sync_q.get.side_effect = [ + {"type": "output", "request_id": "r1"}, + queue.Empty, + ] + + engine = _make_engine(mock_queue, thread_alive=True) + + # Collect the one buffered result. + assert engine.try_get_output()["request_id"] == "r1" + + # Orchestrator thread crashes between polls. + engine.orchestrator_thread.is_alive.return_value = False + + with pytest.raises(RuntimeError, match="Orchestrator died unexpectedly"): + engine.try_get_output() + + +@pytest.mark.asyncio +async def test_try_get_output_async_raises_after_orchestrator_dies(): + """Same scenario as above but for the async variant.""" + mock_queue = MagicMock() + mock_queue.sync_q.get_nowait.side_effect = [ + {"type": "output", "request_id": "r1"}, + queue.Empty, + ] + + engine = _make_engine(mock_queue, thread_alive=True) + + assert (await engine.try_get_output_async())["request_id"] == "r1" + + engine.orchestrator_thread.is_alive.return_value = False + + with pytest.raises(RuntimeError, match="Orchestrator died unexpectedly"): + await engine.try_get_output_async() diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 47b97295c6..c2cc78909f 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -987,6 +987,8 @@ def try_get_output(self, timeout: float = 0.001) -> dict[str, Any] | None: try: return self.output_queue.sync_q.get(timeout=timeout) except queue.Empty: + if not self.is_alive(): + raise RuntimeError("Orchestrator died unexpectedly. See logs above.") return None async def try_get_output_async(self) -> dict[str, Any] | None: @@ -996,6 +998,8 @@ async def try_get_output_async(self) -> dict[str, Any] | None: try: return self.output_queue.sync_q.get_nowait() except queue.Empty: + if not self.is_alive(): + raise RuntimeError("Orchestrator died unexpectedly. See logs above.") return None def get_stage_metadata(self, stage_id: int) -> dict[str, Any]: