diff --git a/tests/e2e/online_serving/test_omnivoice.py b/tests/e2e/online_serving/test_omnivoice.py index ec1981aab22..4a0069f4022 100644 --- a/tests/e2e/online_serving/test_omnivoice.py +++ b/tests/e2e/online_serving/test_omnivoice.py @@ -17,9 +17,16 @@ import httpx import pytest -from tests.conftest import OmniServerParams +from tests.conftest import OmniServerParams, generate_synthetic_audio from tests.utils import hardware_test +try: + from transformers import HiggsAudioV2TokenizerModel # noqa: F401 + + _HAS_VOICE_CLONE = True +except ImportError: + _HAS_VOICE_CLONE = False + MODEL = "k2-fsa/OmniVoice" STAGE_CONFIG = str( @@ -40,6 +47,16 @@ MIN_AUDIO_BYTES = 5000 +def _get_ref_audio_b64() -> str: + """Generate synthetic speech for reference audio. + + Returns: + Base64 data URL string (data:audio/wav;base64,...) + """ + audio_data = generate_synthetic_audio(duration=2, num_channels=1, sample_rate=24000) + return f"data:audio/wav;base64,{audio_data['base64']}" + + def make_speech_request( host: str, port: int, @@ -82,3 +99,102 @@ def test_speech_auto_voice(self, omni_server) -> None: assert len(response.content) > MIN_AUDIO_BYTES, ( f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}" ) + + +def make_voice_clone_request( + host: str, + port: int, + text: str, + ref_audio_b64: str, + ref_text: str | None = None, + timeout: float = 180.0, +) -> httpx.Response: + """Make a voice cloning request to the /v1/audio/speech endpoint. + + Args: + host: Server host + port: Server port + text: Text to synthesize + ref_audio_b64: Base64-encoded reference audio data URL + ref_text: Optional transcript of reference audio + timeout: Request timeout in seconds + + Returns: + httpx.Response object + """ + url = f"http://{host}:{port}/v1/audio/speech" + payload = { + "input": text, + "ref_audio": ref_audio_b64, + } + if ref_text: + payload["ref_text"] = ref_text + + with httpx.Client(timeout=timeout) as client: + return client.post(url, json=payload) + + +@pytest.mark.skipif(not _HAS_VOICE_CLONE, reason="Voice cloning requires transformers>=5.3.0") +@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True) +class TestOmniVoiceVoiceCloning: + """E2E tests for OmniVoice voice cloning functionality.""" + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_voice_clone_ref_audio_only(self, omni_server) -> None: + """Test voice cloning with ref_audio only (x_vector mode).""" + ref_audio_b64 = _get_ref_audio_b64() + + response = make_voice_clone_request( + host=omni_server.host, + port=omni_server.port, + text="Hello, this is a voice cloning test.", + ref_audio_b64=ref_audio_b64, + ) + + assert response.status_code == 200, f"Request failed: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + assert verify_wav_audio(response.content), "Response is not valid WAV audio" + assert len(response.content) > MIN_AUDIO_BYTES, ( + f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}" + ) + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_voice_clone_ref_audio_and_text(self, omni_server) -> None: + """Test voice cloning with ref_audio and ref_text (in-context mode).""" + ref_audio_b64 = _get_ref_audio_b64() + ref_text = "This is the reference transcript." + + response = make_voice_clone_request( + host=omni_server.host, + port=omni_server.port, + text="Hello, this is a voice cloning test with in-context learning.", + ref_audio_b64=ref_audio_b64, + ref_text=ref_text, + ) + + assert response.status_code == 200, f"Request failed: {response.text}" + assert response.headers.get("content-type") == "audio/wav" + assert verify_wav_audio(response.content), "Response is not valid WAV audio" + assert len(response.content) > MIN_AUDIO_BYTES, ( + f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}" + ) + + @pytest.mark.core_model + @pytest.mark.omni + @hardware_test(res={"cuda": "L4"}, num_cards=1) + def test_voice_clone_invalid_ref_audio_format(self, omni_server) -> None: + """Test that invalid ref_audio format returns a clear error.""" + response = make_voice_clone_request( + host=omni_server.host, + port=omni_server.port, + text="This should fail with invalid ref_audio.", + ref_audio_b64="not_a_valid_uri", + ) + + assert response.status_code in (400, 422), ( + f"Expected 400/422 for invalid ref_audio format, got {response.status_code}" + ) diff --git a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py index 568e2f51640..c330e91de8d 100644 --- a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py +++ b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py @@ -16,6 +16,7 @@ from collections.abc import Iterable from typing import ClassVar +import numpy as np import torch from tokenizers import Tokenizer as HFTokenizer from torch import nn @@ -30,6 +31,13 @@ from vllm_omni.model_executor.models.omnivoice.omnivoice_decoder import OmniVoiceDecoder from vllm_omni.model_executor.models.omnivoice.omnivoice_generator import OmniVoiceGenerator +try: + from transformers import HiggsAudioV2TokenizerModel +except ImportError: + HiggsAudioV2TokenizerModel = None + +import torchaudio + logger = init_logger(__name__) @@ -79,6 +87,17 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): tokenizer_path = os.path.join(self.model_path, "tokenizer.json") self.tokenizer = HFTokenizer.from_file(tokenizer_path) + # Audio tokenizer for voice cloning (requires transformers>=5.3) + if HiggsAudioV2TokenizerModel is not None: + audio_tokenizer_path = os.path.join(self.model_path, "audio_tokenizer") + self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained( + audio_tokenizer_path, device_map=self.device + ).eval() + logger.info("HiggsAudioV2 tokenizer loaded for voice cloning on %s", self.device) + else: + self.audio_tokenizer = None + logger.warning("Voice cloning disabled (requires transformers>=5.3.0).") + # Duration estimator self.duration_estimator = RuleDurationEstimator() @@ -91,20 +110,46 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): self.class_temperature = self.config.class_temperature self.sample_rate = self.config.sample_rate + def _encode_ref_audio(self, audio_signal: torch.Tensor, sr: int) -> torch.Tensor: + """Encode reference audio to 8-codebook tokens for voice cloning.""" + if self.audio_tokenizer is None: + raise RuntimeError("Audio tokenizer not available for voice cloning") + if audio_signal.dim() == 1: + audio_signal = audio_signal.unsqueeze(0) + # Resample to tokenizer's expected sample rate + target_sr = self.audio_tokenizer.config.sample_rate + if sr != target_sr: + audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr) + # Ensure mono [B, 1, samples] + if audio_signal.dim() == 2: + audio_signal = audio_signal.unsqueeze(1) + with torch.inference_mode(): + tokens = self.audio_tokenizer.encode( + audio_signal.to(self.audio_tokenizer.device), return_dict=False + ) # [B, 8, T_ref] + tokens = tokens.squeeze(0) # [8, T_ref] + return tokens + @torch.inference_mode() def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: - """Generate speech audio from text. - - Args: - req: Diffusion request containing text prompt(s). + """Generate speech audio from text, optionally with voice cloning. - Returns: - DiffusionOutput with audio tensor in .output + Accepts either a plain text prompt or a structured dict: + {"text": "...", "ref_audio": (samples, sr), "ref_text": "...", + "lang": "...", "instruct": "..."} """ - # Extract text from request prompt = req.prompts[0] if req.prompts else "" + ref_audio = None + ref_text = None + lang = "None" + instruct = "None" + if isinstance(prompt, dict): text = prompt.get("input", prompt.get("text", str(prompt))) + ref_audio = prompt.get("ref_audio") + ref_text = prompt.get("ref_text") + lang = prompt.get("lang") or "None" + instruct = prompt.get("instruct") or "None" else: text = str(prompt) @@ -119,17 +164,37 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: target_len = self.duration_estimator.estimate_duration(text, "Nice to meet you.", 25) target_len = max(1, int(target_len)) - # Tokenize with control tokens - style = "<|denoise|><|lang_start|>None<|lang_end|><|instruct_start|>None<|instruct_end|>" - full_prompt = f"{style}<|text_start|>{text}<|text_end|>" + # Build text prompt with control tokens + style = f"<|denoise|><|lang_start|>{lang}<|lang_end|><|instruct_start|>{instruct}<|instruct_end|>" + if ref_text: + full_text = f"{ref_text} {text}" + else: + full_text = text + full_prompt = f"{style}<|text_start|>{full_text}<|text_end|>" encoding = self.tokenizer.encode(full_prompt) text_tokens = torch.tensor(encoding.ids, dtype=torch.long, device=device) text_len = text_tokens.shape[0] + # Encode reference audio tokens if provided + ref_audio_tokens = None + if ref_audio is not None: + if self.audio_tokenizer is None: + raise RuntimeError( + "Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'" + ) + audio_signal, sr = ref_audio + if isinstance(audio_signal, np.ndarray): + audio_signal = torch.from_numpy(audio_signal).float() + ref_audio_tokens = self._encode_ref_audio(audio_signal, int(sr)).to(device) + # Build conditional + unconditional batches [2, 8, max_len] text_ids = text_tokens.unsqueeze(0).repeat(num_cb, 1) target_ids = torch.full((num_cb, target_len), mask_id, dtype=torch.long, device=device) - cond_ids = torch.cat([text_ids, target_ids], dim=1) + + if ref_audio_tokens is not None: + cond_ids = torch.cat([text_ids, ref_audio_tokens, target_ids], dim=1) + else: + cond_ids = torch.cat([text_ids, target_ids], dim=1) cond_len = cond_ids.shape[1] uncond_ids = target_ids.clone() diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 52944d50824..a95fa695156 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1024,11 +1024,15 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int URLs, ``data:`` base64 URIs, and ``file:`` local paths (the latter gated by ``--allowed-local-media-path``). """ - model_config = self.model_config - connector = MediaConnector( - allowed_local_media_path=model_config.allowed_local_media_path, - allowed_media_domains=model_config.allowed_media_domains, - ) + # In diffusion mode, model_config may not be available + if self._diffusion_mode: + connector = MediaConnector() + else: + model_config = self.model_config + connector = MediaConnector( + allowed_local_media_path=model_config.allowed_local_media_path, + allowed_media_domains=model_config.allowed_media_domains, + ) wav_np, sr = await connector.fetch_audio_async(ref_audio_str) wav_np = np.asarray(wav_np, dtype=np.float32) if wav_np.ndim > 1: @@ -1399,8 +1403,33 @@ async def _prepare_speech_generation( prompt = await self._build_fish_speech_prompt_async(request, ref_audio_data=ref_audio_data) tts_params = {} elif self._tts_model_type == "omnivoice": + if not request.input or not request.input.strip(): + raise ValueError("Input text cannot be empty") tts_params = {} - prompt = request.input # Diffusion engine takes raw text + prompt: dict[str, Any] = {"input": request.input} + # Resolve ref_audio: explicit request param or uploaded voice + ref_src = request.ref_audio + if not ref_src and request.voice: + vl = request.voice.lower() + if vl in self.uploaded_speakers: + sp = self.uploaded_speakers[vl] + if sp.get("embedding_source") == "audio": + ref_src = self._get_uploaded_audio_data(request.voice) + if not ref_src: + raise ValueError(f"Audio for voice '{request.voice}' missing") + prompt["ref_text"] = sp.get("ref_text") + if ref_src: + fmt_err = self._validate_ref_audio_format(ref_src) + if fmt_err: + raise ValueError(fmt_err) + wav, sr = await self._resolve_ref_audio(ref_src) + prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr) + if request.ref_text: + prompt["ref_text"] = request.ref_text + if request.language: + prompt["lang"] = request.language + if request.instructions: + prompt["instruct"] = request.instructions elif self._is_tts: validation_error = self._validate_tts_request(request) if validation_error: @@ -1567,13 +1596,26 @@ async def _create_diffusion_speech( from vllm_omni.outputs import OmniRequestOutput try: + if not request.input or not request.input.strip(): + raise ValueError("Input text cannot be empty") + request_id = f"speech-{random_uuid()}" - prompt = request.input + prompt: dict[str, Any] = {"input": request.input} + if request.ref_audio: + wav, sr = await self._resolve_ref_audio(request.ref_audio) + prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr) + if request.ref_text: + prompt["ref_text"] = request.ref_text + if request.language: + prompt["lang"] = request.language + if request.instructions: + prompt["instruct"] = request.instructions logger.info( - "Diffusion TTS speech request %s: text=%r", + "Diffusion TTS speech request %s: text=%r, voice_clone=%s", request_id, - prompt[:50] + "..." if len(prompt) > 50 else prompt, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + "ref_audio" in prompt, ) generator = self._diffusion_engine.generate( diff --git a/vllm_omni/model_executor/models/omnivoice/omnivoice.py b/vllm_omni/model_executor/models/omnivoice/omnivoice.py index a3603a3c398..7fde8f16faa 100644 --- a/vllm_omni/model_executor/models/omnivoice/omnivoice.py +++ b/vllm_omni/model_executor/models/omnivoice/omnivoice.py @@ -15,6 +15,7 @@ import numpy as np import torch import torch.nn as nn +import torchaudio from transformers.feature_extraction_utils import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -77,31 +78,21 @@ def _ensure_cached_runtime_components(self, model_dir: str, config: OmniVoiceCon self.text_tokenizer = AutoTokenizer.from_pretrained(model_dir) - # Audio tokenizer for encoding reference audio + # Audio tokenizer for encoding reference audio (requires transformers>=5.3) audio_tokenizer_path = os.path.join(model_dir, "audio_tokenizer") - if os.path.isdir(audio_tokenizer_path): - try: - from transformers import ( - AutoFeatureExtractor, - HiggsAudioV2TokenizerModel, - ) - except ImportError as e: - raise ImportError( - "OmniVoice voice cloning requires transformers with " - "HiggsAudioV2TokenizerModel. Upgrade transformers or " - "use text-only mode (no reference audio)." - ) from e + try: + from transformers import ( + AutoFeatureExtractor, + HiggsAudioV2TokenizerModel, + ) self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(audio_tokenizer_path, device_map="cpu") self.feature_extractor = AutoFeatureExtractor.from_pretrained(audio_tokenizer_path) self.audio_tokenizer.eval() - else: + except ImportError: self.audio_tokenizer = None self.feature_extractor = None - logger.warning( - "audio_tokenizer not found at %s, voice cloning disabled", - audio_tokenizer_path, - ) + logger.warning("Voice cloning disabled (requires transformers>=5.3.0).") self._cached_model_dir = model_dir @@ -166,20 +157,16 @@ def _call_hf_processor( if self.feature_extractor is not None: target_sr = self.feature_extractor.sampling_rate if sr != target_sr: - import torchaudio - audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr) # Encode reference audio to 8-codebook tokens - if self.audio_tokenizer is not None: - with torch.inference_mode(): - ref_audio_tokens = self.audio_tokenizer.encode(audio_signal) # [8, T_ref] - if ref_audio_tokens.dim() == 3: - ref_audio_tokens = ref_audio_tokens.squeeze(0) # [8, T_ref] - else: - raise RuntimeError( - "Audio tokenizer not available for voice cloning. Ensure audio_tokenizer/ exists in model directory." - ) + if self.audio_tokenizer is None: + raise RuntimeError("Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'") + + with torch.inference_mode(): + ref_audio_tokens = self.audio_tokenizer.encode(audio_signal) # [8, T_ref] + if ref_audio_tokens.dim() == 3: + ref_audio_tokens = ref_audio_tokens.squeeze(0) # [8, T_ref] ft = BatchFeature( {