Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion tests/e2e/online_serving/test_omnivoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@
import httpx
import pytest

from tests.conftest import OmniServerParams
from tests.conftest import OmniServerParams, generate_synthetic_audio
from tests.utils import hardware_test

try:
from transformers import HiggsAudioV2TokenizerModel # noqa: F401

_HAS_VOICE_CLONE = True
except ImportError:
_HAS_VOICE_CLONE = False

MODEL = "k2-fsa/OmniVoice"

STAGE_CONFIG = str(
Expand All @@ -40,6 +47,16 @@
MIN_AUDIO_BYTES = 5000


def _get_ref_audio_b64() -> str:
"""Generate synthetic speech for reference audio.

Returns:
Base64 data URL string (data:audio/wav;base64,...)
"""
audio_data = generate_synthetic_audio(duration=2, num_channels=1, sample_rate=24000)
return f"data:audio/wav;base64,{audio_data['base64']}"


def make_speech_request(
host: str,
port: int,
Expand Down Expand Up @@ -82,3 +99,102 @@ def test_speech_auto_voice(self, omni_server) -> None:
assert len(response.content) > MIN_AUDIO_BYTES, (
f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
)


def make_voice_clone_request(
host: str,
port: int,
text: str,
ref_audio_b64: str,
ref_text: str | None = None,
timeout: float = 180.0,
) -> httpx.Response:
"""Make a voice cloning request to the /v1/audio/speech endpoint.

Args:
host: Server host
port: Server port
text: Text to synthesize
ref_audio_b64: Base64-encoded reference audio data URL
ref_text: Optional transcript of reference audio
timeout: Request timeout in seconds

Returns:
httpx.Response object
"""
url = f"http://{host}:{port}/v1/audio/speech"
payload = {
"input": text,
"ref_audio": ref_audio_b64,
}
if ref_text:
payload["ref_text"] = ref_text

with httpx.Client(timeout=timeout) as client:
return client.post(url, json=payload)


@pytest.mark.skipif(not _HAS_VOICE_CLONE, reason="Voice cloning requires transformers>=5.3.0")
@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
class TestOmniVoiceVoiceCloning:
"""E2E tests for OmniVoice voice cloning functionality."""

@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_voice_clone_ref_audio_only(self, omni_server) -> None:
"""Test voice cloning with ref_audio only (x_vector mode)."""
ref_audio_b64 = _get_ref_audio_b64()

response = make_voice_clone_request(
host=omni_server.host,
port=omni_server.port,
text="Hello, this is a voice cloning test.",
ref_audio_b64=ref_audio_b64,
)

assert response.status_code == 200, f"Request failed: {response.text}"
assert response.headers.get("content-type") == "audio/wav"
assert verify_wav_audio(response.content), "Response is not valid WAV audio"
assert len(response.content) > MIN_AUDIO_BYTES, (
f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
)

@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_voice_clone_ref_audio_and_text(self, omni_server) -> None:
"""Test voice cloning with ref_audio and ref_text (in-context mode)."""
ref_audio_b64 = _get_ref_audio_b64()
ref_text = "This is the reference transcript."

response = make_voice_clone_request(
host=omni_server.host,
port=omni_server.port,
text="Hello, this is a voice cloning test with in-context learning.",
ref_audio_b64=ref_audio_b64,
ref_text=ref_text,
)

assert response.status_code == 200, f"Request failed: {response.text}"
assert response.headers.get("content-type") == "audio/wav"
assert verify_wav_audio(response.content), "Response is not valid WAV audio"
assert len(response.content) > MIN_AUDIO_BYTES, (
f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
)

@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_voice_clone_invalid_ref_audio_format(self, omni_server) -> None:
"""Test that invalid ref_audio format returns a clear error."""
response = make_voice_clone_request(
host=omni_server.host,
port=omni_server.port,
text="This should fail with invalid ref_audio.",
ref_audio_b64="not_a_valid_uri",
)

assert response.status_code in (400, 422), (
f"Expected 400/422 for invalid ref_audio format, got {response.status_code}"
)
87 changes: 76 additions & 11 deletions vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Iterable
from typing import ClassVar

import numpy as np
import torch
from tokenizers import Tokenizer as HFTokenizer
from torch import nn
Expand All @@ -30,6 +31,13 @@
from vllm_omni.model_executor.models.omnivoice.omnivoice_decoder import OmniVoiceDecoder
from vllm_omni.model_executor.models.omnivoice.omnivoice_generator import OmniVoiceGenerator

try:
from transformers import HiggsAudioV2TokenizerModel
except ImportError:
HiggsAudioV2TokenizerModel = None

import torchaudio

logger = init_logger(__name__)


Expand Down Expand Up @@ -79,6 +87,17 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
tokenizer_path = os.path.join(self.model_path, "tokenizer.json")
self.tokenizer = HFTokenizer.from_file(tokenizer_path)

# Audio tokenizer for voice cloning (requires transformers>=5.3)
if HiggsAudioV2TokenizerModel is not None:
audio_tokenizer_path = os.path.join(self.model_path, "audio_tokenizer")
self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
audio_tokenizer_path, device_map=self.device
).eval()
logger.info("HiggsAudioV2 tokenizer loaded for voice cloning on %s", self.device)
else:
self.audio_tokenizer = None
logger.warning("Voice cloning disabled (requires transformers>=5.3.0).")

# Duration estimator
self.duration_estimator = RuleDurationEstimator()

Expand All @@ -91,20 +110,46 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
self.class_temperature = self.config.class_temperature
self.sample_rate = self.config.sample_rate

def _encode_ref_audio(self, audio_signal: torch.Tensor, sr: int) -> torch.Tensor:
"""Encode reference audio to 8-codebook tokens for voice cloning."""
if self.audio_tokenizer is None:
raise RuntimeError("Audio tokenizer not available for voice cloning")
if audio_signal.dim() == 1:
audio_signal = audio_signal.unsqueeze(0)
# Resample to tokenizer's expected sample rate
target_sr = self.audio_tokenizer.config.sample_rate
if sr != target_sr:
audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr)
# Ensure mono [B, 1, samples]
if audio_signal.dim() == 2:
audio_signal = audio_signal.unsqueeze(1)
with torch.inference_mode():
tokens = self.audio_tokenizer.encode(
audio_signal.to(self.audio_tokenizer.device), return_dict=False
) # [B, 8, T_ref]
tokens = tokens.squeeze(0) # [8, T_ref]
return tokens

