Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 121 additions & 4 deletions tests/entrypoints/openai_api/test_serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,11 +658,13 @@ def speech_server(self, mocker: MockerFixture):
mock_engine_client.tts_max_instructions_length = None
mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True
return OmniOpenAIServingSpeech(
server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)
yield server
server.shutdown()

def test_is_tts_detection_no_stage(self, speech_server):
"""Test TTS model detection when no TTS stage exists."""
Expand Down Expand Up @@ -1639,11 +1641,13 @@ def fish_speech_server(mocker: MockerFixture):
mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True

return OmniOpenAIServingSpeech(
server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)
yield server
server.shutdown()


class TestFishSpeechServing:
Expand Down Expand Up @@ -1717,7 +1721,7 @@ def test_build_fish_prompt_rejects_unsafe_control_tokens(self, fish_speech_serve
fish_speech_server._build_fish_speech_prompt(request)

def test_prepare_speech_generation_overrides_fish_default_max_tokens(self, fish_speech_server):
fish_speech_server._build_fish_speech_prompt = MagicMock(
fish_speech_server._build_fish_speech_prompt_async = AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
Expand All @@ -1730,13 +1734,14 @@ def test_prepare_speech_generation_overrides_fish_default_max_tokens(self, fish_

assert request_id.startswith("speech-")
assert generator == "generator"
fish_speech_server._build_fish_speech_prompt_async.assert_awaited_once()
fish_speech_server.engine_client.generate.assert_called_once()
sampling_params_list = fish_speech_server.engine_client.generate.call_args.kwargs["sampling_params_list"]
assert sampling_params_list[0].max_tokens == 4096
assert fish_speech_server.engine_client.default_sampling_params_list[0].max_tokens == 2048

def test_prepare_speech_generation_uses_stage_default_max_tokens(self, fish_speech_server):
fish_speech_server._build_fish_speech_prompt = MagicMock(
fish_speech_server._build_fish_speech_prompt_async = AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
Expand Down Expand Up @@ -1985,3 +1990,115 @@ def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server):
assert generator == "generator"
assert tts_params == {}
cosyvoice3_server._build_cosyvoice3_prompt.assert_awaited_once()


class TestTTSAsyncOffloading:
"""Tests for event-loop-safe offloading of blocking TTS operations."""

def test_build_voxtral_prompt_is_sync(self):
"""_build_voxtral_prompt should be a regular function, not a coroutine."""
assert not asyncio.iscoroutinefunction(OmniOpenAIServingSpeech._build_voxtral_prompt)

@pytest.fixture
def voxtral_server(self, mocker: MockerFixture):
mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
mock_engine_client = mocker.MagicMock()
mock_engine_client.errored = False
mock_engine_client.model_config = mocker.MagicMock(model="mistralai/Voxtral")
mock_engine_client.default_sampling_params_list = [SimpleNamespace(max_tokens=2048)]
mock_engine_client.tts_batch_max_items = 32
mock_engine_client.generate = mocker.MagicMock(return_value="generator")
mock_engine_client.stage_configs = [
SimpleNamespace(
engine_args=SimpleNamespace(model_stage="audio_generation"),
tts_args={},
)
]
mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True
server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)
yield server
server.shutdown()

@pytest.fixture
def qwen3_tts_server(self, mocker: MockerFixture):
mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
mock_engine_client = mocker.MagicMock()
mock_engine_client.errored = False
mock_engine_client.model_config = mocker.MagicMock(model="Qwen/Qwen3-TTS", hf_config=mocker.MagicMock())
mock_engine_client.default_sampling_params_list = [SimpleNamespace(max_tokens=2048)]
mock_engine_client.tts_batch_max_items = 32
mock_engine_client.generate = mocker.MagicMock(return_value="generator")
mock_engine_client.tts_max_instructions_length = None
mock_engine_client.stage_configs = [
SimpleNamespace(
engine_args=SimpleNamespace(model_stage="qwen3_tts"),
tts_args={},
)
]
mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True
server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)
yield server
server.shutdown()

def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server):
"""Voxtral path in _prepare_speech_generation should call the async wrapper."""
voxtral_server._build_voxtral_prompt_async = AsyncMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {"voice": ["test"]},
}
)
request = OpenAICreateSpeechRequest(input="hello", voice="test")
asyncio.run(voxtral_server._prepare_speech_generation(request))
voxtral_server._build_voxtral_prompt_async.assert_awaited_once()

def test_prepare_speech_generation_awaits_qwen3_tts_async(self, qwen3_tts_server):
"""Qwen3 TTS path should call _estimate_prompt_len_async."""
qwen3_tts_server._validate_tts_request = MagicMock(return_value=None)
qwen3_tts_server._build_tts_params = MagicMock(
return_value={"text": ["hello"], "task_type": ["CustomVoice"], "speaker": ["Vivian"]}
)
qwen3_tts_server._estimate_prompt_len_async = AsyncMock(return_value=512)
request = OpenAICreateSpeechRequest(input="hello")
asyncio.run(qwen3_tts_server._prepare_speech_generation(request))
qwen3_tts_server._build_tts_params.assert_called_once()
qwen3_tts_server._estimate_prompt_len_async.assert_awaited_once()

