Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/entrypoints/openai_api/test_serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytest_mock import MockerFixture
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse

from vllm_omni.entrypoints.omni_base import OmniEngineDeadError
from vllm_omni.entrypoints.openai import api_server as api_server_module
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.protocol.audio import (
Expand Down Expand Up @@ -1707,6 +1708,77 @@ def test_api_server_create_speech_wraps_error_response_status(mocker: MockerFixt
assert response.status_code == 400


def test_api_server_create_speech_engine_error_response_includes_request_and_stage_id(mocker: MockerFixture):
handler = mocker.MagicMock()
handler.create_speech = mocker.AsyncMock(
side_effect=OmniEngineDeadError(
"engine dead",
error_stage_id=1,
)
)

terminate_mock = mocker.patch.object(api_server_module, "terminate_if_errored")

app = FastAPI()
app.state.args = SimpleNamespace(log_error_stack=False)
app.state.openai_serving_speech = handler
app.state.engine_client = SimpleNamespace(
engine=SimpleNamespace(is_alive=lambda: False),
errored=True,
)
app.state.server = SimpleNamespace()
scope = {
"type": "http",
"app": app,
"method": "POST",
"path": "/v1/audio/speech",
"headers": [],
"query_string": b"",
"client": ("127.0.0.1", 12345),
"server": ("testserver", 80),
"scheme": "http",
}
raw_request = Request(scope)
raw_request.state.request_metadata = SimpleNamespace(request_id="speech-req-1")
request = OpenAICreateSpeechRequest(input="Hello")

response = asyncio.run(api_server_module.create_speech(request, raw_request))

assert isinstance(response, JSONResponse)
assert response.status_code == 500
assert response.body.decode("utf-8") == (
'{"error":{"message":"engine dead","type":"InternalServerError","param":null,'
'"code":500,"request_id":"speech-req-1","error_stage_id":1}}'
)
terminate_mock.assert_called_once()


def test_omni_engine_error_handler_includes_request_and_stage_id(mocker: MockerFixture):
app = FastAPI()
app.state.args = SimpleNamespace(log_error_stack=False)
app.state.engine_client = SimpleNamespace(
engine=SimpleNamespace(is_alive=lambda: False),
errored=True,
)
app.state.server = SimpleNamespace()

terminate_mock = mocker.patch.object(api_server_module, "terminate_if_errored")
api_server_module._register_omni_exception_handlers(app)

@app.get("/boom")
async def boom(request: Request):
request.state.request_metadata = SimpleNamespace(request_id="speech-req-1")
exc = OmniEngineDeadError("engine dead", error_stage_id=1)
raise exc

response = TestClient(app).get("/boom")

assert response.status_code == 500
assert response.json()["error"]["request_id"] == "speech-req-1"
assert response.json()["error"]["error_stage_id"] == 1
terminate_mock.assert_called_once()


class TestWAVHeaderGeneration:
"""Unit tests for WAV header generation with placeholder values."""

Expand Down
23 changes: 20 additions & 3 deletions tests/entrypoints/test_omni_entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.entrypoints.omni_base import OmniEngineDeadError

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]

Expand Down Expand Up @@ -335,6 +336,18 @@ def _enqueue_error_message(engine: FakeAsyncOmniEngine, msg: dict[str, Any]) ->
)


def _enqueue_fatal_error_message(engine: FakeAsyncOmniEngine, msg: dict[str, Any]) -> None:
engine.output_q.put_nowait(
{
"type": "error",
"fatal": True,
"request_id": msg["request_id"],
"stage_id": 2,
"error": "engine dead",
}
)


@pytest.mark.asyncio
async def test_get_supported_tasks_returns_engine_supported_tasks():
omni = object.__new__(AsyncOmni)
Expand Down Expand Up @@ -546,18 +559,22 @@ async def test_async_omni_abort_forwards_to_engine(monkeypatch: pytest.MonkeyPat


@pytest.mark.asyncio
async def test_async_omni_propagates_engine_error(monkeypatch: pytest.MonkeyPatch):
engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META, on_add_request=_enqueue_error_message)
async def test_async_omni_propagates_fatal_error_context(monkeypatch: pytest.MonkeyPatch):
engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META, on_add_request=_enqueue_fatal_error_message)
_patch_engine(monkeypatch, engine)

app = AsyncOmni("dummy-model")
try:
with pytest.raises(RuntimeError, match="engine boom"):
with pytest.raises(EngineDeadError, match="engine dead") as exc_info:
async for _ in app.generate(prompt="hello", request_id="req-1"):
pass
finally:
app.shutdown()

assert isinstance(exc_info.value, OmniEngineDeadError)
assert str(exc_info.value) == "engine dead"
assert getattr(exc_info.value, "error_stage_id") == 2


