fix: handle uploaded voice as ref_audio in Voxtral TTS#2790
fix: handle uploaded voice as ref_audio in Voxtral TTS#2790passionworkeer wants to merge 1 commit into
Conversation
When user provides an uploaded speaker voice, resolve to reference audio and pass as ref_audio to the Voxtral tokenizer instead of as voice name. Validates that the uploaded voice is audio-backed (not embedding-only) and raises clear ValueError if reference audio is missing. Closes #2547
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Pull request overview
This PR updates the OpenAI Speech serving implementation to correctly handle uploaded/custom voices for Voxtral TTS by resolving them to stored reference audio and passing that audio via ref_audio rather than treating the uploaded voice name as a built-in preset.
Changes:
- Add Voxtral-specific prompt building that resolves uploaded voices to stored reference audio and uses
SpeechRequest(..., ref_audio=...). - Expand TTS model/stage detection and validation logic to branch across Voxtral/Qwen/Fish paths.
- Add new streaming and batch-generation helper paths (plus additional request validation and uploaded-voice metadata handling).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for stage in self.engine_client.stage_configs: | ||
| if stage.engine_args.model_stage in _TTS_MODEL_STAGES: |
There was a problem hiding this comment.
_find_tts_stage() assumes self.engine_client.stage_configs always exists, but existing unit tests (e.g. tests/entrypoints/openai_api/test_serving_speech.py) mock only stage_list. This will raise AttributeError and break non-omni engine clients/mocks. Consider using getattr(self.engine_client, "stage_configs", None) with a fallback to the previous stage_list shape, or guard with hasattr before iterating.
| for stage in self.engine_client.stage_configs: | |
| if stage.engine_args.model_stage in _TTS_MODEL_STAGES: | |
| stage_configs = getattr(self.engine_client, "stage_configs", None) | |
| if stage_configs is None: | |
| stage_configs = getattr(self.engine_client, "stage_list", ()) | |
| for stage in stage_configs or (): | |
| engine_args = getattr(stage, "engine_args", None) | |
| model_stage = getattr(engine_args, "model_stage", None) | |
| if model_stage in _TTS_MODEL_STAGES: |
| # Validate speaker_embedding constraints | ||
| if request.speaker_embedding is not None: | ||
| if task_type != "Base": | ||
| return "'speaker_embedding' is only valid for Base task" | ||
| if not request.speaker_embedding: | ||
| return "'speaker_embedding' must be a non-empty list of floats" | ||
| # speaker_embedding implies x_vector_only_mode — set it before | ||
| # Base task validation so callers don't need to pass it explicitly. | ||
| request.x_vector_only_mode = True | ||
| emb_len = len(request.speaker_embedding) |
There was a problem hiding this comment.
OpenAICreateSpeechRequest (from vllm_omni/entrypoints/openai/protocol/audio.py) does not define speaker_embedding, but this code accesses request.speaker_embedding directly. That will raise AttributeError at runtime. Either add speaker_embedding to the request model (and any API schema) or switch all accesses to getattr(request, "speaker_embedding", None) and plumb the value in via an existing field.
| # Handle the case where request.voice is NOT None | ||
| pass | ||
| # voice is not None |
There was a problem hiding this comment.
The else branch for task_type == "Base" contains a stray pass and redundant comment lines before the actual logic. This is dead code and makes the control flow harder to follow; it should be removed.
| # Handle the case where request.voice is NOT None | |
| pass | |
| # voice is not None |
| # Assert that sample rate has been set from chunk metadata (not just default) | ||
| # This ensures the WAV header contains the correct sample rate | ||
| assert sr_raw is not None, ( | ||
| "First audio chunk must include sample rate metadata for WAV streaming" | ||
| ) |
There was a problem hiding this comment.
_generate_audio_chunks() uses an assert sr_raw is not None to enforce that the first chunk includes sample-rate metadata for WAV streaming. assert can be disabled with Python optimizations (-O), which would silently produce a WAV header with the default/incorrect sample rate. Prefer an explicit runtime check that raises a regular exception (e.g., ValueError) so this invariant is enforced in production.
| # Assert that sample rate has been set from chunk metadata (not just default) | |
| # This ensures the WAV header contains the correct sample rate | |
| assert sr_raw is not None, ( | |
| "First audio chunk must include sample rate metadata for WAV streaming" | |
| ) | |
| # Require sample rate metadata from the first chunk so the | |
| # streamed WAV header is always generated with the correct value. | |
| if sr_raw is None: | |
| raise ValueError( | |
| "First audio chunk must include sample rate metadata for WAV streaming" | |
| ) |
| async def _iter_pcm_audio_bytes(self, request: OpenAICreateSpeechRequest): | ||
| """Yield raw PCM bytes for a speech request as soon as chunks are decoded.""" | ||
| request_id, generator, _ = await self._prepare_speech_generation(request) | ||
| async for chunk in self._generate_pcm_chunks(generator, request_id): |
There was a problem hiding this comment.
_iter_pcm_audio_bytes() calls self._generate_pcm_chunks(...), but the streaming helper has been renamed to _generate_audio_chunks() and _generate_pcm_chunks no longer exists. This will raise at runtime if the method is used. Update this to call _generate_audio_chunks(..., response_format="pcm") or reintroduce the old helper name as a wrapper for backward compatibility.
| async for chunk in self._generate_pcm_chunks(generator, request_id): | |
| async for chunk in self._generate_audio_chunks( | |
| generator, request_id, response_format="pcm" | |
| ): |
| if request.speaker_embedding is not None: | ||
| # Store as plain float list (not tensor) so it survives msgspec | ||
| # serialization through the EngineCore IPC boundary. The talker's | ||
| # _build_prompt_embeds converts it back to a tensor on the GPU. | ||
| params["voice_clone_prompt"] = [ | ||
| { | ||
| "ref_spk_embedding": list(request.speaker_embedding), | ||
| } | ||
| ] | ||
| # speaker_embedding implies x_vector_only_mode | ||
| params["x_vector_only_mode"] = [True] |
There was a problem hiding this comment.
_build_tts_params() also accesses request.speaker_embedding, but OpenAICreateSpeechRequest currently has no such field (see vllm_omni/entrypoints/openai/protocol/audio.py). This will crash when building params. Either extend the request schema to include speaker_embedding or remove/guard this branch.
| with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f: | ||
| np.save(f, np.asarray(wav_samples, dtype=np.float32)) | ||
| ref_audio_path = f.name | ||
|
|
||
| # Structured clone metadata is consumed directly by | ||
| # FishSpeechSlowARForConditionalGeneration.preprocess(), so keep these | ||
| # values as scalars instead of the list-wrapped prompt-dict convention. | ||
| additional_information = { | ||
| "text": normalized_text, | ||
| "ref_text": normalized_ref_text, | ||
| "ref_audio_path": ref_audio_path, | ||
| "ref_audio_sr": int(sr), | ||
| "fish_structured_voice_clone": True, | ||
| } |
There was a problem hiding this comment.
Fish voice-cloning writes each request’s reference audio to a NamedTemporaryFile(..., delete=False) and returns only the path, but there is no cleanup path after generation completes. This can leak files and eventually fill disk in production. Consider passing audio in-memory (if supported), deleting the temp file in a finally after the engine has consumed it, or adding a bounded cache/cleanup mechanism for these fish_ref_*.npy files.
| Streaming is supported via stream=True with response_format='pcm' or 'wav'. | ||
| Each Code2Wav chunk is yielded as raw audio bytes as soon as it is decoded. | ||
| For WAV format, a header with placeholder size values is emitted first. | ||
| """ | ||
| error_check_ret = await self._check_model(request) | ||
| if error_check_ret is not None: | ||
| logger.error("Error with model %s", error_check_ret) | ||
| return error_check_ret | ||
|
|
||
| if self.engine_client.errored: | ||
| raise self.engine_client.dead_error | ||
|
|
||
| request_id = f"speech-{random_uuid()}" | ||
|
|
||
| try: | ||
| if self._is_tts: | ||
| # Validate TTS parameters | ||
| validation_error = self._validate_tts_request(request) | ||
| if validation_error: | ||
| return self.create_error_response(validation_error) | ||
|
|
||
| tts_params = self._build_tts_params(request) | ||
| if request.ref_audio is not None: | ||
| wav_list, sr = await self._resolve_ref_audio(request.ref_audio) | ||
| tts_params["ref_audio"] = [[wav_list, sr]] | ||
|
|
||
| # Prompt length must match model-side embeddings; values are placeholders. | ||
| ph_len = self._estimate_prompt_len(tts_params) | ||
| prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params} | ||
| else: | ||
| tts_params = {} | ||
| prompt = {"prompt": request.input} | ||
|
|
||
| logger.info( | ||
| "TTS speech request %s: text=%r, task_type=%s", | ||
| request_id, | ||
| request.input[:50] + "..." if len(request.input) > 50 else request.input, | ||
| tts_params.get("task_type", ["unknown"])[0], | ||
| ) | ||
|
|
||
| sampling_params_list = self.engine_client.default_sampling_params_list | ||
| if request.stream: | ||
| # Determine response format and media type for streaming | ||
| response_format = (request.response_format or "wav").lower() | ||
|
|
||
| # Only pcm and wav support streaming without post-processing | ||
| if response_format not in ["pcm", "wav"]: | ||
| return self.create_error_response( | ||
| f"Streaming is only supported for 'pcm' and 'wav' formats. " | ||
| f"Got '{response_format}'. For other formats, use stream=False." | ||
| ) |
There was a problem hiding this comment.
This method advertises streaming WAV support (and defaults response_format to wav), but the request schema (OpenAICreateSpeechRequest.validate_streaming_constraints in vllm_omni/entrypoints/openai/protocol/audio.py) currently enforces stream=true => response_format == "pcm". As a result, the WAV streaming branch is unreachable and the docstring/error messages here will be inconsistent with validation. Either update the request validator/schema to allow wav streaming, or keep streaming limited to pcm and adjust this logic/docs accordingly.
| BatchSpeechRequest, | ||
| BatchSpeechResponse, | ||
| CreateAudio, | ||
| OpenAICreateSpeechRequest, | ||
| SpeechBatchItem, | ||
| SpeechBatchItemResult, |
There was a problem hiding this comment.
BatchSpeechRequest, BatchSpeechResponse, SpeechBatchItem, and SpeechBatchItemResult are imported from vllm_omni.entrypoints.openai.protocol.audio, but those symbols do not exist in that module (a repo-wide search only finds them in this file). This will cause an ImportError when importing serving_speech.py. Either add these models to protocol/audio.py (and any API routes) or remove these imports/implementations until the protocol layer is updated.
| BatchSpeechRequest, | |
| BatchSpeechResponse, | |
| CreateAudio, | |
| OpenAICreateSpeechRequest, | |
| SpeechBatchItem, | |
| SpeechBatchItemResult, | |
| CreateAudio, | |
| OpenAICreateSpeechRequest, |
lishunyang12
left a comment
There was a problem hiding this comment.
Review: fix: handle uploaded voice as ref_audio in Voxtral TTS
Summary
This is a large PR (802 additions, 137 deletions) that, despite the title suggesting a focused Voxtral fix, actually introduces substantial new functionality across multiple TTS backends. The core change (resolving uploaded voices to ref_audio for Voxtral) is sound, but the PR bundles in Fish Speech integration, batch speech generation, WAV streaming, speaker embedding upload, and significant refactoring of the shared create_speech flow — all in a single commit touching one file.
What looks good
-
Voxtral uploaded voice resolution (the stated fix): The logic in
_build_voxtral_promptcorrectly checksembedding_sourceandmime_typeto validate audio-backed voices, retrieves the stored audio via_get_uploaded_audio_data(), strips the data URI prefix, and passes it asref_audiotoSpeechRequest. This is well-structured and handles error cases clearly. -
Audio duration validation: Adding
_REF_AUDIO_MIN_DURATION/_REF_AUDIO_MAX_DURATIONchecks in bothupload_voiceand_resolve_ref_audiois a good defensive measure. -
_validate_ref_audio_formatextraction: De-duplicating the ref_audio format validation into a shared helper method is a clean improvement. -
Error messages are user-friendly and actionable throughout.
Issues and concerns
1. Scope creep — PR does far more than the title claims
The title says "fix: handle uploaded voice as ref_audio in Voxtral TTS" but this PR also adds:
- Full Fish Speech TTS integration (
_build_fish_speech_prompt,_validate_fish_tts_request,_estimate_fish_prompt_len,_estimate_fish_ref_code_len) - Batch speech generation (
create_speech_batch,_merge_batch_item,BatchSpeechRequest/BatchSpeechResponseimports) - WAV streaming support (
_create_wav_header, WAV header emission in_generate_audio_chunks) - Speaker embedding upload (
upload_voice_embedding) _generate_audio_bytesnon-streaming refactor_prepare_speech_generationorchestration refactor
This makes the PR very hard to review properly. Consider splitting into focused PRs.
2. assert used for runtime validation in _build_voxtral_prompt
assert voice or ref_audio, "Either voice or ref_audio must be provided"assert statements are stripped when Python runs with -O (optimized mode). This should be a ValueError or similar explicit check, especially since this is reachable from user input after validation — a defense-in-depth concern.
3. Bare pass before code in _validate_qwen_tts_request
else:
# Handle the case where request.voice is NOT None
pass
# voice is not None
voice_lower = request.voice.lower()The pass followed by executable code on the next line is confusing. The pass is a no-op here and the comment is redundant — this should just have the code directly.
4. Temp file leak in _build_fish_speech_prompt
with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f:
np.save(f, np.asarray(wav_samples, dtype=np.float32))
ref_audio_path = f.namedelete=False means these temp files persist indefinitely. There is no cleanup mechanism visible. Under sustained voice-cloning load, this will leak disk space. Consider registering cleanup in a finally block or using a managed temp directory.
5. _tts_tokenizer attribute used but never initialized
In _build_voxtral_prompt, self._tts_tokenizer is checked and lazily initialized, but I don't see where it's initially set to None in __init__. If another code path accesses it before _build_voxtral_prompt runs, this will raise AttributeError. (This may be set elsewhere in the class not visible in the diff — worth confirming.)
6. Mixed torch/numpy in _generate_audio_bytes for cumulative mode
if async_chunk:
audio_tensor = torch.cat(non_empty_chunks, dim=-1) if non_empty_chunks else np.zeros(...)
else:
audio_tensor = np.zeros(...)
for candidate in reversed(audio_history):
if candidate.numel() > 0:
audio_tensor = candidate
breakThe fallback for async_chunk=True uses np.zeros when there are no chunks, but later the code calls .float().detach().cpu().numpy() which would fail on an ndarray. The non-async path also calls .numel() which is a torch method, not numpy. This suggests the code assumes torch tensors but the fallback creates numpy arrays — a latent type mismatch bug.
7. _generate_audio_chunks — assert for sample rate
assert sr_raw is not None, "First audio chunk must include sample rate metadata for WAV streaming"Same concern as issue #2 — this is runtime validation using assert. If the engine ever returns a chunk without sr, this would silently pass in optimized mode and produce a WAV header with a possibly wrong default sample rate (24000).
Nits
- The
import torchat top level adds startup latency even when serving non-TTS models. Consider lazy import. _batch_max_itemsdefaults togetattr(self.engine_client, "tts_batch_max_items", 32)— this attribute doesn't appear to be documented or set anywhere in the visible code.
Verdict
The core Voxtral uploaded-voice fix is correct and well-implemented. However, the PR bundles too many unrelated features into a single 800-line diff, making thorough review difficult. I recommend splitting the Fish Speech, batch, WAV streaming, and embedding upload features into separate PRs. The issues noted above (assert usage, temp file leaks, torch/numpy type mismatch) should be addressed.
Summary
When user provides an uploaded speaker voice to Voxtral TTS, resolve it to its reference audio and pass as
ref_audioparameter instead of as a voice name. Also validates the uploaded voice is audio-backed (not embedding-only) and raises clear errors if reference audio is missing.Changes
uploaded_speakersdictembedding_source == "audio"ormime_type.startswith("audio/")— raiseValueErrorif embedding-onlyself._get_uploaded_audio_data()to retrieve reference audio — raiseValueErrorif missingSpeechRequest(input=text, ref_audio=ref_audio)instead ofvoice=...Closes #2547