diff --git a/tests/entrypoints/openai_api/test_voice_cache.py b/tests/entrypoints/openai_api/test_voice_cache.py new file mode 100644 index 00000000000..575f4beb008 --- /dev/null +++ b/tests/entrypoints/openai_api/test_voice_cache.py @@ -0,0 +1,202 @@ +"""Tests for uploaded voice cache warmup endpoint.""" + +from __future__ import annotations + +import time +from typing import Any +from unittest.mock import AsyncMock + +import numpy as np +import pytest + +from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest +from vllm_omni.entrypoints.openai.serving_speech import ( + OmniOpenAIServingSpeech, + SpeakerCacheUnsupportedError, + SpeakerNotFoundError, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class FakeSpeakerCache: + def __init__(self): + self.entries: dict[tuple[str, str, int], dict[str, Any]] = {} + + @staticmethod + def make_cache_key(speaker_name: str, model_type: str, created_at: int = 0) -> tuple[str, str, int]: + return (model_type, speaker_name.lower(), int(created_at)) + + def get(self, key: tuple[str, str, int]) -> dict[str, Any] | None: + return self.entries.get(key) + + def put(self, key: tuple[str, str, int], artifacts: dict[str, Any]) -> None: + self.entries[key] = artifacts + + def clear(self, speaker_name: str | None = None) -> int: + if speaker_name is None: + removed = len(self.entries) + self.entries.clear() + return removed + before = len(self.entries) + normalized = speaker_name.lower() + self.entries = {key: value for key, value in self.entries.items() if key[1] != normalized} + return before - len(self.entries) + + +def _make_server(mocker, *, model_stage: str = "qwen3_tts") -> OmniOpenAIServingSpeech: + engine_client = mocker.MagicMock() + engine_client.errored = False + engine_client.tts_max_instructions_length = None + engine_client.default_sampling_params_list = [{}] + + stage = mocker.MagicMock() + stage.engine_args.model_stage = model_stage + stage.stage_id = 0 + stage.tts_args = {} + engine_client.stage_configs = [stage] + engine_client.collective_rpc = AsyncMock() + + models = mocker.MagicMock() + models.is_base_model.return_value = True + server = OmniOpenAIServingSpeech( + engine_client=engine_client, + models=models, + request_logger=mocker.MagicMock(), + ) + server._speaker_cache = FakeSpeakerCache() + return server + + +def _audio_speaker_info(*, ref_text: str | None = None) -> dict[str, Any]: + info: dict[str, Any] = { + "name": "voice_a", + "voice_name_lower": "voice_a", + "consent": "consent", + "file_path": "/tmp/voice_a.safetensors", + "created_at": int(time.time()), + "mime_type": "audio/wav", + "sample_rate": 16000, + "embedding_source": "audio", + } + if ref_text is not None: + info["ref_text"] = ref_text + return info + + +def _direct_speaker_info() -> dict[str, Any]: + return { + "name": "voice_a", + "voice_name_lower": "voice_a", + "consent": "consent", + "file_path": "/tmp/voice_a.safetensors", + "created_at": int(time.time()), + "mime_type": "application/x-safetensors", + "embedding_source": "direct", + "embedding_dim": 1024, + } + + +class TestVoiceCacheWarmup: + @pytest.fixture + def server(self, mocker): + server = _make_server(mocker) + yield server + server.shutdown() + + @pytest.mark.asyncio + async def test_missing_voice_returns_not_found(self, server): + with pytest.raises(SpeakerNotFoundError): + await server.create_voice_cache("missing") + + @pytest.mark.asyncio + async def test_direct_embedding_voice_is_rejected(self, server): + server.uploaded_speakers = {"voice_a": _direct_speaker_info()} + with pytest.raises(SpeakerCacheUnsupportedError): + await server.create_voice_cache("voice_a") + + @pytest.mark.asyncio + async def test_non_qwen3_model_is_rejected(self, mocker): + server = _make_server(mocker, model_stage="audio_generation") + server.uploaded_speakers = {"voice_a": _audio_speaker_info()} + with pytest.raises(SpeakerCacheUnsupportedError): + await server.create_voice_cache("voice_a") + + @pytest.mark.asyncio + async def test_existing_cache_returns_idempotent_ready(self, server): + speaker_info = _audio_speaker_info(ref_text="hello") + server.uploaded_speakers = {"voice_a": speaker_info} + key = server._qwen3_speaker_cache_key("voice_a", speaker_info) + server._speaker_cache.put(key, {"ref_code": None, "ref_spk_embedding": object(), "icl_mode": True}) + + result = await server.create_voice_cache("voice_a") + + assert result["cache_status"] == "ready" + assert "already exists" in result["message"] + server.engine_client.collective_rpc.assert_not_called() + + @pytest.mark.asyncio + async def test_success_warms_shared_speaker_cache(self, server, mocker): + speaker_info = _audio_speaker_info(ref_text="hello") + server.uploaded_speakers = {"voice_a": speaker_info} + mocker.patch.object(server, "_load_uploaded_audio", return_value=(np.zeros(16000, dtype=np.float32), 16000)) + server.engine_client.collective_rpc.return_value = [ + [ + { + "ref_spk_embedding": [0.1] * 1024, + "ref_code": [[1, 2], [3, 4]], + "x_vector_only_mode": False, + "icl_mode": True, + "ref_text": "hello", + } + ] + ] + + result = await server.create_voice_cache("voice_a") + + assert result == {"voice": "voice_a", "cache_status": "ready"} + key = server._qwen3_speaker_cache_key("voice_a", speaker_info) + cached = server._speaker_cache.get(key) + assert cached is not None + assert cached["icl_mode"] is True + assert cached["ref_code"].shape == (2, 2) + server.engine_client.collective_rpc.assert_awaited_once() + + def test_build_tts_params_uses_warmed_cache_without_ref_audio(self, server, mocker): + speaker_info = _audio_speaker_info(ref_text="hello") + server.uploaded_speakers = {"voice_a": speaker_info} + key = server._qwen3_speaker_cache_key("voice_a", speaker_info) + server._speaker_cache.put(key, {"ref_code": None, "ref_spk_embedding": object(), "icl_mode": True}) + get_audio_data = mocker.patch.object(server, "_get_uploaded_audio_data") + + params = server._build_tts_params( + OpenAICreateSpeechRequest( + input="test", + voice="voice_a", + response_format="wav", + ) + ) + + assert params["task_type"] == ["Base"] + assert params["speaker"] == ["voice_a"] + assert params["ref_text"] == ["hello"] + assert params["x_vector_only_mode"] == [False] + assert "ref_audio" not in params + get_audio_data.assert_not_called() + + def test_build_tts_params_falls_back_to_raw_audio_without_cache(self, server, mocker): + server.uploaded_speakers = {"voice_a": _audio_speaker_info(ref_text="hello")} + get_audio_data = mocker.patch.object( + server, "_get_uploaded_audio_data", return_value="data:audio/wav;base64,AA==" + ) + + params = server._build_tts_params( + OpenAICreateSpeechRequest( + input="test", + voice="voice_a", + response_format="wav", + ) + ) + + assert params["ref_audio"] == ["data:audio/wav;base64,AA=="] + get_audio_data.assert_called_once_with("voice_a") diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 06fb0a7f4cb..1987cf0d538 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -118,7 +118,11 @@ from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat -from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech +from vllm_omni.entrypoints.openai.serving_speech import ( + OmniOpenAIServingSpeech, + SpeakerCacheUnsupportedError, + SpeakerNotFoundError, +) from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo, ReferenceImage from vllm_omni.entrypoints.openai.serving_video_stream import OmniStreamingVideoHandler @@ -1294,6 +1298,60 @@ async def upload_voice( return base(raw_request).create_error_response(message=f"Failed to upload voice: {str(e)}") +@router.post( + "/v1/audio/voices/{name}/cache", + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": dict}, + HTTPStatus.NOT_FOUND.value: {"model": dict}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": dict}, + }, +) +async def create_voice_cache( + name: str, + raw_request: Request, + force: bool = Query( + False, + description=("Force rebuild even if the in-memory speaker cache already exists."), + ), +): + """Pre-compute voice clone prompt for an uploaded voice. + + Triggers GPU-side speaker embedding extraction and reference audio + codec encoding on the TTS worker. Results are stored in the shared + in-memory speaker cache for faster subsequent TTS requests. + + Only supports audio-uploaded voices (not direct-embedding uploads). + """ + handler = Omnispeech(raw_request) + if handler is None: + return base(raw_request).create_error_response(message="The model does not support Speech API") + try: + result = await handler.create_voice_cache(name, force=force) + return JSONResponse(content=result) + except SpeakerNotFoundError as e: + return JSONResponse( + content={"success": False, "error": str(e)}, + status_code=HTTPStatus.NOT_FOUND.value, + ) + except SpeakerCacheUnsupportedError as e: + return JSONResponse( + content={"success": False, "error": str(e)}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + except ValueError as e: + return JSONResponse( + content={"success": False, "error": str(e)}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + except Exception as e: + logger.exception("Failed to create voice cache for '%s': %s", name, e) + return JSONResponse( + content={"success": False, "error": f"Internal error: {str(e)}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + @router.delete( "/v1/audio/voices/{name}", responses={ diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 256b9a56be8..c18189bc99d 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -58,6 +58,15 @@ logger = init_logger(__name__) + +class SpeakerNotFoundError(ValueError): + """Raised when the requested speaker does not exist in uploaded_speakers.""" + + +class SpeakerCacheUnsupportedError(ValueError): + """Raised when speaker cache generation is not supported.""" + + # TTS Configuration _VOXTRAL_TTS_MODEL_STAGES = {"audio_generation"} _QWEN3_TTS_MODEL_STAGES = {"qwen3_tts"} @@ -207,6 +216,7 @@ def _init_speaker_storage(self) -> None: self._speaker_cache = get_speaker_cache() self._last_upload_ts = 0 self._upload_lock = asyncio.Lock() + self._speaker_cache_build_locks: dict[str, asyncio.Lock] = {} self._restore_uploaded_speakers() logger.info( "Speaker storage: dir=%s, max_speakers=%d, restored=%d", @@ -1120,10 +1130,114 @@ async def delete_voice(self, name: str) -> bool: logger.warning("Failed to delete audio file for '%s': %s", name, e) self._speaker_cache.clear(voice_name_lower) + self._speaker_cache_build_locks.pop(voice_name_lower, None) logger.info("Deleted voice '%s'", name) return True + def _get_speaker_cache_build_lock(self, speaker_key: str) -> asyncio.Lock: + return self._speaker_cache_build_locks.setdefault(speaker_key, asyncio.Lock()) + + def _qwen3_speaker_cache_key(self, speaker_key: str, speaker_info: dict[str, Any]) -> tuple[str, str, int]: + ref_text = speaker_info.get("ref_text") + has_ref_text = isinstance(ref_text, str) and ref_text.strip() != "" + mode = "icl" if has_ref_text else "xvec" + return self._speaker_cache.make_cache_key( + speaker_key, + model_type=f"qwen3_tts_{mode}", + created_at=int(speaker_info.get("created_at", 0)), + ) + + def _has_qwen3_speaker_cache(self, speaker_key: str, speaker_info: dict[str, Any]) -> bool: + return self._speaker_cache.get(self._qwen3_speaker_cache_key(speaker_key, speaker_info)) is not None + + async def create_voice_cache(self, voice_name: str, force: bool = False) -> dict[str, Any]: + """HTTP-boundary wrapper that warms the Qwen3-TTS speaker cache.""" + if self._tts_model_type != "qwen3_tts": + raise SpeakerCacheUnsupportedError("Voice cache generation is only supported for Qwen3-TTS models") + if self._tts_stage is None or not hasattr(self.engine_client, "collective_rpc"): + raise SpeakerCacheUnsupportedError("Voice cache generation requires multi-stage engine support") + + speaker_key = voice_name.lower() + if speaker_key not in self.uploaded_speakers: + raise SpeakerNotFoundError(f"Voice '{voice_name}' not found") + speaker_info = self.uploaded_speakers[speaker_key] + + if speaker_info.get("embedding_source") == "direct": + raise SpeakerCacheUnsupportedError( + f"Voice '{voice_name}' uses a pre-computed speaker embedding. " + "Cache generation only supports audio-uploaded voices." + ) + if speaker_info.get("embedding_source", "audio") != "audio": + raise SpeakerCacheUnsupportedError( + f"Voice '{voice_name}' has unsupported embedding_source={speaker_info.get('embedding_source')!r}" + ) + + async with self._get_speaker_cache_build_lock(speaker_key): + cache_key = self._qwen3_speaker_cache_key(speaker_key, speaker_info) + if not force and self._speaker_cache.get(cache_key) is not None: + return { + "voice": voice_name, + "cache_status": "ready", + "message": "Cache already exists and is valid", + } + + audio_data = self._load_uploaded_audio(voice_name) + if audio_data is None: + raise ValueError( + f"Audio file for uploaded voice '{voice_name}' is missing or corrupted. " + f"Delete this voice via DELETE /v1/audio/voices/{voice_name} and re-upload." + ) + wav_np, sr = audio_data + payload = await self._build_speaker_cache_payload(wav_np, int(sr), speaker_info.get("ref_text")) + ref_code = payload.get("ref_code") + artifacts = { + "ref_code": torch.tensor(ref_code, dtype=torch.long) if ref_code is not None else None, + "ref_spk_embedding": torch.tensor(payload["ref_spk_embedding"], dtype=torch.float32), + "icl_mode": bool(payload.get("icl_mode")), + } + self._speaker_cache.put(cache_key, artifacts) + + return {"voice": voice_name, "cache_status": "ready"} + + async def _build_speaker_cache_payload( + self, + wav_np: np.ndarray, + sample_rate: int, + ref_text: str | None, + ) -> dict[str, Any]: + wav_np = np.asarray(wav_np, dtype=np.float32) + if wav_np.ndim > 1: + wav_np = np.mean(wav_np, axis=-1) + + # msgspec IPC requires plain Python types; numpy arrays do not survive + # this RPC boundary. This can be expensive for long reference audio. + results = await self.engine_client.collective_rpc( + method="create_voice_clone_prompt", + args=(wav_np.tolist(), int(sample_rate), ref_text), + stage_ids=[self._tts_stage.stage_id], + ) + return self._extract_rpc_payload(results) + + @staticmethod + def _extract_rpc_payload(results: list[Any]) -> dict[str, Any]: + if not results: + raise ValueError("Empty RPC response") + stage_result = results[0] + if isinstance(stage_result, dict) and stage_result.get("supported") is False: + raise ValueError(f"Stage RPC failed: {stage_result.get('error', 'unknown')}") + if isinstance(stage_result, dict) and stage_result.get("todo"): + raise ValueError(f"Stage RPC not supported: {stage_result.get('reason', 'unknown')}") + if isinstance(stage_result, list): + if not stage_result: + raise ValueError("Stage RPC returned empty worker results") + payload = stage_result[0] + else: + payload = stage_result + if not isinstance(payload, dict) or "ref_spk_embedding" not in payload: + raise ValueError(f"Invalid RPC payload: {type(payload)}") + return payload + def _is_tts_model(self) -> bool: """Check if the current model is a supported TTS model.""" return any(stage.engine_args.model_stage in _TTS_MODEL_STAGES for stage in self.engine_client.stage_configs) @@ -1347,7 +1461,13 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str speaker_info = self.uploaded_speakers[voice_lower] file_path = Path(speaker_info["file_path"]) if not file_path.exists(): - return f"Data file for uploaded speaker '{request.voice}' not found on disk" + has_warm_cache = ( + self._tts_model_type == "qwen3_tts" + and speaker_info.get("embedding_source") == "audio" + and self._has_qwen3_speaker_cache(voice_lower, speaker_info) + ) + if not has_warm_cache: + return f"Data file for uploaded speaker '{request.voice}' not found on disk" else: # need ref_audio for built-in speaker if request.ref_audio is None: @@ -1680,13 +1800,24 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any # Uploaded voices use task_type="Base" (CustomVoice requires built-in spk_id). # If ref_text was provided at upload time, use in-context cloning; otherwise x_vector only. if request.voice.lower() in self.uploaded_speakers and request.ref_audio is None: - speaker_info = self.uploaded_speakers[request.voice.lower()] + speaker_key = request.voice.lower() + speaker_info = self.uploaded_speakers[speaker_key] # Check if this voice was uploaded with a pre-computed embedding. # Populate request.speaker_embedding so the existing code path # (below) handles voice_clone_prompt and x_vector_only_mode. - embedding = self._get_uploaded_speaker_embedding(request.voice) - if embedding is not None: + if speaker_info.get("embedding_source") == "audio" and self._has_qwen3_speaker_cache( + speaker_key, speaker_info + ): + stored_ref_text = speaker_info.get("ref_text") + params["task_type"] = ["Base"] + if stored_ref_text: + params["ref_text"] = [stored_ref_text] + params["x_vector_only_mode"] = [False] + else: + params["x_vector_only_mode"] = [True] + logger.info("Using warmed speaker cache for uploaded voice: %s", request.voice) + elif (embedding := self._get_uploaded_speaker_embedding(request.voice)) is not None: request.speaker_embedding = embedding params["task_type"] = ["Base"] logger.info("Auto-set speaker_embedding for uploaded voice: %s", request.voice) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index 5703377daa7..2eac4dcdda4 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -1357,6 +1357,15 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: return raw[0] return None + def _normalize_ref_code_payload(raw: object) -> object: + if isinstance(raw, list) and len(raw) == 1 and isinstance(raw[0], (torch.Tensor, np.ndarray)): + # Preserve compatibility with older internal payloads that + # wrapped tensor/ndarray ref_code in a singleton list. Do not + # unwrap Python list payloads: a one-frame ref_code is also a + # valid list with len == 1. + return raw[0] + return raw + if task_type == "Base": # Base supports voice clone prompt with in-context mode. xvec_only = bool((info_dict.get("x_vector_only_mode") or [False])[0]) @@ -1402,12 +1411,17 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: xvec_only = not in_context_mode ref_code = None if voice_clone_prompt is not None: - ref_code = _as_singleton(voice_clone_prompt.get("ref_code")) + # Keep the full ref_code payload. For cached prompts this may be + # a 2D Python list (frames x quantizers), and unwrapping it as a + # singleton would silently drop all but the first frame. + ref_code = _normalize_ref_code_payload(voice_clone_prompt.get("ref_code")) ref_code_t = None if isinstance(ref_code, torch.Tensor): ref_code_t = ref_code elif isinstance(ref_code, np.ndarray): ref_code_t = torch.from_numpy(ref_code) + elif isinstance(ref_code, list) and ref_code: + ref_code_t = torch.tensor(ref_code, dtype=torch.long) if isinstance(ref_code_t, torch.Tensor): if ref_code_t.ndim == 3: ref_code_t = ref_code_t[0] diff --git a/vllm_omni/worker/base.py b/vllm_omni/worker/base.py index 67b6f77b3fe..57dff85b18b 100644 --- a/vllm_omni/worker/base.py +++ b/vllm_omni/worker/base.py @@ -7,6 +7,7 @@ import time from contextlib import AbstractContextManager, nullcontext +import numpy as np import torch from vllm.logger import init_logger from vllm.utils.mem_utils import format_gib, memory_profiling @@ -168,6 +169,45 @@ def determine_available_memory(self) -> int: return int(self.available_kv_cache_memory_bytes) + @torch.inference_mode() + def create_voice_clone_prompt( + self, + wav_samples: list[float], + sample_rate: int, + ref_text: str | None = None, + ) -> dict: + """RPC handler: extract speaker embedding + ref_code on GPU. + + Only effective for Qwen3-TTS talker models with speaker_encoder. + Both inputs and outputs must stay as plain Python types to survive + msgspec IPC (e.g. no numpy arrays in wav_samples). + """ + model = self.model_runner.model + + if not hasattr(model, "_extract_speaker_embedding"): + raise NotImplementedError(f"{type(model).__name__} does not support speaker embedding extraction") + + wav_np = np.array(wav_samples, dtype=np.float32) + icl_mode = ref_text is not None and ref_text.strip() != "" + + if icl_mode and not hasattr(model, "_encode_ref_audio_to_code"): + raise NotImplementedError(f"{type(model).__name__} does not support ref audio codec encoding") + + spk = model._extract_speaker_embedding(wav_np, sample_rate) + + ref_code_list = None + if icl_mode: + ref_code_t = model._encode_ref_audio_to_code(wav_np, sample_rate) + ref_code_list = ref_code_t.cpu().tolist() + + return { + "ref_spk_embedding": spk.cpu().tolist(), + "ref_code": ref_code_list, + "x_vector_only_mode": not icl_mode, + "icl_mode": icl_mode, + "ref_text": ref_text.strip() if icl_mode else None, + } + # Provide memory pool context def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: v1_config_enabled = False