def test_omni_generate_py_generator_yields_final_outputs_for_each_request(monkeypatch: pytest.MonkeyPatch):
sampling_params = [SamplingParams(max_tokens=8) for _ in range(3)]
Expand Down
5 changes: 5 additions & 0 deletions vllm_omni/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
self._shutdown_event = asyncio.Event()
self._stages_shutdown = False
self._fatal_error: str | None = None
self._fatal_error_stage_id: int | None = None

async def run(self) -> None:
"""Main entry point for the Orchestrator event loop."""
Expand Down Expand Up @@ -331,6 +332,7 @@ async def _orchestration_loop(self) -> None:
e,
)
self._fatal_error = str(e)
self._fatal_error_stage_id = stage_id
for req_id, req_state in list(self.request_states.items()):
if stage_id in req_state.stage_submit_ts:
await self.output_async_queue.put(
Expand All @@ -339,6 +341,7 @@ async def _orchestration_loop(self) -> None:
"error": str(e),
"fatal": True,
"request_id": req_id,
"stage_id": stage_id,
}
)
self.request_states.pop(req_id, None)
Expand Down Expand Up @@ -1135,6 +1138,7 @@ async def _drain_pending_requests_on_fatal(self) -> None:
"error": self._fatal_error,
"fatal": True,
"request_id": req_id,
"stage_id": self._fatal_error_stage_id,
}
)
notified.add(req_id)
Expand All @@ -1150,6 +1154,7 @@ async def _drain_pending_requests_on_fatal(self) -> None:
"error": self._fatal_error,
"fatal": True,
"request_id": req_id,
"stage_id": self._fatal_error_stage_id,
}
)
self.request_states.pop(req_id, None)
Expand Down
24 changes: 21 additions & 3 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@

from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask
from vllm_omni.entrypoints.client_request_state import ClientRequestState
from vllm_omni.entrypoints.omni_base import OmniBase
from vllm_omni.entrypoints.omni_base import (
OmniBase,
OmniEngineDeadError,
)
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
Expand Down Expand Up @@ -494,7 +497,10 @@ async def _process_orchestrator_results(
stage_id = result.get("stage_id", 0)

if result.get("type") == "error" and result.get("fatal"):
raise EngineDeadError(result.get("error", ""))
raise OmniEngineDeadError(
result.get("error", ""),
error_stage_id=result.get("stage_id"),
)

# Check for errors
if "error" in result:
Expand Down Expand Up @@ -580,6 +586,18 @@ async def _final_output_loop():

except asyncio.CancelledError:
raise
except OmniEngineDeadError as e:
logger.error("[AsyncOmni] Engine dead: %s", e)
for req_state in list(self.request_states.values()):
error_msg = {
"type": "error",
"error": str(e),
"fatal": True,
"request_id": req_state.request_id,
}
if e.error_stage_id is not None:
error_msg["stage_id"] = e.error_stage_id
await req_state.queue.put(error_msg)
except EngineDeadError as e:
logger.error("[AsyncOmni] Engine dead: %s", e)
for req_state in list(self.request_states.values()):
Expand Down Expand Up @@ -884,7 +902,7 @@ def is_stopped(self) -> bool:
@property
def dead_error(self) -> BaseException:
"""EngineClient abstract property implementation."""
return EngineDeadError()
return OmniEngineDeadError()

# ==================== EngineClient Interface ====================

Expand Down
28 changes: 26 additions & 2 deletions vllm_omni/entrypoints/omni_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@
logger = init_logger(__name__)


class OmniEngineDeadError(EngineDeadError):
_DEFAULT_MESSAGE = EngineDeadError().args[0]
error_stage_id: int | None

def __init__(
self,
message: str | None = None,
*,
error_stage_id: int | None = None,
suppress_context: bool = False,
) -> None:
resolved_message = message or self._DEFAULT_MESSAGE
Exception.__init__(self, resolved_message)
self.__suppress_context__ = suppress_context
self.error_stage_id = error_stage_id


def _weak_shutdown_engine(engine: AsyncOmniEngine) -> None:
"""Best-effort engine cleanup for GC finalization."""
try:
Expand Down Expand Up @@ -296,8 +313,12 @@ def _handle_output_message(

if msg_type == "error":
error_text = msg.get("error", "Orchestrator returned an error message")
stage_id = msg.get("stage_id")
if msg.get("fatal"):
raise EngineDeadError(error_text)
raise OmniEngineDeadError(
error_text,
error_stage_id=stage_id,
)
raise RuntimeError(error_text)

if msg_type != "output":
Expand Down Expand Up @@ -352,7 +373,10 @@ def _check_engine_output_error(
)
# NOTE: O(n_stages) check for every error.
if self.errored:
raise EngineDeadError(error_text)
raise OmniEngineDeadError(
error_text,
error_stage_id=stage_id,
)
raise EngineGenerateError(error_text)

def _process_single_result(
Expand Down
Loading
Loading