-
Notifications
You must be signed in to change notification settings - Fork 1k
[Qwen3TTS][ServingSpeech] Bugfix/voice upload and add optional ref_text #2046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
db0ef93
0accaed
0b89a26
6f3e23a
e4508b1
75638c3
b7140f3
eeb3449
746ba82
610809d
01e4134
da52be7
7473da6
5160819
1d38a1f
2fbdfae
9ef29d7
55a27c1
4347880
5cceff9
431ae0b
4c86c88
0d19a75
fa508d7
37e5cb8
d00efbb
83c33ea
8d38593
9a4552a
97695e8
f28645e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import asyncio | ||
| import base64 | ||
| import io | ||
| import json | ||
| import math | ||
| import os | ||
|
|
@@ -11,6 +12,7 @@ | |
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import soundfile as sf | ||
| import torch | ||
| from fastapi import Request, UploadFile | ||
| from fastapi.responses import Response, StreamingResponse | ||
|
|
@@ -54,6 +56,8 @@ | |
| "Spanish", | ||
| "Italian", | ||
| } | ||
| _REF_AUDIO_MIN_DURATION = 1.0 # seconds | ||
| _REF_AUDIO_MAX_DURATION = 30.0 # seconds | ||
| _TTS_MAX_INSTRUCTIONS_LENGTH = 500 | ||
| _TTS_MAX_NEW_TOKENS_MIN = 1 | ||
| _TTS_MAX_NEW_TOKENS_MAX = 4096 | ||
|
|
@@ -437,8 +441,12 @@ def _get_uploaded_audio_data(self, voice_name: str) -> str | None: | |
| logger.error(f"Could not read audio file for voice {voice_name}: {e}") | ||
| return None | ||
|
|
||
| async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> dict: | ||
| """Upload a new voice sample.""" | ||
| async def upload_voice( | ||
| self, audio_file: UploadFile, consent: str, name: str, *, ref_text: str | None = None | ||
| ) -> dict: | ||
| # Normalize ref_text: treat whitespace-only as absent | ||
| if ref_text is not None: | ||
| ref_text = ref_text.strip() or None | ||
| # Validate file size (max 10MB) | ||
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | ||
| audio_file.file.seek(0, 2) # Seek to end | ||
|
|
@@ -512,10 +520,29 @@ async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> | |
| if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): | ||
| raise ValueError("Invalid file path: potential path traversal attack detected") | ||
|
|
||
| # Read content and validate duration before saving | ||
| content = await audio_file.read() | ||
| try: | ||
| wav_np, sr = sf.read(io.BytesIO(content)) | ||
| duration = len(wav_np) / sr if sr > 0 else 0.0 | ||
| if duration < _REF_AUDIO_MIN_DURATION: | ||
| raise ValueError( | ||
| f"Reference audio too short ({duration:.1f}s). " | ||
| f"At least {_REF_AUDIO_MIN_DURATION:.0f}s of clear speech is required." | ||
| ) | ||
| if duration > _REF_AUDIO_MAX_DURATION: | ||
| raise ValueError( | ||
| f"Reference audio too long ({duration:.1f}s). " | ||
| f"Maximum {_REF_AUDIO_MAX_DURATION:.0f}s supported — use a shorter clip." | ||
| ) | ||
| except ValueError: | ||
| raise | ||
| except Exception as e: | ||
| logger.warning("Could not validate audio duration: %s", e) | ||
|
|
||
| # Save audio file | ||
| try: | ||
| with open(file_path, "wb") as f: | ||
| content = await audio_file.read() | ||
| f.write(content) | ||
| except Exception as e: | ||
| raise ValueError(f"Failed to save audio file: {e}") | ||
|
|
@@ -529,6 +556,7 @@ async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> | |
| "mime_type": mime_type, | ||
| "original_filename": audio_file.filename, | ||
| "file_size": file_size, | ||
| "ref_text": ref_text, | ||
| "cache_status": "pending", # The initial cache state is pending. | ||
| "cache_file": None, # The initial cache file is empty. | ||
| "cache_generated_at": None, # The initial cache generation time is empty. | ||
|
|
@@ -552,13 +580,16 @@ async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> | |
| logger.info(f"Uploaded new voice '{name}' with consent ID '{consent}'") | ||
|
|
||
| # Return voice information without exposing the server file path | ||
| return { | ||
| result = { | ||
| "name": name, | ||
| "consent": consent, | ||
| "created_at": timestamp, | ||
| "mime_type": mime_type, | ||
| "file_size": file_size, | ||
| } | ||
| if ref_text is not None: | ||
| result["ref_text"] = ref_text | ||
| return result | ||
|
|
||
| async def upload_voice_embedding(self, embedding_json: str, consent: str, name: str) -> dict: | ||
| """Upload a voice from a pre-computed speaker embedding. | ||
|
|
@@ -863,7 +894,19 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int | |
| wav_np = np.asarray(wav_np, dtype=np.float32) | ||
| if wav_np.ndim > 1: | ||
| wav_np = np.mean(wav_np, axis=-1) | ||
| return wav_np.tolist(), int(sr) | ||
| sr = int(sr) | ||
| duration = len(wav_np) / sr if sr > 0 else 0.0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duration validation only runs at generation time (
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect. I added a check there as well. |
||
| if duration < _REF_AUDIO_MIN_DURATION: | ||
| raise ValueError( | ||
| f"Reference audio too short ({duration:.1f}s). " | ||
| f"At least {_REF_AUDIO_MIN_DURATION:.0f}s of clear speech is required." | ||
| ) | ||
| if duration > _REF_AUDIO_MAX_DURATION: | ||
| raise ValueError( | ||
| f"Reference audio too long ({duration:.1f}s). " | ||
| f"Maximum {_REF_AUDIO_MAX_DURATION:.0f}s supported — use a shorter clip." | ||
| ) | ||
| return wav_np.tolist(), sr | ||
|
|
||
| async def _generate_audio_chunks(self, generator, request_id: str, response_format: str = "pcm"): | ||
| """Generate audio chunks for streaming response. | ||
|
|
@@ -987,15 +1030,22 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any | |
| if request.voice is not None: | ||
| params["speaker"] = [request.voice] | ||
|
|
||
| # If voice is an uploaded speaker and no ref_audio provided, auto-set it | ||
| # Uploaded voices use task_type="Base" (CustomVoice requires built-in spk_id). | ||
| # If ref_text was provided at upload time, use in-context cloning; otherwise x_vector only. | ||
| if request.voice.lower() in self.uploaded_speakers and request.ref_audio is None: | ||
| audio_data = self._get_uploaded_audio_data(request.voice) | ||
| if audio_data: | ||
| params["ref_audio"] = [audio_data] | ||
| params["x_vector_only_mode"] = [True] | ||
| logger.info(f"Auto-set ref_audio for uploaded voice: {request.voice}") | ||
| else: | ||
| if not audio_data: | ||
| raise ValueError(f"Audio file for uploaded voice '{request.voice}' is missing or corrupted") | ||
| speaker_info = self.uploaded_speakers[request.voice.lower()] | ||
| stored_ref_text = speaker_info.get("ref_text") | ||
| params["ref_audio"] = [audio_data] | ||
| params["task_type"] = ["Base"] | ||
| if stored_ref_text: | ||
| params["ref_text"] = [stored_ref_text] | ||
| params["x_vector_only_mode"] = [False] | ||
|
Comment on lines
+1043
to
+1045
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||
| else: | ||
| params["x_vector_only_mode"] = [True] | ||
| logger.info("Auto-set ref_audio for uploaded voice: %s (icl=%s)", request.voice, bool(stored_ref_text)) | ||
|
|
||
| elif params["task_type"][0] == "CustomVoice": | ||
| params["speaker"] = ["Vivian"] # Default for CustomVoice | ||
|
|
@@ -1162,8 +1212,14 @@ async def _prepare_speech_generation( | |
| tts_params = {} | ||
| else: | ||
| tts_params = self._build_tts_params(request) | ||
| if request.ref_audio is not None: | ||
| wav_list, sr = await self._resolve_ref_audio(request.ref_audio) | ||
| # Resolve ref_audio (explicit or auto-set for uploaded voices) | ||
| # to [[wav_list, sr]] so the model doesn't re-decode base64. | ||
| ref_audio_source = request.ref_audio | ||
| if ref_audio_source is None and isinstance(tts_params.get("ref_audio"), list): | ||
| # Uploaded voice: ref_audio was auto-set as [base64_data_url] | ||
| ref_audio_source = tts_params["ref_audio"][0] | ||
| if ref_audio_source is not None and isinstance(ref_audio_source, str): | ||
| wav_list, sr = await self._resolve_ref_audio(ref_audio_source) | ||
| tts_params["ref_audio"] = [[wav_list, sr]] | ||
|
|
||
| ph_len = self._estimate_prompt_len(tts_params) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test doesn't assert that
ref_textwas actually stored. At minimum checkresult["voice"].get("ref_text") == "Hello world transcript"— otherwise the test passes even ifref_textis silently dropped.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added this