def test_shutdown_is_idempotent(self, mocker: MockerFixture):
"""Calling shutdown() twice should not raise."""
mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
mock_engine_client = mocker.MagicMock()
mock_engine_client.errored = False
mock_engine_client.stage_configs = []
mock_engine_client.tts_max_instructions_length = None
mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True
server = OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)
assert server._tts_executor is not None
server.shutdown()
assert server._tts_executor is None
server.shutdown() # Should not raise
assert server._tts_executor is None

def test_diffusion_instance_shutdown_safe(self):
"""Diffusion instances (created via for_diffusion) should have safe shutdown."""
server = OmniOpenAIServingSpeech.for_diffusion(diffusion_engine=MagicMock(), model_name="test-model")
assert server._tts_executor is None
server.shutdown() # Should not raise
1 change: 1 addition & 0 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None,
try:
await shutdown_task
finally:
app.state.openai_serving_speech.shutdown()
sock.close()


Expand Down
25 changes: 21 additions & 4 deletions vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import struct
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any

Expand All @@ -22,6 +23,7 @@
from vllm.logger import init_logger
from vllm.multimodal.media import MediaConnector
from vllm.utils import random_uuid
from vllm.utils.async_utils import make_async

from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.protocol.audio import (
Expand Down Expand Up @@ -153,6 +155,7 @@ def _validate_path_within_directory(file_path: Path, directory: Path) -> bool:

class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin):
_diffusion_mode: bool = False
_tts_executor: ThreadPoolExecutor | None = None

@classmethod
def for_diffusion(
Expand Down Expand Up @@ -219,6 +222,14 @@ def __init__(self, *args, **kwargs):
# Load speech tokenizer codec parameters for prompt length estimation
self._codec_frame_rate: float | None = self._load_codec_frame_rate()

# Shared thread pool executor for blocking TTS preprocessing
# operations. max_workers=1 serializes tokenizer access to avoid
# Rust RefCell "Already borrowed" errors from concurrent use.
self._tts_executor = ThreadPoolExecutor(max_workers=1)
self._build_voxtral_prompt_async = make_async(self._build_voxtral_prompt, executor=self._tts_executor)
self._build_fish_speech_prompt_async = make_async(self._build_fish_speech_prompt, executor=self._tts_executor)
self._estimate_prompt_len_async = make_async(self._estimate_prompt_len, executor=self._tts_executor)

def _load_codec_frame_rate(self) -> float | None:
"""Load codec frame rate from speech tokenizer config for prompt length estimation."""
try:
Expand Down Expand Up @@ -252,6 +263,12 @@ def _load_codec_frame_rate(self) -> float | None:
pass
return None

def shutdown(self) -> None:
"""Shut down the TTS thread pool executor."""
if self._tts_executor is not None:
self._tts_executor.shutdown(wait=False, cancel_futures=True)
self._tts_executor = None

def _find_tts_stage(self):
"""Find and return the TTS stage config, or None if not found."""
for stage in self.engine_client.stage_configs:
Expand Down Expand Up @@ -1149,7 +1166,7 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any

# ---- Voxtral TTS helpers ----

async def _build_voxtral_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
def _build_voxtral_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
"""Build Voxtral TTS engine prompt from shared TTS parameters."""
from mistral_common.protocol.speech.request import SpeechRequest

Expand Down Expand Up @@ -1289,7 +1306,7 @@ async def _prepare_speech_generation(
if request.ref_audio is not None:
wav_list, sr = await self._resolve_ref_audio(request.ref_audio)
ref_audio_data = (wav_list, sr)
prompt = self._build_fish_speech_prompt(request, ref_audio_data=ref_audio_data)
prompt = await self._build_fish_speech_prompt_async(request, ref_audio_data=ref_audio_data)
tts_params = {}
elif self._tts_model_type == "omnivoice":
tts_params = {}
Expand All @@ -1300,7 +1317,7 @@ async def _prepare_speech_generation(
raise ValueError(validation_error)

if self._tts_model_type == "voxtral_tts":
prompt = await self._build_voxtral_prompt(request)
prompt = await self._build_voxtral_prompt_async(request)
tts_params = {}
elif self._tts_model_type == "cosyvoice3":
prompt = await self._build_cosyvoice3_prompt(request)
Expand All @@ -1317,7 +1334,7 @@ async def _prepare_speech_generation(
wav_list, sr = await self._resolve_ref_audio(ref_audio_source)
tts_params["ref_audio"] = [[wav_list, sr]]

ph_len = self._estimate_prompt_len(tts_params)
ph_len = await self._estimate_prompt_len_async(tts_params)
prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params}
else:
tts_params = {}
Expand Down
Loading