@torch.inference_mode()
def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
"""Generate speech audio from text.

Args:
req: Diffusion request containing text prompt(s).
"""Generate speech audio from text, optionally with voice cloning.

Returns:
DiffusionOutput with audio tensor in .output
Accepts either a plain text prompt or a structured dict:
{"text": "...", "ref_audio": (samples, sr), "ref_text": "...",
"lang": "...", "instruct": "..."}
"""
# Extract text from request
prompt = req.prompts[0] if req.prompts else ""
ref_audio = None
ref_text = None
lang = "None"
instruct = "None"

if isinstance(prompt, dict):
text = prompt.get("input", prompt.get("text", str(prompt)))
ref_audio = prompt.get("ref_audio")
ref_text = prompt.get("ref_text")
lang = prompt.get("lang") or "None"
instruct = prompt.get("instruct") or "None"
else:
text = str(prompt)

Expand All @@ -119,17 +164,37 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
target_len = self.duration_estimator.estimate_duration(text, "Nice to meet you.", 25)
target_len = max(1, int(target_len))

# Tokenize with control tokens
style = "<|denoise|><|lang_start|>None<|lang_end|><|instruct_start|>None<|instruct_end|>"
full_prompt = f"{style}<|text_start|>{text}<|text_end|>"
# Build text prompt with control tokens
style = f"<|denoise|><|lang_start|>{lang}<|lang_end|><|instruct_start|>{instruct}<|instruct_end|>"
if ref_text:
full_text = f"{ref_text} {text}"
else:
full_text = text
full_prompt = f"{style}<|text_start|>{full_text}<|text_end|>"
encoding = self.tokenizer.encode(full_prompt)
text_tokens = torch.tensor(encoding.ids, dtype=torch.long, device=device)
text_len = text_tokens.shape[0]

# Encode reference audio tokens if provided
ref_audio_tokens = None
if ref_audio is not None:
if self.audio_tokenizer is None:
raise RuntimeError(
"Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'"
)
audio_signal, sr = ref_audio
if isinstance(audio_signal, np.ndarray):
audio_signal = torch.from_numpy(audio_signal).float()
ref_audio_tokens = self._encode_ref_audio(audio_signal, int(sr)).to(device)

# Build conditional + unconditional batches [2, 8, max_len]
text_ids = text_tokens.unsqueeze(0).repeat(num_cb, 1)
target_ids = torch.full((num_cb, target_len), mask_id, dtype=torch.long, device=device)
cond_ids = torch.cat([text_ids, target_ids], dim=1)

if ref_audio_tokens is not None:
cond_ids = torch.cat([text_ids, ref_audio_tokens, target_ids], dim=1)
else:
cond_ids = torch.cat([text_ids, target_ids], dim=1)
cond_len = cond_ids.shape[1]

uncond_ids = target_ids.clone()
Expand Down
60 changes: 51 additions & 9 deletions vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,11 +1024,15 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int
URLs, ``data:`` base64 URIs, and ``file:`` local paths (the latter
gated by ``--allowed-local-media-path``).
"""
model_config = self.model_config
connector = MediaConnector(
allowed_local_media_path=model_config.allowed_local_media_path,
allowed_media_domains=model_config.allowed_media_domains,
)
# In diffusion mode, model_config may not be available
if self._diffusion_mode:
connector = MediaConnector()
else:
model_config = self.model_config
connector = MediaConnector(
allowed_local_media_path=model_config.allowed_local_media_path,
allowed_media_domains=model_config.allowed_media_domains,
)
wav_np, sr = await connector.fetch_audio_async(ref_audio_str)
wav_np = np.asarray(wav_np, dtype=np.float32)
if wav_np.ndim > 1:
Expand Down Expand Up @@ -1399,8 +1403,33 @@ async def _prepare_speech_generation(
prompt = await self._build_fish_speech_prompt_async(request, ref_audio_data=ref_audio_data)
tts_params = {}
elif self._tts_model_type == "omnivoice":
if not request.input or not request.input.strip():
raise ValueError("Input text cannot be empty")
tts_params = {}
prompt = request.input # Diffusion engine takes raw text
prompt: dict[str, Any] = {"input": request.input}
# Resolve ref_audio: explicit request param or uploaded voice
ref_src = request.ref_audio
if not ref_src and request.voice:
vl = request.voice.lower()
if vl in self.uploaded_speakers:
sp = self.uploaded_speakers[vl]
if sp.get("embedding_source") == "audio":
ref_src = self._get_uploaded_audio_data(request.voice)
if not ref_src:
raise ValueError(f"Audio for voice '{request.voice}' missing")
prompt["ref_text"] = sp.get("ref_text")
if ref_src:
fmt_err = self._validate_ref_audio_format(ref_src)
if fmt_err:
raise ValueError(fmt_err)
wav, sr = await self._resolve_ref_audio(ref_src)
prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr)
if request.ref_text:
prompt["ref_text"] = request.ref_text
if request.language:
prompt["lang"] = request.language
if request.instructions:
prompt["instruct"] = request.instructions
elif self._is_tts:
validation_error = self._validate_tts_request(request)
if validation_error:
Expand Down Expand Up @@ -1567,13 +1596,26 @@ async def _create_diffusion_speech(
from vllm_omni.outputs import OmniRequestOutput

try:
if not request.input or not request.input.strip():
raise ValueError("Input text cannot be empty")

request_id = f"speech-{random_uuid()}"
prompt = request.input
prompt: dict[str, Any] = {"input": request.input}
if request.ref_audio:
wav, sr = await self._resolve_ref_audio(request.ref_audio)
prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr)
if request.ref_text:
prompt["ref_text"] = request.ref_text
if request.language:
prompt["lang"] = request.language
if request.instructions:
prompt["instruct"] = request.instructions

logger.info(
"Diffusion TTS speech request %s: text=%r",
"Diffusion TTS speech request %s: text=%r, voice_clone=%s",
request_id,
prompt[:50] + "..." if len(prompt) > 50 else prompt,
request.input[:50] + "..." if len(request.input) > 50 else request.input,
"ref_audio" in prompt,
)

generator = self._diffusion_engine.generate(
Expand Down
Loading
Loading