Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions tests/entrypoints/openai_api/test_serving_speech_cosyvoice3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for CosyVoice3 online serving via /v1/audio/speech.

Covers the changes in PR #2121:
- model_stage rename (talker -> cosyvoice3_talker, code2wav -> cosyvoice3_code2wav)
- TTS model type detection for cosyvoice3
- CosyVoice3 prompt building in _prepare_speech_generation
"""

import asyncio
from unittest.mock import AsyncMock

import pytest
from pytest_mock import MockerFixture

from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest
from vllm_omni.entrypoints.openai.serving_speech import (
_COSYVOICE3_TTS_MODEL_STAGES,
_TTS_MODEL_STAGES,
OmniOpenAIServingSpeech,
)

pytestmark = [pytest.mark.core_model, pytest.mark.cpu]


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------


@pytest.fixture
def cosyvoice3_server(mocker: MockerFixture):
"""Create a speech server configured with a CosyVoice3 talker stage."""
mock_engine_client = mocker.MagicMock()
mock_engine_client.errored = False
mock_engine_client.tts_max_instructions_length = None

mock_stage = mocker.MagicMock()
mock_stage.engine_args.model_stage = "cosyvoice3_talker"
mock_stage.tts_args = {}
mock_engine_client.stage_configs = [mock_stage]

mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True

return OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)


# ---------------------------------------------------------------------------
# Tests: model_stage constants
# ---------------------------------------------------------------------------


class TestCosyVoice3ModelStage:
"""Verify model_stage rename is consistent."""

def test_cosyvoice3_talker_in_tts_stages(self):
assert "cosyvoice3_talker" in _COSYVOICE3_TTS_MODEL_STAGES
assert "cosyvoice3_talker" in _TTS_MODEL_STAGES

def test_old_stage_names_not_in_tts_stages(self):
"""Old generic names should not be registered."""
assert "talker" not in _COSYVOICE3_TTS_MODEL_STAGES
assert "code2wav" not in _COSYVOICE3_TTS_MODEL_STAGES


# ---------------------------------------------------------------------------
# Tests: TTS model type detection
# ---------------------------------------------------------------------------


class TestCosyVoice3Detection:
def test_detect_cosyvoice3_model_type(self, cosyvoice3_server):
assert cosyvoice3_server._is_tts is True
assert cosyvoice3_server._tts_model_type == "cosyvoice3"

def test_is_not_fish_or_voxtral(self, cosyvoice3_server):
assert cosyvoice3_server._is_fish_speech is False


# ---------------------------------------------------------------------------
# Tests: _prepare_speech_generation for CosyVoice3
# ---------------------------------------------------------------------------


class TestCosyVoice3PromptBuilding:
def test_requires_ref_audio(self, cosyvoice3_server):
"""CosyVoice3 must reject requests without ref_audio."""
req = OpenAICreateSpeechRequest(
input="Hello world",
ref_audio=None,
ref_text="reference text",
)
with pytest.raises(ValueError, match="ref_audio"):
asyncio.run(cosyvoice3_server._prepare_speech_generation(req))

def test_requires_ref_text(self, cosyvoice3_server):
"""CosyVoice3 must reject requests without ref_text."""
req = OpenAICreateSpeechRequest(
input="Hello world",
ref_audio="data:audio/wav;base64,UklGR...",
ref_text="",
)
with pytest.raises(ValueError, match="ref_text"):
asyncio.run(cosyvoice3_server._prepare_speech_generation(req))

def test_requires_nonempty_input(self, cosyvoice3_server):
"""CosyVoice3 must reject empty input text."""
req = OpenAICreateSpeechRequest(
input="",
ref_audio="data:audio/wav;base64,UklGR...",
ref_text="reference",
)
with pytest.raises(ValueError, match="cannot be empty"):
asyncio.run(cosyvoice3_server._prepare_speech_generation(req))

def test_builds_correct_prompt(self, cosyvoice3_server, mocker):
"""Verify prompt structure when all inputs are valid."""
dummy_audio = [0.0] * 16000
mocker.patch.object(
cosyvoice3_server,
"_resolve_ref_audio",
new_callable=AsyncMock,
return_value=(dummy_audio, 16000),
)

req = OpenAICreateSpeechRequest(
input="Hello world",
ref_audio="data:audio/wav;base64,UklGR...",
ref_text="reference transcript",
)

request_id, generator, tts_params = asyncio.run(cosyvoice3_server._prepare_speech_generation(req))

assert request_id.startswith("speech-")
# generator is an async generator from the engine; just check it exists
assert generator is not None
34 changes: 32 additions & 2 deletions vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
_VOXTRAL_TTS_MODEL_STAGES = {"audio_generation"}
_QWEN3_TTS_MODEL_STAGES = {"qwen3_tts"}
_FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"}
_TTS_MODEL_STAGES: set[str] = _VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES | _FISH_TTS_MODEL_STAGES
_COSYVOICE3_TTS_MODEL_STAGES = {"cosyvoice3_talker"}
_TTS_MODEL_STAGES: set[str] = (
_VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES | _FISH_TTS_MODEL_STAGES | _COSYVOICE3_TTS_MODEL_STAGES
)
_TTS_LANGUAGES: set[str] = {
"Auto",
"Chinese",
Expand Down Expand Up @@ -221,6 +224,8 @@ def _detect_tts_model_type(self) -> str | None:
return "voxtral_tts"
if model_stage in _FISH_TTS_MODEL_STAGES:
return "fish_tts"
if model_stage in _COSYVOICE3_TTS_MODEL_STAGES:
return "cosyvoice3"
return None

def _compute_max_instructions_length(self) -> int:
Expand Down Expand Up @@ -966,7 +971,30 @@ async def _prepare_speech_generation(
if validation_error:
raise ValueError(validation_error)

if self._tts_model_type == "voxtral_tts":
if self._tts_model_type == "cosyvoice3":
if not request.input or not request.input.strip():
raise ValueError("Input text cannot be empty")
if request.ref_audio is None:
raise ValueError(
"CosyVoice3 requires a reference audio for voice cloning. "
"Please provide 'ref_audio' in the request."
)
if not request.ref_text or not request.ref_text.strip():
raise ValueError("CosyVoice3 requires 'ref_text' (transcript of the reference audio)")
wav_list, sr = await self._resolve_ref_audio(request.ref_audio)
audio_data = (np.array(wav_list, dtype=np.float32), sr)
prompt = {
"prompt": request.input,
"multi_modal_data": {
"audio": audio_data,
},
"mm_processor_kwargs": {
"prompt_text": request.ref_text,
"sample_rate": sr,
},
}
tts_params = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Respect max_new_tokens for CosyVoice3 requests

The new CosyVoice3 path drops all per-request generation controls by setting tts_params = {} and never mapping request.max_new_tokens into sampling params, so /v1/audio/speech callers who set max_new_tokens for latency/cost control will have that limit silently ignored. This is observable whenever max_new_tokens is provided with a CosyVoice3 model and can lead to much longer-than-requested decoding runs.

Useful? React with 👍 / 👎.

elif self._tts_model_type == "voxtral_tts":
prompt = await self._build_voxtral_prompt(request)
tts_params = {}
else:
Expand All @@ -984,6 +1012,8 @@ async def _prepare_speech_generation(
request_id = f"speech-{random_uuid()}"
if self._is_fish_speech:
model_type = "fish_speech"
elif self._tts_model_type == "cosyvoice3":
model_type = "cosyvoice3"
elif self._tts_model_type == "voxtral_tts":
model_type = "voxtral_tts"
elif self._is_tts:
Expand Down
20 changes: 10 additions & 10 deletions vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model_stage = vllm_config.model_config.model_stage
self.model_dir = vllm_config.model_config.model
self.model = None
if self.model_stage == "talker":
if self.model_stage == "cosyvoice3_talker":
# Initialize talker stage (text to speech tokens)
from vllm_omni.model_executor.models.cosyvoice3.cosyvoice3_talker import CosyVoice3LM, VLLMQwen2Encoder

Expand All @@ -286,7 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# KV cache is now managed externally by vLLM's PagedAttention
# No need for self.llm_cache
self.model = self.talker
elif self.model_stage == "code2wav":
elif self.model_stage == "cosyvoice3_code2wav":
# Initialize code2wav stage (flow matching + vocoder)
from vllm_omni.model_executor.models.cosyvoice3.cosyvoice3_code2wav import CosyVoice3Code2Wav

Expand Down Expand Up @@ -322,7 +322,7 @@ def _create_llm_vllm_config(self, parent_config: VllmConfig) -> VllmConfig:
def compute_logits(self, hidden_states: torch.Tensor | OmniOutput) -> torch.Tensor | None:
if isinstance(hidden_states, OmniOutput):
hidden_states = hidden_states.text_hidden_states
if self.model_stage == "talker":
if self.model_stage == "cosyvoice3_talker":
logits = self.model.llm_decoder(hidden_states)
vocab_size = self.config.vocab_size
pad_size = vocab_size - logits.size(-1)
Expand All @@ -337,7 +337,7 @@ def compute_logits(self, hidden_states: torch.Tensor | OmniOutput) -> torch.Tens
raise RuntimeError(f"compute_logits is only valid for {self.model_stage}.")

def embed_multimodal(self, **kwargs: object) -> torch.Tensor:
if self.model_stage == "talker":
if self.model_stage == "cosyvoice3_talker":
speech_token = kwargs["speech_token"]
speech_token_emb = self.model.speech_embedding(speech_token)
return speech_token_emb
Expand All @@ -350,7 +350,7 @@ def embed_input_ids(
multimodal_embeddings=None,
is_multimodal=None,
) -> torch.Tensor:
if self.model_stage == "talker":
if self.model_stage == "cosyvoice3_talker":
if is_multimodal is not None and any(is_multimodal):
embed_tokens = self.model.llm.model.embed_tokens(input_ids)
sos = self.model.speech_embedding.weight[self.model.sos].reshape(1, -1)
Expand All @@ -363,7 +363,7 @@ def embed_input_ids(
else:
embed_tokens = self.model.speech_embedding.weight[input_ids]
return embed_tokens
elif self.model_stage == "code2wav":
elif self.model_stage == "cosyvoice3_code2wav":
assert input_ids.dim() == 1
hidden = int(self.config.hidden_size)
return torch.zeros(
Expand All @@ -381,7 +381,7 @@ def forward(
additional_information: dict[str, object] | None = None,
**kwargs: object,
) -> OmniOutput:
if self.model_stage == "talker":
if self.model_stage == "cosyvoice3_talker":
if inputs_embeds is None:
inputs_embeds = self.embed_input_ids(input_ids)

Expand All @@ -399,7 +399,7 @@ def forward(
}

return OmniOutput(text_hidden_states=hidden_states, multimodal_outputs=multimodal_outputs)
elif self.model_stage == "code2wav":
elif self.model_stage == "cosyvoice3_code2wav":
runtime_info = kwargs.get("runtime_additional_information", [])
if not runtime_info:
length = 30 * 24000
Expand All @@ -426,7 +426,7 @@ def forward(
raise ValueError(f"Unsupported model_stage: {self.model_stage}")

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if self.model_stage == "talker":
if self.model_stage == "cosyvoice3_talker":
# Load weights for text to speech LM stage using vLLM's weight loading
llm_weight_path = os.path.join(self.model_dir, "llm.pt")
device = next(self.parameters()).device
Expand Down Expand Up @@ -460,7 +460,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
self.model.llm_decoder.load_state_dict(llm_decoder_state)

self.model.to(device).eval()
elif self.model_stage == "code2wav":
elif self.model_stage == "cosyvoice3_code2wav":
# Load weights for code2wav stage (flow + hift)
device = next(self.parameters()).device
self.code2wav.load_weights(self.model_dir, device)
Expand Down
4 changes: 2 additions & 2 deletions vllm_omni/model_executor/stage_configs/cosyvoice3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ stage_args:
runtime:
devices: 0
engine_args:
model_stage: talker
model_stage: cosyvoice3_talker
worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
model_arch: CosyVoice3Model
Expand All @@ -27,7 +27,7 @@ stage_args:
runtime:
devices: 0
engine_args:
model_stage: code2wav
model_stage: cosyvoice3_code2wav
model_arch: CosyVoice3Model
trust_remote_code: true
worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker
Expand Down
Loading