diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index b388b18606b..b78d62d9eda 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -18,6 +18,7 @@ from pytest_mock import MockerFixture from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse +from vllm_omni.entrypoints.omni_base import OmniEngineDeadError from vllm_omni.entrypoints.openai import api_server as api_server_module from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol.audio import ( @@ -1707,6 +1708,77 @@ def test_api_server_create_speech_wraps_error_response_status(mocker: MockerFixt assert response.status_code == 400 +def test_api_server_create_speech_engine_error_response_includes_request_and_stage_id(mocker: MockerFixture): + handler = mocker.MagicMock() + handler.create_speech = mocker.AsyncMock( + side_effect=OmniEngineDeadError( + "engine dead", + error_stage_id=1, + ) + ) + + terminate_mock = mocker.patch.object(api_server_module, "terminate_if_errored") + + app = FastAPI() + app.state.args = SimpleNamespace(log_error_stack=False) + app.state.openai_serving_speech = handler + app.state.engine_client = SimpleNamespace( + engine=SimpleNamespace(is_alive=lambda: False), + errored=True, + ) + app.state.server = SimpleNamespace() + scope = { + "type": "http", + "app": app, + "method": "POST", + "path": "/v1/audio/speech", + "headers": [], + "query_string": b"", + "client": ("127.0.0.1", 12345), + "server": ("testserver", 80), + "scheme": "http", + } + raw_request = Request(scope) + raw_request.state.request_metadata = SimpleNamespace(request_id="speech-req-1") + request = OpenAICreateSpeechRequest(input="Hello") + + response = asyncio.run(api_server_module.create_speech(request, raw_request)) + + assert isinstance(response, JSONResponse) + assert response.status_code == 500 + assert response.body.decode("utf-8") == ( + '{"error":{"message":"engine dead","type":"InternalServerError","param":null,' + '"code":500,"request_id":"speech-req-1","error_stage_id":1}}' + ) + terminate_mock.assert_called_once() + + +def test_omni_engine_error_handler_includes_request_and_stage_id(mocker: MockerFixture): + app = FastAPI() + app.state.args = SimpleNamespace(log_error_stack=False) + app.state.engine_client = SimpleNamespace( + engine=SimpleNamespace(is_alive=lambda: False), + errored=True, + ) + app.state.server = SimpleNamespace() + + terminate_mock = mocker.patch.object(api_server_module, "terminate_if_errored") + api_server_module._register_omni_exception_handlers(app) + + @app.get("/boom") + async def boom(request: Request): + request.state.request_metadata = SimpleNamespace(request_id="speech-req-1") + exc = OmniEngineDeadError("engine dead", error_stage_id=1) + raise exc + + response = TestClient(app).get("/boom") + + assert response.status_code == 500 + assert response.json()["error"]["request_id"] == "speech-req-1" + assert response.json()["error"]["error_stage_id"] == 1 + terminate_mock.assert_called_once() + + class TestWAVHeaderGeneration: """Unit tests for WAV header generation with placeholder values.""" diff --git a/tests/entrypoints/test_omni_entrypoints.py b/tests/entrypoints/test_omni_entrypoints.py index 289cf2673fb..341bf6bc09c 100644 --- a/tests/entrypoints/test_omni_entrypoints.py +++ b/tests/entrypoints/test_omni_entrypoints.py @@ -14,6 +14,7 @@ from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.omni import Omni +from vllm_omni.entrypoints.omni_base import OmniEngineDeadError pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -335,6 +336,18 @@ def _enqueue_error_message(engine: FakeAsyncOmniEngine, msg: dict[str, Any]) -> ) +def _enqueue_fatal_error_message(engine: FakeAsyncOmniEngine, msg: dict[str, Any]) -> None: + engine.output_q.put_nowait( + { + "type": "error", + "fatal": True, + "request_id": msg["request_id"], + "stage_id": 2, + "error": "engine dead", + } + ) + + @pytest.mark.asyncio async def test_get_supported_tasks_returns_engine_supported_tasks(): omni = object.__new__(AsyncOmni) @@ -546,18 +559,22 @@ async def test_async_omni_abort_forwards_to_engine(monkeypatch: pytest.MonkeyPat @pytest.mark.asyncio -async def test_async_omni_propagates_engine_error(monkeypatch: pytest.MonkeyPatch): - engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META, on_add_request=_enqueue_error_message) +async def test_async_omni_propagates_fatal_error_context(monkeypatch: pytest.MonkeyPatch): + engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META, on_add_request=_enqueue_fatal_error_message) _patch_engine(monkeypatch, engine) app = AsyncOmni("dummy-model") try: - with pytest.raises(RuntimeError, match="engine boom"): + with pytest.raises(EngineDeadError, match="engine dead") as exc_info: async for _ in app.generate(prompt="hello", request_id="req-1"): pass finally: app.shutdown() + assert isinstance(exc_info.value, OmniEngineDeadError) + assert str(exc_info.value) == "engine dead" + assert getattr(exc_info.value, "error_stage_id") == 2 + def test_omni_generate_py_generator_yields_final_outputs_for_each_request(monkeypatch: pytest.MonkeyPatch): sampling_params = [SamplingParams(max_tokens=8) for _ in range(3)] diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index ca60c753125..ff5117934ac 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -181,6 +181,7 @@ def __init__( self._shutdown_event = asyncio.Event() self._stages_shutdown = False self._fatal_error: str | None = None + self._fatal_error_stage_id: int | None = None async def run(self) -> None: """Main entry point for the Orchestrator event loop.""" @@ -331,6 +332,7 @@ async def _orchestration_loop(self) -> None: e, ) self._fatal_error = str(e) + self._fatal_error_stage_id = stage_id for req_id, req_state in list(self.request_states.items()): if stage_id in req_state.stage_submit_ts: await self.output_async_queue.put( @@ -339,6 +341,7 @@ async def _orchestration_loop(self) -> None: "error": str(e), "fatal": True, "request_id": req_id, + "stage_id": stage_id, } ) self.request_states.pop(req_id, None) @@ -1135,6 +1138,7 @@ async def _drain_pending_requests_on_fatal(self) -> None: "error": self._fatal_error, "fatal": True, "request_id": req_id, + "stage_id": self._fatal_error_stage_id, } ) notified.add(req_id) @@ -1150,6 +1154,7 @@ async def _drain_pending_requests_on_fatal(self) -> None: "error": self._fatal_error, "fatal": True, "request_id": req_id, + "stage_id": self._fatal_error_stage_id, } ) self.request_states.pop(req_id, None) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 67425ea2e65..119232ce767 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -27,7 +27,10 @@ from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask from vllm_omni.entrypoints.client_request_state import ClientRequestState -from vllm_omni.entrypoints.omni_base import OmniBase +from vllm_omni.entrypoints.omni_base import ( + OmniBase, + OmniEngineDeadError, +) from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform @@ -494,7 +497,10 @@ async def _process_orchestrator_results( stage_id = result.get("stage_id", 0) if result.get("type") == "error" and result.get("fatal"): - raise EngineDeadError(result.get("error", "")) + raise OmniEngineDeadError( + result.get("error", ""), + error_stage_id=result.get("stage_id"), + ) # Check for errors if "error" in result: @@ -580,6 +586,18 @@ async def _final_output_loop(): except asyncio.CancelledError: raise + except OmniEngineDeadError as e: + logger.error("[AsyncOmni] Engine dead: %s", e) + for req_state in list(self.request_states.values()): + error_msg = { + "type": "error", + "error": str(e), + "fatal": True, + "request_id": req_state.request_id, + } + if e.error_stage_id is not None: + error_msg["stage_id"] = e.error_stage_id + await req_state.queue.put(error_msg) except EngineDeadError as e: logger.error("[AsyncOmni] Engine dead: %s", e) for req_state in list(self.request_states.values()): @@ -884,7 +902,7 @@ def is_stopped(self) -> bool: @property def dead_error(self) -> BaseException: """EngineClient abstract property implementation.""" - return EngineDeadError() + return OmniEngineDeadError() # ==================== EngineClient Interface ==================== diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index 3180c9c80c0..d922ab0723c 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -28,6 +28,23 @@ logger = init_logger(__name__) +class OmniEngineDeadError(EngineDeadError): + _DEFAULT_MESSAGE = EngineDeadError().args[0] + error_stage_id: int | None + + def __init__( + self, + message: str | None = None, + *, + error_stage_id: int | None = None, + suppress_context: bool = False, + ) -> None: + resolved_message = message or self._DEFAULT_MESSAGE + Exception.__init__(self, resolved_message) + self.__suppress_context__ = suppress_context + self.error_stage_id = error_stage_id + + def _weak_shutdown_engine(engine: AsyncOmniEngine) -> None: """Best-effort engine cleanup for GC finalization.""" try: @@ -296,8 +313,12 @@ def _handle_output_message( if msg_type == "error": error_text = msg.get("error", "Orchestrator returned an error message") + stage_id = msg.get("stage_id") if msg.get("fatal"): - raise EngineDeadError(error_text) + raise OmniEngineDeadError( + error_text, + error_stage_id=stage_id, + ) raise RuntimeError(error_text) if msg_type != "output": @@ -352,7 +373,10 @@ def _check_engine_output_error( ) # NOTE: O(n_stages) check for every error. if self.errored: - raise EngineDeadError(error_text) + raise OmniEngineDeadError( + error_text, + error_stage_id=stage_id, + ) raise EngineGenerateError(error_text) def _process_single_result( diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 41ccbcf1d7c..4454f5bda10 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -220,33 +220,64 @@ async def omni_engine_error_handler( req: Request, exc: EngineDeadError | EngineGenerateError, ): - request_id = req.state.request_metadata.request_id if hasattr(req.state, "request_metadata") else None + request_id = _get_request_id_from_request(req) if req.app.state.args.log_error_stack: logger.exception("Engine Exception caught. Request id: %s", request_id) - engine = req.app.state.engine_client - if isinstance(exc, EngineDeadError): - # Log Omni-specific diagnostic information for dead engines. - orchestrator_alive = engine.engine.is_alive() if hasattr(engine, "engine") else "N/A" - logger.error( - "EngineDeadError: orchestrator_alive=%s, errored=%s, request_id=%s", - orchestrator_alive, - engine.errored, - request_id, - ) - - terminate_if_errored( - server=req.app.state.server, - engine=engine, - ) - err = create_error_response(exc) - return JSONResponse(err.model_dump(), status_code=err.error.code) + return _create_engine_error_json_response(req, exc) app.exception_handler(EngineGenerateError)(omni_engine_error_handler) app.exception_handler(EngineDeadError)(omni_engine_error_handler) +def _get_request_id_from_request(req: Request) -> str | None: + return req.state.request_metadata.request_id if hasattr(req.state, "request_metadata") else None + + +def _build_engine_error_payload( + exc: EngineDeadError | EngineGenerateError, + *, + request_id: str | None, +) -> tuple[dict[str, Any], int]: + err = create_error_response(exc) + payload = err.model_dump() + error_body = payload.get("error", {}) + + error_body["request_id"] = request_id + error_body["error_stage_id"] = getattr(exc, "error_stage_id", None) + + return payload, err.error.code + + +def _create_engine_error_json_response( + req: Request, + exc: EngineDeadError | EngineGenerateError, +) -> JSONResponse: + request_id = _get_request_id_from_request(req) + error_stage_id = getattr(exc, "error_stage_id", None) + engine = req.app.state.engine_client + + if isinstance(exc, EngineDeadError): + # Log Omni-specific diagnostic information for dead engines. + orchestrator_alive = engine.engine.is_alive() if hasattr(engine, "engine") else "N/A" + logger.error( + "EngineDeadError: orchestrator_alive=%s, errored=%s, request_id=%s, error_stage_id=%s", + orchestrator_alive, + engine.errored, + request_id, + error_stage_id, + ) + + terminate_if_errored( + server=req.app.state.server, + engine=engine, + ) + + payload, status_code = _build_engine_error_payload(exc, request_id=request_id) + return JSONResponse(content=payload, status_code=status_code) + + class _DiffusionServingModels: """Minimal OpenAIServingModels implementation for diffusion-only servers. @@ -930,8 +961,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re return base_server.create_error_response(message="The model does not support Chat Completions API") try: generator = await handler.create_chat_completion(request, raw_request) - except (EngineGenerateError, EngineDeadError): - raise # Propagate to the global Omni exception handler + except (EngineGenerateError, EngineDeadError) as exc: + return _create_engine_error_json_response(raw_request, exc) except Exception as e: logger.exception("Chat completion failed: %s", e) raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e @@ -1027,8 +1058,8 @@ async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request status_code=result.error.code if result.error else 400, ) return result - except (EngineGenerateError, EngineDeadError): - raise # Propagate to the global Omni exception handler + except (EngineGenerateError, EngineDeadError) as exc: + return _create_engine_error_json_response(raw_request, exc) except Exception as e: raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e @@ -1063,8 +1094,8 @@ async def create_speech_batch(request: BatchSpeechRequest, raw_request: Request) status_code=result.error.code if result.error else 400, ) return JSONResponse(content=result.model_dump()) - except (EngineGenerateError, EngineDeadError): - raise # Propagate to the global Omni exception handler + except (EngineGenerateError, EngineDeadError) as exc: + return _create_engine_error_json_response(raw_request, exc) except ValueError as e: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) from e except Exception as e: @@ -1522,8 +1553,8 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) data=image_data, ) - except (EngineGenerateError, EngineDeadError): - raise # Propagate to the global Omni exception handler + except (EngineGenerateError, EngineDeadError) as exc: + return _create_engine_error_json_response(raw_request, exc) except HTTPException: raise except ValueError as e: @@ -1759,8 +1790,8 @@ async def edit_images( size=size_str, ) - except (EngineGenerateError, EngineDeadError): - raise # Propagate to the global Omni exception handler + except (EngineGenerateError, EngineDeadError) as exc: + return _create_engine_error_json_response(raw_request, exc) except HTTPException: raise except ValueError as e: @@ -2481,8 +2512,8 @@ async def create_video_sync( status_code=HTTPStatus.GATEWAY_TIMEOUT.value, detail=f"Video generation timed out after {VIDEO_SYNC_TIMEOUT_S}s.", ) - except (EngineGenerateError, EngineDeadError): - raise # Propagate to the global Omni exception handler + except (EngineGenerateError, EngineDeadError) as exc: + return _create_engine_error_json_response(raw_request, exc) except HTTPException: raise except Exception as exc: