From d7852a6947c56a69662bda2a2811a3ef403a1c6a Mon Sep 17 00:00:00 2001 From: Jianjun Wang <2089966424@qq.com> Date: Tue, 14 Apr 2026 18:05:46 +0800 Subject: [PATCH] fix: handle uploaded voice as ref_audio in Voxtral TTS When user provides an uploaded speaker voice, resolve to reference audio and pass as ref_audio to the Voxtral tokenizer instead of as voice name. Validates that the uploaded voice is audio-backed (not embedding-only) and raises clear ValueError if reference audio is missing. Closes #2547 --- .../entrypoints/openai/serving_speech.py | 939 +++++++++++++++--- 1 file changed, 802 insertions(+), 137 deletions(-) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 7bcf75ace9d..c799587b194 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1,17 +1,23 @@ import asyncio import base64 +import io import json import math import os import re +import struct +import tempfile import time from pathlib import Path from typing import Any import numpy as np +import soundfile as sf +import torch from fastapi import Request, UploadFile from fastapi.responses import Response, StreamingResponse from transformers.utils.hub import cached_file +from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.logger import init_logger from vllm.multimodal.media import MediaConnector @@ -20,15 +26,28 @@ from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.metadata_manager import MetadataManager from vllm_omni.entrypoints.openai.protocol.audio import ( + AudioResponse, + BatchSpeechRequest, + BatchSpeechResponse, CreateAudio, OpenAICreateSpeechRequest, + SpeechBatchItem, + SpeechBatchItemResult, +) +from vllm_omni.model_executor.models.fish_speech.prompt_utils import ( + build_fish_text_only_prompt_ids, + estimate_fish_voice_clone_prompt_len_from_normalized, + normalize_fish_voice_clone_texts, ) from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) -# TTS Configuration (currently supports Qwen3-TTS) -_TTS_MODEL_STAGES: set[str] = {"qwen3_tts"} +# TTS Configuration +_VOXTRAL_TTS_MODEL_STAGES = {"audio_generation"} +_QWEN3_TTS_MODEL_STAGES = {"qwen3_tts"} +_FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"} +_TTS_MODEL_STAGES: set[str] = _VOXTRAL_TTS_MODEL_STAGES | _QWEN3_TTS_MODEL_STAGES | _FISH_TTS_MODEL_STAGES _TTS_LANGUAGES: set[str] = { "Auto", "Chinese", @@ -42,11 +61,54 @@ "Spanish", "Italian", } +_REF_AUDIO_MIN_DURATION = 1.0 # seconds +_REF_AUDIO_MAX_DURATION = 30.0 # seconds _TTS_MAX_INSTRUCTIONS_LENGTH = 500 _TTS_MAX_NEW_TOKENS_MIN = 1 _TTS_MAX_NEW_TOKENS_MAX = 4096 +def _create_wav_header(sample_rate: int, num_channels: int = 1, bits_per_sample: int = 16) -> bytes: + """Create a WAV header with placeholder size values for streaming. + + Uses 0xFFFFFFFF as placeholder for data size fields, which is accepted + by most audio clients and matches OpenAI's streaming WAV implementation. + + Args: + sample_rate: Audio sample rate in Hz + num_channels: Number of audio channels (1 for mono, 2 for stereo) + bits_per_sample: Bits per sample (typically 16) + + Returns: + 44-byte WAV header as bytes + """ + byte_rate = sample_rate * num_channels * bits_per_sample // 8 + block_align = num_channels * bits_per_sample // 8 + + # Use 0xFFFFFFFF as placeholder for unknown size (streaming) + placeholder_size = 0xFFFFFFFF + + # ref https://docs.fileformat.com/audio/wav/ + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", # ChunkID + placeholder_size, # ChunkSize (placeholder) + b"WAVE", # Format + b"fmt ", # Subchunk1ID + 16, # Subchunk1Size (16 for PCM) + 1, # AudioFormat (1 for PCM) + num_channels, # NumChannels + sample_rate, # SampleRate + byte_rate, # ByteRate + block_align, # BlockAlign + bits_per_sample, # BitsPerSample + b"data", # Subchunk2ID + placeholder_size, # Subchunk2Size (placeholder) + ) + + return header + + def _sanitize_filename(filename: str) -> str: """Sanitize filename to prevent path traversal attacks. @@ -97,6 +159,14 @@ def __init__(self, *args, **kwargs): # Find and cache the TTS stage (if any) during initialization self._tts_stage = self._find_tts_stage() self._is_tts = self._tts_stage is not None + self._is_fish_speech = ( + self._tts_stage is not None + and getattr(getattr(self._tts_stage, "engine_args", None), "model_stage", None) == "fish_speech_slow_ar" + ) + self._fish_speech_tokenizer = None + + # Determine TTS model type or None + self._tts_model_type = self._detect_tts_model_type() # Cache TTS configuration values (computed once, reused per request) self._max_instructions_length = self._compute_max_instructions_length() @@ -113,6 +183,9 @@ def __init__(self, *args, **kwargs): logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") logger.info(f"Loaded {len(self.uploaded_speakers)} uploaded speakers") + # Batch configuration + self._batch_max_items: int = getattr(self.engine_client, "tts_batch_max_items", 32) + # Load speech tokenizer codec parameters for prompt length estimation self._codec_frame_rate: float | None = self._load_codec_frame_rate() @@ -150,15 +223,25 @@ def _load_codec_frame_rate(self) -> float | None: return None def _find_tts_stage(self): - """Find and return the TTS stage from the stage list, or None if not found.""" - stage_list = getattr(self.engine_client, "stage_list", None) - if stage_list is None: - return None - for stage in stage_list: - if getattr(stage, "model_stage", None) in _TTS_MODEL_STAGES: + """Find and return the TTS stage config, or None if not found.""" + for stage in self.engine_client.stage_configs: + if stage.engine_args.model_stage in _TTS_MODEL_STAGES: return stage return None + def _detect_tts_model_type(self) -> str | None: + """Detect TTS model type from the stage's model_stage attribute.""" + if self._tts_stage is None: + return None + model_stage = getattr(self._tts_stage.engine_args, "model_stage", None) + if model_stage in _QWEN3_TTS_MODEL_STAGES: + return "qwen3_tts" + if model_stage in _VOXTRAL_TTS_MODEL_STAGES: + return "voxtral_tts" + if model_stage in _FISH_TTS_MODEL_STAGES: + return "fish_tts" + return None + def _compute_max_instructions_length(self) -> int: """Compute max instructions length with precedence: CLI > stage config > default. @@ -181,16 +264,22 @@ def _compute_max_instructions_length(self) -> int: def _load_supported_speakers(self) -> set[str]: """Load supported speakers (case-insensitive) from the model configuration.""" try: - talker_config = self.engine_client.model_config.hf_config.talker_config + if self._tts_model_type == "voxtral_tts": + config = self.engine_client.model_config.hf_config.audio_config + else: + # Default is qwen3_tts path + config = self.engine_client.model_config.hf_config.talker_config # Check for speakers in either spk_id or speaker_id for attr_name in ["spk_id", "speaker_id"]: - speakers_dict = getattr(talker_config, attr_name, None) + if isinstance(config, dict): + speakers_dict = config.get(attr_name) + else: + speakers_dict = getattr(config, attr_name, None) if speakers_dict and isinstance(speakers_dict, dict): - # Normalize to lowercase for case-insensitive matching return {speaker.lower() for speaker in speakers_dict.keys()} - logger.warning("No speakers found in talker_config (checked spk_id and speaker_id)") + logger.warning("No speakers found in config (checked spk_id and speaker_id)") except Exception as e: logger.warning(f"Could not load speakers from model config: {e}") @@ -262,6 +351,41 @@ def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: logger.warning("Failed to estimate TTS prompt length, using fallback 2048: %s", e) return 2048 + def _estimate_fish_ref_code_len(self, ref_audio: object) -> int | None: + """Estimate Fish Speech semantic token length from raw reference audio.""" + from vllm_omni.model_executor.models.fish_speech.dac_utils import ( + DAC_HOP_LENGTH, + DAC_SAMPLE_RATE, + ) + + if not isinstance(ref_audio, (list, tuple)) or len(ref_audio) != 2: + return None + wav, sr = ref_audio + sr = int(sr) + n_samples = len(wav) + if sr <= 0 or n_samples <= 0: + return None + resampled_len = max(1, math.ceil(n_samples * DAC_SAMPLE_RATE / sr)) + return max(1, math.ceil(resampled_len / DAC_HOP_LENGTH)) + + def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object) -> int: + """Estimate Fish Speech clone prompt length without encoding reference audio.""" + try: + from transformers import AutoTokenizer + + if self._fish_speech_tokenizer is None: + model_name = self.engine_client.model_config.model + self._fish_speech_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tokenizer = self._fish_speech_tokenizer + semantic_len = self._estimate_fish_ref_code_len(ref_audio) + if semantic_len is None: + raise ValueError("Failed to estimate Fish Speech semantic token length") + return estimate_fish_voice_clone_prompt_len_from_normalized(tokenizer, text, ref_text, semantic_len) + except Exception as e: + logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e) + return 2048 + def _get_uploaded_audio_data(self, voice_name: str) -> str | None: """Get base64 encoded audio data for uploaded voice.""" voice_name_lower = voice_name.lower() @@ -292,8 +416,12 @@ def _get_uploaded_audio_data(self, voice_name: str) -> str | None: logger.error(f"Could not read audio file for voice {voice_name}: {e}") return None - async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> dict: - """Upload a new voice sample.""" + async def upload_voice( + self, audio_file: UploadFile, consent: str, name: str, *, ref_text: str | None = None + ) -> dict: + # Normalize ref_text: treat whitespace-only as absent + if ref_text is not None: + ref_text = ref_text.strip() or None # Validate file size (max 10MB) MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB audio_file.file.seek(0, 2) # Seek to end @@ -367,10 +495,29 @@ async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): raise ValueError("Invalid file path: potential path traversal attack detected") + # Read content and validate duration before saving + content = await audio_file.read() + try: + wav_np, sr = sf.read(io.BytesIO(content)) + duration = len(wav_np) / sr if sr > 0 else 0.0 + if duration < _REF_AUDIO_MIN_DURATION: + raise ValueError( + f"Reference audio too short ({duration:.1f}s). " + f"At least {_REF_AUDIO_MIN_DURATION:.0f}s of clear speech is required." + ) + if duration > _REF_AUDIO_MAX_DURATION: + raise ValueError( + f"Reference audio too long ({duration:.1f}s). " + f"Maximum {_REF_AUDIO_MAX_DURATION:.0f}s supported — use a shorter clip." + ) + except ValueError: + raise + except Exception as e: + logger.warning("Could not validate audio duration: %s", e) + # Save audio file try: with open(file_path, "wb") as f: - content = await audio_file.read() f.write(content) except Exception as e: raise ValueError(f"Failed to save audio file: {e}") @@ -384,9 +531,11 @@ async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> "mime_type": mime_type, "original_filename": audio_file.filename, "file_size": file_size, + "ref_text": ref_text, "cache_status": "pending", # The initial cache state is pending. "cache_file": None, # The initial cache file is empty. "cache_generated_at": None, # The initial cache generation time is empty. + "embedding_source": "audio", } # Save metadata using metadata manager (concurrency safe) @@ -406,13 +555,111 @@ async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> logger.info(f"Uploaded new voice '{name}' with consent ID '{consent}'") # Return voice information without exposing the server file path - return { + result = { "name": name, "consent": consent, "created_at": timestamp, "mime_type": mime_type, "file_size": file_size, } + if ref_text is not None: + result["ref_text"] = ref_text + return result + + async def upload_voice_embedding(self, embedding_json: str, consent: str, name: str) -> dict: + """Upload a voice from a pre-computed speaker embedding. + + Stores the embedding as a safetensors file and marks it immediately + ready (no audio processing needed). + + Args: + embedding_json: JSON-encoded list of floats (1024 or 2048 dim). + consent: Consent recording ID. + name: Name for the new voice. + + Returns: + dict with voice information. + """ + try: + embedding = json.loads(embedding_json) + except (json.JSONDecodeError, TypeError) as exc: + raise ValueError(f"'speaker_embedding' must be valid JSON: {exc}") from exc + + if not isinstance(embedding, list) or not embedding: + raise ValueError("'speaker_embedding' must be a non-empty list of numbers") + + if not all(isinstance(x, (int, float)) for x in embedding): + raise ValueError("'speaker_embedding' must contain only numeric values") + + if not all(math.isfinite(x) for x in embedding): + raise ValueError("'speaker_embedding' values must be finite (no NaN or Inf)") + + emb_dim = len(embedding) + if emb_dim not in {1024, 2048}: + logger.warning( + "speaker_embedding has %d dimensions; expected 1024 (0.6B) or 2048 (1.7B)", + emb_dim, + ) + + voice_name_lower = name.lower() + if voice_name_lower in self.uploaded_speakers: + raise ValueError(f"Voice '{name}' already exists") + + sanitized_name = _sanitize_filename(name) + sanitized_consent = _sanitize_filename(consent) + timestamp = int(time.time()) + + # Store as safetensors for efficient loading + try: + import torch + from safetensors.torch import save_file + + tensor = torch.tensor(embedding, dtype=torch.float32) + filename = f"{sanitized_name}_{sanitized_consent}_{timestamp}.safetensors" + file_path = self.uploaded_speakers_dir / filename + + if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): + raise ValueError("Invalid file path: potential path traversal attack detected") + + save_file({"speaker_embedding": tensor}, str(file_path)) + except ImportError: + raise ValueError("safetensors and torch are required for embedding upload") + + speaker_data = { + "name": name, + "consent": consent, + "file_path": str(file_path), + "created_at": timestamp, + "mime_type": "application/x-safetensors", + "original_filename": filename, + "file_size": file_path.stat().st_size, + "cache_status": "ready", + "cache_file": str(file_path), + "cache_generated_at": timestamp, + "embedding_source": "direct", + "embedding_dim": emb_dim, + } + + success = self.metadata_manager.create_speaker(voice_name_lower, speaker_data) + if not success: + try: + file_path.unlink() + except Exception: + pass + raise ValueError(f"Failed to create metadata for voice '{name}' (possibly already exists)") + + self.uploaded_speakers[voice_name_lower] = speaker_data + self.supported_speakers.add(voice_name_lower) + + logger.info(f"Uploaded voice '{name}' from speaker embedding ({emb_dim}-dim)") + + return { + "name": name, + "consent": consent, + "created_at": timestamp, + "embedding_source": "direct", + "embedding_dim": emb_dim, + } async def delete_voice(self, name: str) -> bool: """ @@ -449,16 +696,55 @@ async def delete_voice(self, name: str) -> bool: def _is_tts_model(self) -> bool: """Check if the current model is a supported TTS model.""" - stage_list = getattr(self.engine_client, "stage_list", None) - if stage_list: - for stage in stage_list: - model_stage = getattr(stage, "model_stage", None) - if model_stage in _TTS_MODEL_STAGES: - return True - return False + return any(stage.engine_args.model_stage in _TTS_MODEL_STAGES for stage in self.engine_client.stage_configs) def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: """Validate TTS request parameters. Returns error message or None.""" + if self._tts_model_type == "voxtral_tts": + return self._validate_voxtral_tts_request(request) + if self._tts_model_type == "fish_tts": + return self._validate_fish_tts_request(request) + return self._validate_qwen_tts_request(request) + + def _validate_ref_audio_format(self, ref_audio: str) -> str | None: + """Validate ref_audio is a supported URI format. Returns error or None.""" + if not ( + ref_audio.startswith(("http://", "https://")) + or ref_audio.startswith("data:") + or ref_audio.startswith("file://") + ): + return "ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)" + return None + + def _validate_voxtral_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate Voxtral TTS request parameters. Returns error message or None.""" + if not request.input or not request.input.strip(): + return "Input text cannot be empty" + + # Voxtral TTS requires either a preset voice or ref_audio for voice cloning. + if request.voice is None and request.ref_audio is None: + return "Either 'voice' (preset speaker) or 'ref_audio' (voice cloning) must be provided" + + if request.ref_audio is not None: + fmt_err = self._validate_ref_audio_format(request.ref_audio) + if fmt_err: + return fmt_err + + if request.voice is not None: + request.voice = request.voice.lower() + if self.supported_speakers and request.voice not in self.supported_speakers: + return f"Invalid speaker '{request.voice}'. Supported: {', '.join(sorted(self.supported_speakers))}" + + if request.max_new_tokens is not None: + if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN: + return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}" + if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX: + return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}" + + return None + + def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate Qwen TTS request parameters. Returns error message or None.""" # Infer Base task when ref_audio or ref_text is provided without explicit task_type. if request.task_type is None and (request.ref_audio is not None or request.ref_text is not None): request.task_type = "Base" @@ -485,29 +771,48 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non "or use a CustomVoice model." ) if request.voice is not None and request.voice not in self.supported_speakers: - return f"Invalid speaker '{request.voice}'. Supported: {', '.join(sorted(self.supported_speakers))}" - + return f"Invalid voice '{request.voice}'. Supported: {', '.join(sorted(self.supported_speakers))}" + + # Validate speaker_embedding constraints + if request.speaker_embedding is not None: + if task_type != "Base": + return "'speaker_embedding' is only valid for Base task" + if not request.speaker_embedding: + return "'speaker_embedding' must be a non-empty list of floats" + # speaker_embedding implies x_vector_only_mode — set it before + # Base task validation so callers don't need to pass it explicitly. + request.x_vector_only_mode = True + emb_len = len(request.speaker_embedding) + # ECAPA-TDNN produces 1024-dim (0.6B) or 2048-dim (1.7B) + expected_dims = {1024, 2048} + if emb_len not in expected_dims: + logger.warning( + "speaker_embedding has %d dimensions; expected 1024 " + "(0.6B model) or 2048 (1.7B model). Wrong dimensions " + "will likely result in errors or degraded quality.", + emb_len, + ) # Validate Base task requirements if task_type == "Base": if request.voice is None: - if request.ref_audio is None: - return "Base task requires 'ref_audio' for voice cloning" - # Validate ref_audio format (include file:// from upstream) - if not ( - request.ref_audio.startswith(("http://", "https://")) - or request.ref_audio.startswith("data:") - or request.ref_audio.startswith("file://") - ): - return "ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)" - # In-context voice cloning (default) requires non-empty ref_text. - # x_vector_only_mode skips in-context and only uses speaker embedding. - if not request.x_vector_only_mode: + # 1. Ensure a voice source is provided + if request.ref_audio is None and getattr(request, "speaker_embedding", None) is None: + return "Base task requires 'ref_audio' or 'speaker_embedding' for voice cloning" + # 2. Validate ref_audio format if it exists (using the helper from main) + if request.ref_audio is not None: + fmt_err = self._validate_ref_audio_format(request.ref_audio) + if fmt_err: + return fmt_err + # 3. Validate text requirements based on the mode + if not getattr(request, "x_vector_only_mode", False): if not request.ref_text or not request.ref_text.strip(): return ( "Base task requires non-empty 'ref_text' (transcript of " "the reference audio) unless 'x_vector_only_mode' is enabled" ) else: + # Handle the case where request.voice is NOT None + pass # voice is not None voice_lower = request.voice.lower() if voice_lower in self.uploaded_speakers: @@ -522,15 +827,9 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return ( f"Base task with built-in speaker '{request.voice}' requires 'ref_audio' for voice cloning" ) - # Validate ref_audio format for built-in speaker - if not ( - request.ref_audio.startswith(("http://", "https://")) - or request.ref_audio.startswith("data:") - or request.ref_audio.startswith("file://") - ): - return ( - "ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)" - ) + fmt_err = self._validate_ref_audio_format(request.ref_audio) + if fmt_err: + return fmt_err # Validate cross-parameter dependencies if task_type != "Base": @@ -556,6 +855,26 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non return None + def _validate_fish_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: + """Validate Fish Speech request parameters. Returns error message or None.""" + if not request.input or not request.input.strip(): + return "Input text cannot be empty" + + if request.ref_audio is not None: + fmt_err = self._validate_ref_audio_format(request.ref_audio) + if fmt_err: + return fmt_err + if not request.ref_text or not request.ref_text.strip(): + return "Voice cloning requires 'ref_text' (transcript of the reference audio)" + + if request.max_new_tokens is not None: + if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN: + return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}" + if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX: + return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}" + + return None + async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int]: """Resolve ref_audio to (wav_samples, sample_rate). @@ -572,10 +891,22 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int wav_np = np.asarray(wav_np, dtype=np.float32) if wav_np.ndim > 1: wav_np = np.mean(wav_np, axis=-1) - return wav_np.tolist(), int(sr) + sr = int(sr) + duration = len(wav_np) / sr if sr > 0 else 0.0 + if duration < _REF_AUDIO_MIN_DURATION: + raise ValueError( + f"Reference audio too short ({duration:.1f}s). " + f"At least {_REF_AUDIO_MIN_DURATION:.0f}s of clear speech is required." + ) + if duration > _REF_AUDIO_MAX_DURATION: + raise ValueError( + f"Reference audio too long ({duration:.1f}s). " + f"Maximum {_REF_AUDIO_MAX_DURATION:.0f}s supported — use a shorter clip." + ) + return wav_np.tolist(), sr - async def _generate_pcm_chunks(self, generator, request_id: str): - """Generate PCM audio chunks for streaming response. + async def _generate_audio_chunks(self, generator, request_id: str, response_format: str = "pcm"): + """Generate audio chunks for streaming response. Handles two audio output modes from the engine: - Cumulative mode (list): Engine returns growing list of chunks; @@ -586,12 +917,15 @@ async def _generate_pcm_chunks(self, generator, request_id: str): Args: generator: Async generator from the engine request_id: Request identifier for logging + response_format: Audio format (pcm or wav) Yields: - Raw PCM bytes for each audio chunk + Raw audio bytes for each chunk (with WAV header for first chunk if wav format) """ prev_count = 0 sample_rate_val = 24000 + first_chunk = True + try: async for res in generator: audio_output, audio_key = self._extract_audio_output(res) @@ -622,6 +956,18 @@ async def _generate_pcm_chunks(self, generator, request_id: str): ) if chunk_np.ndim > 1: chunk_np = chunk_np.squeeze() + # For WAV format, emit header before first audio chunk + if response_format == "wav" and first_chunk: + # Assert that sample rate has been set from chunk metadata (not just default) + # This ensures the WAV header contains the correct sample rate + assert sr_raw is not None, ( + "First audio chunk must include sample rate metadata for WAV streaming" + ) + wav_header = _create_wav_header(sample_rate=sample_rate_val, num_channels=1, bits_per_sample=16) + yield wav_header + first_chunk = False + + # Convert audio to PCM bytes audio_obj = CreateAudio( audio_tensor=chunk_np, sample_rate=sample_rate_val, @@ -636,6 +982,7 @@ async def _generate_pcm_chunks(self, generator, request_id: str): raise except Exception as e: logger.exception("Streaming speech generation failed for %s: %s", request_id, e) + raise @staticmethod def _extract_audio_output(res) -> tuple[dict | None, str | None]: @@ -650,7 +997,7 @@ def _extract_audio_output(res) -> tuple[dict | None, str | None]: mm = getattr(ro, "multimodal_output", None) if ro else None if not mm: return None, None - key = "audio" if "audio" in mm else None + key = "audio" if "audio" in mm else ("model_outputs" if "model_outputs" in mm else None) return mm, key def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: @@ -680,15 +1027,22 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any if request.voice is not None: params["speaker"] = [request.voice] - # If voice is an uploaded speaker and no ref_audio provided, auto-set it + # Uploaded voices use task_type="Base" (CustomVoice requires built-in spk_id). + # If ref_text was provided at upload time, use in-context cloning; otherwise x_vector only. if request.voice.lower() in self.uploaded_speakers and request.ref_audio is None: audio_data = self._get_uploaded_audio_data(request.voice) - if audio_data: - params["ref_audio"] = [audio_data] - params["x_vector_only_mode"] = [True] - logger.info(f"Auto-set ref_audio for uploaded voice: {request.voice}") - else: + if not audio_data: raise ValueError(f"Audio file for uploaded voice '{request.voice}' is missing or corrupted") + speaker_info = self.uploaded_speakers[request.voice.lower()] + stored_ref_text = speaker_info.get("ref_text") + params["ref_audio"] = [audio_data] + params["task_type"] = ["Base"] + if stored_ref_text: + params["ref_text"] = [stored_ref_text] + params["x_vector_only_mode"] = [False] + else: + params["x_vector_only_mode"] = [True] + logger.info("Auto-set ref_audio for uploaded voice: %s (icl=%s)", request.voice, bool(stored_ref_text)) elif params["task_type"][0] == "CustomVoice": params["speaker"] = ["Vivian"] # Default for CustomVoice @@ -702,7 +1056,18 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any # Voice clone: ref_audio resolved in create_speech(), not here. if request.ref_text is not None: params["ref_text"] = [request.ref_text] - if request.x_vector_only_mode is not None: + if request.speaker_embedding is not None: + # Store as plain float list (not tensor) so it survives msgspec + # serialization through the EngineCore IPC boundary. The talker's + # _build_prompt_embeds converts it back to a tensor on the GPU. + params["voice_clone_prompt"] = [ + { + "ref_spk_embedding": list(request.speaker_embedding), + } + ] + # speaker_embedding implies x_vector_only_mode + params["x_vector_only_mode"] = [True] + elif request.x_vector_only_mode is not None: params["x_vector_only_mode"] = [request.x_vector_only_mode] # Generation parameters @@ -721,6 +1086,271 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any return params + # ---- Voxtral TTS helpers ---- + + async def _build_voxtral_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: + """Build Voxtral TTS engine prompt from shared TTS parameters.""" + from mistral_common.protocol.speech.request import SpeechRequest + + text = request.input + voice = request.voice + ref_audio = request.ref_audio + assert voice or ref_audio, "Either voice or ref_audio must be provided" + # Strip data URI prefix — mistral_common expects raw base64 + if ref_audio is not None and isinstance(ref_audio, str) and ref_audio.startswith("data:"): + _, _, ref_audio = ref_audio.partition(",") + if self._tts_tokenizer is None: + from vllm.tokenizers import cached_tokenizer_from_config + + mistral_tokenizer = cached_tokenizer_from_config(self.engine_client.model_config) + self._tts_tokenizer = mistral_tokenizer.instruct + + if voice is not None: + # Check if it's an uploaded voice with stored reference audio + voice_lower = voice.lower() + if voice_lower in self.uploaded_speakers: + speaker_info = self.uploaded_speakers[voice_lower] + mime_type = speaker_info.get("mime_type", "audio/wav") + embedding_source = speaker_info.get("embedding_source") + is_audio_backed = embedding_source == "audio" or mime_type.startswith("audio/") + + if not is_audio_backed: + raise ValueError( + f"Uploaded voice '{voice}' is embedding-only and cannot be used as Voxtral " + "reference audio. Please provide an audio-backed uploaded voice or pass ref_audio." + ) + + # Get reference audio from stored file + ref_audio_data = self._get_uploaded_audio_data(voice) + if ref_audio_data is None: + raise ValueError( + f"Reference audio for uploaded voice '{voice}' is missing or unreadable. " + "Please re-upload the voice sample and try again." + ) + # Strip data URI prefix + _, _, ref_audio = ref_audio_data.partition(",") + else: + # Built-in voice name — pass directly to tokenizer + tokens = self._tts_tokenizer.encode_speech_request(SpeechRequest(input=text, voice=voice)).tokens + return { + "prompt_token_ids": tokens, + "additional_information": {"voice": [voice]}, + } + + # Use ref_audio (either from request or resolved from uploaded voice) + tokenized = self._tts_tokenizer.encode_speech_request(SpeechRequest(input=text, ref_audio=ref_audio)) + audio = tokenized.audios[0] + return { + "prompt_token_ids": tokenized.tokens, + "multi_modal_data": {"audio": [(audio.audio_array, audio.sampling_rate)]}, + } + + # ---- Fish Speech helpers ---- + + def _build_fish_speech_prompt( + self, + request: OpenAICreateSpeechRequest, + ref_audio_data: tuple[list[float], int] | None = None, + ) -> dict[str, Any]: + """Build prompt for Fish Speech S2 Pro. + + Without voice cloning: + <|im_start|>system\\nconvert the provided text to speech<|im_end|> + <|im_start|>user\\n{text}<|im_end|>\\n<|im_start|>assistant\\n<|voice|> + + With voice cloning (ref_audio + ref_text): + <|im_start|>system\\nconvert the provided text to speech reference to the following... + <|im_end|>\\n<|im_start|>user\\n{text}<|im_end|>\\n<|im_start|>assistant\\n<|voice|> + """ + from transformers import AutoTokenizer + + if self._fish_speech_tokenizer is None: + model_name = self.engine_client.model_config.model + self._fish_speech_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tokenizer = self._fish_speech_tokenizer + + if ref_audio_data is None or not request.ref_text: + prompt_ids, normalized_text = build_fish_text_only_prompt_ids(tokenizer, request.input) + + # Keep the prompt-dict metadata shape aligned with the existing text-only + # TTS entrypoints: scalar values are wrapped in single-item lists before + # EngineCore serialization. Structured clone below is different because + # model-side preprocess consumes concrete per-request scalar fields. + additional_information: dict[str, Any] = { + "text": [normalized_text], + } + if request.max_new_tokens is not None: + additional_information["max_new_tokens"] = [request.max_new_tokens] + return { + "prompt_token_ids": prompt_ids, + "additional_information": additional_information, + } + + wav_samples, sr = ref_audio_data + normalized_text, normalized_ref_text = normalize_fish_voice_clone_texts(request.input, request.ref_text) + ph_len = self._estimate_fish_prompt_len(normalized_text, normalized_ref_text, ref_audio_data) + with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f: + np.save(f, np.asarray(wav_samples, dtype=np.float32)) + ref_audio_path = f.name + + # Structured clone metadata is consumed directly by + # FishSpeechSlowARForConditionalGeneration.preprocess(), so keep these + # values as scalars instead of the list-wrapped prompt-dict convention. + additional_information = { + "text": normalized_text, + "ref_text": normalized_ref_text, + "ref_audio_path": ref_audio_path, + "ref_audio_sr": int(sr), + "fish_structured_voice_clone": True, + } + if request.max_new_tokens is not None: + additional_information["max_new_tokens"] = request.max_new_tokens + return { + "prompt_token_ids": [1] * ph_len, + "additional_information": additional_information, + } + + # ---- Common speech generation helpers ---- + + async def _prepare_speech_generation( + self, + request: OpenAICreateSpeechRequest, + ) -> tuple[str, Any, dict[str, Any]]: + if self.engine_client.errored: + raise self.engine_client.dead_error + + if self._is_fish_speech: + validation_error = self._validate_fish_tts_request(request) + if validation_error: + raise ValueError(validation_error) + ref_audio_data = None + if request.ref_audio is not None: + wav_list, sr = await self._resolve_ref_audio(request.ref_audio) + ref_audio_data = (wav_list, sr) + prompt = self._build_fish_speech_prompt(request, ref_audio_data=ref_audio_data) + tts_params = {} + elif self._is_tts: + validation_error = self._validate_tts_request(request) + if validation_error: + raise ValueError(validation_error) + + if self._tts_model_type == "voxtral_tts": + prompt = await self._build_voxtral_prompt(request) + tts_params = {} + else: + tts_params = self._build_tts_params(request) + # Resolve ref_audio (explicit or auto-set for uploaded voices) + # to [[wav_list, sr]] so the model doesn't re-decode base64. + ref_audio_source = request.ref_audio + if ref_audio_source is None and isinstance(tts_params.get("ref_audio"), list): + # Uploaded voice: ref_audio was auto-set as [base64_data_url] + ref_audio_source = tts_params["ref_audio"][0] + if ref_audio_source is not None and isinstance(ref_audio_source, str): + wav_list, sr = await self._resolve_ref_audio(ref_audio_source) + tts_params["ref_audio"] = [[wav_list, sr]] + + ph_len = self._estimate_prompt_len(tts_params) + prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params} + else: + tts_params = {} + prompt = {"prompt": request.input} + + request_id = f"speech-{random_uuid()}" + if self._is_fish_speech: + model_type = "fish_speech" + elif self._tts_model_type == "voxtral_tts": + model_type = "voxtral_tts" + elif self._is_tts: + model_type = tts_params.get("task_type", ["unknown"])[0] + else: + model_type = "generic" + logger.info( + "TTS speech request %s: text=%r, model=%s", + request_id, + request.input[:50] + "..." if len(request.input) > 50 else request.input, + model_type, + ) + + sampling_params_list = self.engine_client.default_sampling_params_list + + # Fish defaults come from stage_configs YAML. Only override when the caller + # explicitly requests a different generation length. + if self._is_fish_speech and request.max_new_tokens is not None and sampling_params_list: + import copy + + sampling_params_list = copy.deepcopy(sampling_params_list) + sampling_params_list[0].max_tokens = request.max_new_tokens + + generator = self.engine_client.generate( + prompt=prompt, + request_id=request_id, + sampling_params_list=sampling_params_list, + output_modalities=["audio"], + ) + return request_id, generator, tts_params + + async def _iter_pcm_audio_bytes(self, request: OpenAICreateSpeechRequest): + """Yield raw PCM bytes for a speech request as soon as chunks are decoded.""" + request_id, generator, _ = await self._prepare_speech_generation(request) + async for chunk in self._generate_pcm_chunks(generator, request_id): + yield chunk + + async def _generate_audio_bytes( + self, + request: OpenAICreateSpeechRequest, + base64_encode: bool = False, + ) -> tuple[bytes | str, str]: + request_id, generator, _ = await self._prepare_speech_generation(request) + + final_output: OmniRequestOutput | None = None + async for res in generator: + final_output = res + + if final_output is None: + raise ValueError("No output generated from the model.") + + audio_output, audio_key = self._extract_audio_output(final_output) + if audio_key is None: + raise ValueError("TTS model did not produce audio output.") + + audio_tensor = audio_output[audio_key] + sr_raw = audio_output.get("sr", 24000) + sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw + sample_rate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) + + if isinstance(audio_tensor, list): + async_chunk = bool(getattr(self.engine_client.model_config, "async_chunk", False)) + if async_chunk: + non_empty_chunks = [candidate for candidate in audio_tensor if candidate.numel() > 0] + audio_tensor = ( + torch.cat(non_empty_chunks, dim=-1) if non_empty_chunks else np.zeros((0,), dtype=np.float32) + ) + else: + audio_history = audio_tensor + audio_tensor = np.zeros((0,), dtype=np.float32) + # Non-async Qwen3-TTS returns cumulative history snapshots, so keep the latest non-empty tensor. + for candidate in reversed(audio_history): + if candidate.numel() > 0: + audio_tensor = candidate + break + if hasattr(audio_tensor, "float"): + audio_tensor = audio_tensor.float().detach().cpu().numpy() + + if audio_tensor.ndim > 1: + audio_tensor = audio_tensor.squeeze() + + audio_obj = CreateAudio( + audio_tensor=audio_tensor, + sample_rate=sample_rate, + response_format=request.response_format or "wav", + speed=request.speed or 1.0, + stream_format=request.stream_format, + base64_encode=base64_encode, + ) + audio_response: AudioResponse = self.create_audio(audio_obj) + return audio_response.audio_data, audio_response.media_type + async def create_speech( self, request: OpenAICreateSpeechRequest, @@ -742,97 +1372,43 @@ async def create_speech( - ref_text: Transcript of reference audio (Base task) - x_vector_only_mode: Use speaker embedding only (Base task) - Streaming is supported via stream=True with response_format='pcm'. - Each Code2Wav chunk is yielded as raw PCM bytes as soon as it is decoded. + Streaming is supported via stream=True with response_format='pcm' or 'wav'. + Each Code2Wav chunk is yielded as raw audio bytes as soon as it is decoded. + For WAV format, a header with placeholder size values is emitted first. """ error_check_ret = await self._check_model(request) if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) return error_check_ret - if self.engine_client.errored: - raise self.engine_client.dead_error - - request_id = f"speech-{random_uuid()}" - try: - if self._is_tts: - # Validate TTS parameters - validation_error = self._validate_tts_request(request) - if validation_error: - return self.create_error_response(validation_error) - - tts_params = self._build_tts_params(request) - if request.ref_audio is not None: - wav_list, sr = await self._resolve_ref_audio(request.ref_audio) - tts_params["ref_audio"] = [[wav_list, sr]] - - # Prompt length must match model-side embeddings; values are placeholders. - ph_len = self._estimate_prompt_len(tts_params) - prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params} - else: - tts_params = {} - prompt = {"prompt": request.input} - - logger.info( - "TTS speech request %s: text=%r, task_type=%s", - request_id, - request.input[:50] + "..." if len(request.input) > 50 else request.input, - tts_params.get("task_type", ["unknown"])[0], - ) - - sampling_params_list = self.engine_client.default_sampling_params_list + if request.stream: + # Determine response format and media type for streaming + response_format = (request.response_format or "wav").lower() + + # Only pcm and wav support streaming without post-processing + if response_format not in ["pcm", "wav"]: + return self.create_error_response( + f"Streaming is only supported for 'pcm' and 'wav' formats. " + f"Got '{response_format}'. For other formats, use stream=False." + ) - generator = self.engine_client.generate( - prompt=prompt, - request_id=request_id, - sampling_params_list=sampling_params_list, - output_modalities=["audio"], - ) + # Check if speed adjustment is requested (not compatible with streaming) + if request.speed is not None and request.speed != 1.0: + return self.create_error_response( + "Streaming is not supported with speed adjustment. " + "Use stream=False or remove the speed parameter." + ) - if request.stream: + media_type = "audio/wav" if response_format == "wav" else "audio/pcm" + request_id, generator, _ = await self._prepare_speech_generation(request) return StreamingResponse( - self._generate_pcm_chunks(generator, request_id), - media_type="audio/pcm", + self._generate_audio_chunks(generator, request_id, response_format), + media_type=media_type, ) - # Non-streaming: collect final output - final_output: OmniRequestOutput | None = None - async for res in generator: - final_output = res - - if final_output is None: - return self.create_error_response("No output generated from the model.") - - audio_output, audio_key = self._extract_audio_output(final_output) - if audio_key is None: - return self.create_error_response("TTS model did not produce audio output.") - - audio_tensor = audio_output[audio_key] - sr_raw = audio_output.get("sr", 24000) - sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw - sample_rate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) - - # async_chunk mode accumulates chunks as a list; concat first. - if isinstance(audio_tensor, list): - import torch - - audio_tensor = torch.cat(audio_tensor, dim=-1) - if hasattr(audio_tensor, "float"): - audio_tensor = audio_tensor.float().detach().cpu().numpy() - if audio_tensor.ndim > 1: - audio_tensor = audio_tensor.squeeze() - - audio_obj = CreateAudio( - audio_tensor=audio_tensor, - sample_rate=sample_rate, - response_format=request.response_format or "wav", - speed=request.speed or 1.0, - stream_format=request.stream_format, - base64_encode=False, - ) - audio_response = self.create_audio(audio_obj) - return Response(content=audio_response.audio_data, media_type=audio_response.media_type) + audio_bytes, media_type = await self._generate_audio_bytes(request) + return Response(content=audio_bytes, media_type=media_type) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -841,3 +1417,92 @@ async def create_speech( except Exception as e: logger.exception("Speech generation failed: %s", e) return self.create_error_response(f"Speech generation failed: {e}") + + @staticmethod + def _merge_batch_item( + batch: BatchSpeechRequest, + item: SpeechBatchItem, + ) -> OpenAICreateSpeechRequest: + """Merge batch-level defaults with per-item overrides into a full request.""" + + def _pick(field: str): + """Return item-level value if set, else batch-level value.""" + item_val = getattr(item, field, None) + return item_val if item_val is not None else getattr(batch, field, None) + + picked_speed = _pick("speed") + return OpenAICreateSpeechRequest( + input=item.input, + model=batch.model, + voice=_pick("voice"), + instructions=_pick("instructions"), + response_format=_pick("response_format") or "wav", + speed=picked_speed if picked_speed is not None else 1.0, + stream=False, + task_type=_pick("task_type"), + language=_pick("language"), + ref_audio=_pick("ref_audio"), + ref_text=_pick("ref_text"), + x_vector_only_mode=_pick("x_vector_only_mode"), + max_new_tokens=_pick("max_new_tokens"), + initial_codec_chunk_frames=_pick("initial_codec_chunk_frames"), + ) + + async def create_speech_batch( + self, + batch_request: BatchSpeechRequest, + ) -> BatchSpeechResponse | ErrorResponse: + """Generate speech for multiple items concurrently.""" + if len(batch_request.items) > self._batch_max_items: + raise ValueError( + f"Batch contains {len(batch_request.items)} items, exceeding the maximum of {self._batch_max_items}." + ) + + error_check_ret = await self._check_model(batch_request) + if error_check_ret is not None: + return error_check_ret + + if self.engine_client.errored: + raise self.engine_client.dead_error + + batch_id = f"speech-batch-{random_uuid()}" + + merged_requests = [self._merge_batch_item(batch_request, item) for item in batch_request.items] + + async def _run_item(idx: int, req: OpenAICreateSpeechRequest) -> SpeechBatchItemResult: + validation_error = self._validate_tts_request(req) + if validation_error is not None: + return SpeechBatchItemResult(index=idx, status="error", error=validation_error) + try: + audio_data, media_type = await self._generate_audio_bytes(req, base64_encode=True) + except Exception as e: + logger.exception("Batch item %d failed: %s", idx, e) + return SpeechBatchItemResult(index=idx, status="error", error=str(e)) + return SpeechBatchItemResult( + index=idx, + status="success", + audio_data=audio_data, + media_type=media_type, + ) + + results = await asyncio.gather( + *[_run_item(i, req) for i, req in enumerate(merged_requests)], + return_exceptions=True, + ) + + final_results: list[SpeechBatchItemResult] = [] + for i, r in enumerate(results): + if isinstance(r, BaseException): + logger.exception("Batch item %d raised unexpected exception: %s", i, r) + final_results.append(SpeechBatchItemResult(index=i, status="error", error=str(r))) + else: + final_results.append(r) + + succeeded = sum(1 for r in final_results if r.status == "success") + return BatchSpeechResponse( + id=batch_id, + results=final_results, + total=len(final_results), + succeeded=succeeded, + failed=len(final_results) - succeeded, + )