Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/entrypoints/openai_api/test_serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytest_mock import MockerFixture
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse

from vllm_omni.entrypoints.omni import BackgroundResources
from vllm_omni.entrypoints.openai import api_server as api_server_module
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.protocol.audio import CreateAudio, OpenAICreateSpeechRequest
Expand Down Expand Up @@ -1042,7 +1043,7 @@ async def test_tts_only_no_generate_task(self):
omni.output_modalities = [None, "audio"]
stage = MagicMock()
stage.is_comprehension = False
omni.stage_list = [stage]
omni.resources = BackgroundResources(stage_list=[stage])
tasks = await omni.get_supported_tasks()
assert "generate" not in tasks
assert "speech" in tasks
Expand All @@ -1058,7 +1059,7 @@ async def test_omni_model_includes_generate(self):
omni.output_modalities = ["text", None, "audio"]
stage = MagicMock()
stage.is_comprehension = True
omni.stage_list = [stage]
omni.resources = BackgroundResources(stage_list=[stage])
tasks = await omni.get_supported_tasks()
assert "generate" in tasks

Expand Down
31 changes: 25 additions & 6 deletions tests/entrypoints/test_omni_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,10 +1192,9 @@ def init_stage_worker(self, *args, **kwargs):

from vllm_omni.entrypoints.omni import Omni

# Use very short timeout
omni = Omni(model=MODEL, init_timeout=0.01)
# Verify that no stages are ready
assert len(omni._stages_ready) == 0
with pytest.raises(TimeoutError):
# Use very short timeout
Omni(model=MODEL, init_timeout=0.01)


def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config):
Expand Down Expand Up @@ -1371,15 +1370,25 @@ def _fake_loader(model: str, base_engine_args=None):
_setup_multiprocessing_mocks(monkeypatch, mocker)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
_setup_connector_mocks(monkeypatch, mocker)

from vllm_omni.entrypoints.omni import Omni
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs),
raising=False,
)

import vllm_omni.entrypoints.omni as omni_module

monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs))
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)

from vllm_omni.entrypoints.omni import Omni

omni = Omni(model="any", init_timeout=1, dtype=dtype)

# Dtype parsing being checked is on the diffusion path
Expand Down Expand Up @@ -1417,15 +1426,25 @@ class NotATorchDtype:
_setup_multiprocessing_mocks(monkeypatch, mocker)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
_setup_connector_mocks(monkeypatch, mocker)

from vllm_omni.entrypoints.omni import Omni
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs),
raising=False,
)

import vllm_omni.entrypoints.omni as omni_module

monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(mocker, cfg, **kwargs))
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)

from vllm_omni.entrypoints.omni import Omni

# Raise TypeError if we get an unrecognized type
with pytest.raises(TypeError):
Omni(model="any", init_timeout=1, dtype=NotATorchDtype)
7 changes: 3 additions & 4 deletions tests/entrypoints/test_omni_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,10 +1042,9 @@ def init_stage_worker(self, *args, **kwargs):

from vllm_omni.entrypoints.omni import Omni

# Use very short timeout
omni = Omni(model=MODEL, init_timeout=0.01)
# Verify that no stages are ready
assert len(omni._stages_ready) == 0
with pytest.raises(TimeoutError):
# Use very short timeout
Omni(model=MODEL, init_timeout=0.01)


def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker: MockerFixture, fake_stage_config):
Expand Down
Loading
Loading