diff --git a/tests/entrypoints/openai_api/test_serving_speech_cosyvoice3.py b/tests/entrypoints/openai_api/test_serving_speech_cosyvoice3.py new file mode 100644 index 00000000000..5f17b9382a2 --- /dev/null +++ b/tests/entrypoints/openai_api/test_serving_speech_cosyvoice3.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for CosyVoice3 online serving via /v1/audio/speech. + +Covers the changes in PR #2121: + - model_stage rename (talker -> cosyvoice3_talker, code2wav -> cosyvoice3_code2wav) + - TTS model type detection for cosyvoice3 + - CosyVoice3 prompt building in _prepare_speech_generation +""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest +from pytest_mock import MockerFixture + +from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest +from vllm_omni.entrypoints.openai.serving_speech import ( + _COSYVOICE3_TTS_MODEL_STAGES, + _TTS_MODEL_STAGES, + OmniOpenAIServingSpeech, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cosyvoice3_server(mocker: MockerFixture): + """Create a speech server configured with a CosyVoice3 talker stage.""" + mock_engine_client = mocker.MagicMock() + mock_engine_client.errored = False + mock_engine_client.tts_max_instructions_length = None + + mock_stage = mocker.MagicMock() + mock_stage.engine_args.model_stage = "cosyvoice3_talker" + mock_stage.tts_args = {} + mock_engine_client.stage_configs = [mock_stage] + + mock_models = mocker.MagicMock() + mock_models.is_base_model.return_value = True + + return OmniOpenAIServingSpeech( + engine_client=mock_engine_client, + models=mock_models, + request_logger=mocker.MagicMock(), + ) + + +# --------------------------------------------------------------------------- +# Tests: model_stage constants +# --------------------------------------------------------------------------- + + +class TestCosyVoice3ModelStage: + """Verify model_stage rename is consistent.""" + + def test_cosyvoice3_talker_in_tts_stages(self): + assert "cosyvoice3_talker" in _COSYVOICE3_TTS_MODEL_STAGES + assert "cosyvoice3_talker" in _TTS_MODEL_STAGES + + def test_old_stage_names_not_in_tts_stages(self): + """Old generic names should not be registered.""" + assert "talker" not in _COSYVOICE3_TTS_MODEL_STAGES + assert "code2wav" not in _COSYVOICE3_TTS_MODEL_STAGES + + +# --------------------------------------------------------------------------- +# Tests: TTS model type detection +# --------------------------------------------------------------------------- + + +class TestCosyVoice3Detection: + def test_detect_cosyvoice3_model_type(self, cosyvoice3_server): + assert cosyvoice3_server._is_tts is True + assert cosyvoice3_server._tts_model_type == "cosyvoice3" + + def test_is_not_fish_or_voxtral(self, cosyvoice3_server): + assert cosyvoice3_server._is_fish_speech is False + + +# --------------------------------------------------------------------------- +# Tests: _prepare_speech_generation for CosyVoice3 +# --------------------------------------------------------------------------- + + +class TestCosyVoice3PromptBuilding: + def test_requires_ref_audio(self, cosyvoice3_server): + """CosyVoice3 must reject requests without ref_audio.""" + req = OpenAICreateSpeechRequest( + input="Hello world", + ref_audio=None, + ref_text="reference text", + ) + with pytest.raises(ValueError, match="ref_audio"): + asyncio.run(cosyvoice3_server._prepare_speech_generation(req)) + + def test_requires_ref_text(self, cosyvoice3_server): + """CosyVoice3 must reject requests without ref_text.""" + req = OpenAICreateSpeechRequest( + input="Hello world", + ref_audio="data:audio/wav;base64,UklGR...", + ref_text="", + ) + with pytest.raises(ValueError, match="ref_text"): + asyncio.run(cosyvoice3_server._prepare_speech_generation(req)) + + def test_requires_nonempty_input(self, cosyvoice3_server): + """CosyVoice3 must reject empty input text.""" + req = OpenAICreateSpeechRequest( + input="", + ref_audio="data:audio/wav;base64,UklGR...", + ref_text="reference", + ) + with pytest.raises(ValueError, match="cannot be empty"): + asyncio.run(cosyvoice3_server._prepare_speech_generation(req)) + + def test_builds_correct_prompt(self, cosyvoice3_server, mocker): + """Verify prompt structure when all inputs are valid.""" + dummy_audio = [0.0] * 16000 + mocker.patch.object( + cosyvoice3_server, + "_resolve_ref_audio", + new_callable=AsyncMock, + return_value=(dummy_audio, 16000), + ) + + req = OpenAICreateSpeechRequest( + input="Hello world", + ref_audio="data:audio/wav;base64,UklGR...", + ref_text="reference transcript", + ) + + request_id, generator, tts_params = asyncio.run(cosyvoice3_server._prepare_speech_generation(req)) + + assert request_id.startswith("speech-") + # generator is an async generator from the engine; just check it exists + assert generator is not None diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 3c8278f14a6..f18431d6dff 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -33,7 +33,10 @@ _VOXTRAL_TTS_MODEL_STAGES = {"audio_generation"} _QWEN3_TTS_MODEL_STAGES = {"qwen3_tts"} _FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"} -_TTS_MODEL_STAGES: set[str] = _VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES | _FISH_TTS_MODEL_STAGES +_COSYVOICE3_TTS_MODEL_STAGES = {"cosyvoice3_talker"} +_TTS_MODEL_STAGES: set[str] = ( + _VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES | _FISH_TTS_MODEL_STAGES | _COSYVOICE3_TTS_MODEL_STAGES +) _TTS_LANGUAGES: set[str] = { "Auto", "Chinese", @@ -221,6 +224,8 @@ def _detect_tts_model_type(self) -> str | None: return "voxtral_tts" if model_stage in _FISH_TTS_MODEL_STAGES: return "fish_tts" + if model_stage in _COSYVOICE3_TTS_MODEL_STAGES: + return "cosyvoice3" return None def _compute_max_instructions_length(self) -> int: @@ -966,7 +971,30 @@ async def _prepare_speech_generation( if validation_error: raise ValueError(validation_error) - if self._tts_model_type == "voxtral_tts": + if self._tts_model_type == "cosyvoice3": + if not request.input or not request.input.strip(): + raise ValueError("Input text cannot be empty") + if request.ref_audio is None: + raise ValueError( + "CosyVoice3 requires a reference audio for voice cloning. " + "Please provide 'ref_audio' in the request." + ) + if not request.ref_text or not request.ref_text.strip(): + raise ValueError("CosyVoice3 requires 'ref_text' (transcript of the reference audio)") + wav_list, sr = await self._resolve_ref_audio(request.ref_audio) + audio_data = (np.array(wav_list, dtype=np.float32), sr) + prompt = { + "prompt": request.input, + "multi_modal_data": { + "audio": audio_data, + }, + "mm_processor_kwargs": { + "prompt_text": request.ref_text, + "sample_rate": sr, + }, + } + tts_params = {} + elif self._tts_model_type == "voxtral_tts": prompt = await self._build_voxtral_prompt(request) tts_params = {} else: @@ -984,6 +1012,8 @@ async def _prepare_speech_generation( request_id = f"speech-{random_uuid()}" if self._is_fish_speech: model_type = "fish_speech" + elif self._tts_model_type == "cosyvoice3": + model_type = "cosyvoice3" elif self._tts_model_type == "voxtral_tts": model_type = "voxtral_tts" elif self._is_tts: diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 87c5f323a45..3c3f8bcb9f1 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -268,7 +268,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model_stage = vllm_config.model_config.model_stage self.model_dir = vllm_config.model_config.model self.model = None - if self.model_stage == "talker": + if self.model_stage == "cosyvoice3_talker": # Initialize talker stage (text to speech tokens) from vllm_omni.model_executor.models.cosyvoice3.cosyvoice3_talker import CosyVoice3LM, VLLMQwen2Encoder @@ -286,7 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # KV cache is now managed externally by vLLM's PagedAttention # No need for self.llm_cache self.model = self.talker - elif self.model_stage == "code2wav": + elif self.model_stage == "cosyvoice3_code2wav": # Initialize code2wav stage (flow matching + vocoder) from vllm_omni.model_executor.models.cosyvoice3.cosyvoice3_code2wav import CosyVoice3Code2Wav @@ -322,7 +322,7 @@ def _create_llm_vllm_config(self, parent_config: VllmConfig) -> VllmConfig: def compute_logits(self, hidden_states: torch.Tensor | OmniOutput) -> torch.Tensor | None: if isinstance(hidden_states, OmniOutput): hidden_states = hidden_states.text_hidden_states - if self.model_stage == "talker": + if self.model_stage == "cosyvoice3_talker": logits = self.model.llm_decoder(hidden_states) vocab_size = self.config.vocab_size pad_size = vocab_size - logits.size(-1) @@ -337,7 +337,7 @@ def compute_logits(self, hidden_states: torch.Tensor | OmniOutput) -> torch.Tens raise RuntimeError(f"compute_logits is only valid for {self.model_stage}.") def embed_multimodal(self, **kwargs: object) -> torch.Tensor: - if self.model_stage == "talker": + if self.model_stage == "cosyvoice3_talker": speech_token = kwargs["speech_token"] speech_token_emb = self.model.speech_embedding(speech_token) return speech_token_emb @@ -350,7 +350,7 @@ def embed_input_ids( multimodal_embeddings=None, is_multimodal=None, ) -> torch.Tensor: - if self.model_stage == "talker": + if self.model_stage == "cosyvoice3_talker": if is_multimodal is not None and any(is_multimodal): embed_tokens = self.model.llm.model.embed_tokens(input_ids) sos = self.model.speech_embedding.weight[self.model.sos].reshape(1, -1) @@ -363,7 +363,7 @@ def embed_input_ids( else: embed_tokens = self.model.speech_embedding.weight[input_ids] return embed_tokens - elif self.model_stage == "code2wav": + elif self.model_stage == "cosyvoice3_code2wav": assert input_ids.dim() == 1 hidden = int(self.config.hidden_size) return torch.zeros( @@ -381,7 +381,7 @@ def forward( additional_information: dict[str, object] | None = None, **kwargs: object, ) -> OmniOutput: - if self.model_stage == "talker": + if self.model_stage == "cosyvoice3_talker": if inputs_embeds is None: inputs_embeds = self.embed_input_ids(input_ids) @@ -399,7 +399,7 @@ def forward( } return OmniOutput(text_hidden_states=hidden_states, multimodal_outputs=multimodal_outputs) - elif self.model_stage == "code2wav": + elif self.model_stage == "cosyvoice3_code2wav": runtime_info = kwargs.get("runtime_additional_information", []) if not runtime_info: length = 30 * 24000 @@ -426,7 +426,7 @@ def forward( raise ValueError(f"Unsupported model_stage: {self.model_stage}") def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - if self.model_stage == "talker": + if self.model_stage == "cosyvoice3_talker": # Load weights for text to speech LM stage using vLLM's weight loading llm_weight_path = os.path.join(self.model_dir, "llm.pt") device = next(self.parameters()).device @@ -460,7 +460,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self.model.llm_decoder.load_state_dict(llm_decoder_state) self.model.to(device).eval() - elif self.model_stage == "code2wav": + elif self.model_stage == "cosyvoice3_code2wav": # Load weights for code2wav stage (flow + hift) device = next(self.parameters()).device self.code2wav.load_weights(self.model_dir, device) diff --git a/vllm_omni/model_executor/stage_configs/cosyvoice3.yaml b/vllm_omni/model_executor/stage_configs/cosyvoice3.yaml index 13b6ff55bd6..3a800ccb8c7 100644 --- a/vllm_omni/model_executor/stage_configs/cosyvoice3.yaml +++ b/vllm_omni/model_executor/stage_configs/cosyvoice3.yaml @@ -9,7 +9,7 @@ stage_args: runtime: devices: 0 engine_args: - model_stage: talker + model_stage: cosyvoice3_talker worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler model_arch: CosyVoice3Model @@ -27,7 +27,7 @@ stage_args: runtime: devices: 0 engine_args: - model_stage: code2wav + model_stage: cosyvoice3_code2wav model_arch: CosyVoice3Model trust_remote_code: true worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker