diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md index a305981f3b3..ca5f7a2ac3a 100644 --- a/docs/serving/speech_api.md +++ b/docs/serving/speech_api.md @@ -125,6 +125,81 @@ Lists available voices for the loaded model. "voices": ["aiden", "dylan", "eric", "ono_anna", "ryan", "serena", "sohee", "uncle_fu", "vivian"] } ``` +``` +POST /v1/audio/voices +Content-Type: multipart/form-data +``` + +Upload a new voice sample for voice cloning in Base task TTS requests. + +**Form Parameters:** + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `audio_sample` | file | Yes | Audio file (max 10MB, supported formats: wav, mp3, flac, ogg, aac, webm, mp4) | +| `consent` | string | Yes | Consent recording ID | +| `name` | string | Yes | Name for the new voice | + +**Response Example:** + +```json +{ + "success": true, + "voice": { + "name": "custom_voice_1", + "consent": "user_consent_id", + "created_at": 1738660000, + "mime_type": "audio/wav", + "file_size": 1024000 + } +} +``` + +**Usage Example:** + +```bash +curl -X POST http://localhost:8091/v1/audio/voices \ + -F "audio_sample=@/path/to/voice_sample.wav" \ + -F "consent=user_consent_id" \ + -F "name=custom_voice_1" +``` + + +```bash +DELETE /v1/audio/voices/{name} +``` + +Delete an uploaded voice sample. + +**Path Parameters:** + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `name` | string | Yes | Name of the voice to delete | + +**Response Example:** + +```json +{ + "success": true, + "message": "Voice 'custom_voice_1' deleted successfully" +} +``` + +**Error Response (404 Not Found):** + +```json +{ + "success": false, + "error": "Voice 'unknown_voice' not found" +} +``` + +**Usage Example:** + +```bash +curl -X DELETE http://localhost:8091/v1/audio/voices/custom_voice_1 +``` ## Examples @@ -185,6 +260,25 @@ curl -X POST http://localhost:8091/v1/audio/speech \ }' --output cloned.wav ``` +upload voice +```bash +curl -X POST http://localhost:8091/v1/audio/voices \ + -F "audio_sample=@/path/to/voice_sample.wav" \ + -F "consent=user_consent_id" \ + -F "name=custom_voice_1" +``` + +use upload voice +```bash +curl -X POST http://localhost:8091/v1/audio/speech \ + -H "Content-Type: application/json" \ + -d '{ + "input": "Hello, this is a cloned voice", + "task_type": "Base", + "voice": "custom_voice_1" + }' --output cloned.wav +``` + ## Supported Models | Model | Task Type | Description | diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md index 2046c650866..5088a4ca6fd 100644 --- a/examples/online_serving/qwen3_tts/README.md +++ b/examples/online_serving/qwen3_tts/README.md @@ -184,29 +184,68 @@ sudo apt install ffmpeg ## API Reference -### Endpoint - -``` -POST /v1/audio/speech -Content-Type: application/json -``` +### Voices Endpoint -This endpoint follows the [OpenAI Audio Speech API](https://platform.openai.com/docs/api-reference/audio/createSpeech) format with additional Qwen3-TTS parameters. +#### GET /v1/audio/voices -### Voices Endpoint +List all available voices/speakers from the loaded model, including both built-in model voices and uploaded custom voices. +**Response Example:** +```json +{ + "voices": ["vivian", "ryan", "custom_voice_1"], + "uploaded_voices": [ + { + "name": "custom_voice_1", + "consent": "user_consent_id", + "created_at": 1738660000, + "file_size": 1024000, + "mime_type": "audio/wav" + } + ] +} ``` -GET /v1/audio/voices -``` -Lists available voices for the loaded model: +#### POST /v1/audio/voices + +Upload a new voice sample for voice cloning in Base task TTS requests. +**Form Parameters:** +- `audio_sample` (required): Audio file (max 10MB, supported formats: wav, mp3, flac, ogg, aac, webm, mp4) +- `consent` (required): Consent recording ID +- `name` (required): Name for the new voice + +**Response Example:** ```json { - "voices": ["aiden", "dylan", "eric", "one_anna", "ryan", "serena", "sohee", "uncle_fu", "vivian"] + "success": true, + "voice": { + "name": "custom_voice_1", + "consent": "user_consent_id", + "created_at": 1738660000, + "mime_type": "audio/wav", + "file_size": 1024000 + } } ``` +**Usage Example:** +```bash +curl -X POST http://localhost:8000/v1/audio/voices \ + -F "audio_sample=@/path/to/voice_sample.wav" \ + -F "consent=user_consent_id" \ + -F "name=custom_voice_1" +``` + +### Endpoint + +``` +POST /v1/audio/speech +Content-Type: application/json +``` + +This endpoint follows the [OpenAI Audio Speech API](https://platform.openai.com/docs/api-reference/audio/createSpeech) format with additional Qwen3-TTS parameters. + ### Request Body ```json diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index 10dbd3a3b9d..8b717736583 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -1,11 +1,15 @@ # tests/entrypoints/openai/test_serving_speech.py import logging +import os from inspect import Signature, signature +from pathlib import Path +from unittest.mock import MagicMock, patch import numpy as np import pytest import torch -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException, UploadFile +from fastapi.params import File, Form from fastapi.testclient import TestClient from pydantic import ValidationError from pytest_mock import MockerFixture @@ -206,10 +210,52 @@ async def awaitable_patched_create_speech(*args, **kwargs): # Add list_voices endpoint async def list_voices(): speakers = sorted(speech_server.supported_speakers) if speech_server.supported_speakers else [] - return {"voices": speakers} + uploaded_voices = [] + if hasattr(speech_server, "uploaded_speakers"): + for voice_name, info in speech_server.uploaded_speakers.items(): + uploaded_voices.append( + { + "name": info.get("name", voice_name), + "consent": info.get("consent", ""), + "created_at": info.get("created_at", 0), + "file_size": info.get("file_size", 0), + "mime_type": info.get("mime_type", ""), + } + ) + return {"voices": speakers, "uploaded_voices": uploaded_voices} app.add_api_route("/v1/audio/voices", list_voices, methods=["GET"]) + # Add upload_voice endpoint + async def upload_voice(audio_sample: UploadFile = File(...), consent: str = Form(...), name: str = Form(...)): + try: + result = await speech_server.upload_voice(audio_sample, consent, name) + return {"success": True, "voice": result} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.exception(f"Failed to upload voice: {e}") + raise HTTPException(status_code=500, detail=f"Failed to upload voice: {str(e)}") + + app.add_api_route("/v1/audio/voices", upload_voice, methods=["POST"]) + + # Add delete_voice endpoint + async def delete_voice(name: str): + try: + success = await speech_server.delete_voice(name) + if not success: + raise HTTPException(status_code=404, detail=f"Voice '{name}' not found") + return {"success": True, "message": f"Voice '{name}' deleted successfully"} + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.exception(f"Failed to delete voice '{name}': {e}") + raise HTTPException(status_code=500, detail=f"Failed to delete voice: {str(e)}") + + app.add_api_route("/v1/audio/voices/{name}", delete_voice, methods=["DELETE"]) + return app @@ -284,6 +330,149 @@ def test_list_voices_endpoint(self, client): assert response.status_code == 200 assert "voices" in response.json() + def test_upload_voice_success(self, client, tmp_path): + """Test successful voice upload.""" + # Create a mock audio file + audio_content = b"fake audio content" * 1000 # ~17KB + files = { + "audio_sample": ("test.wav", audio_content, "audio/wav"), + } + data = { + "consent": "user_consent_123", + "name": "test_voice", + } + + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 200 + result = response.json() + assert result["success"] is True + assert "voice" in result + voice_info = result["voice"] + assert voice_info["name"] == "test_voice" + assert voice_info["consent"] == "user_consent_123" + assert "created_at" in voice_info + assert voice_info["mime_type"] == "audio/wav" + assert voice_info["file_size"] == len(audio_content) + response = client.delete("/v1/audio/voices/test_voice") + + def test_upload_voice_file_too_large(self, client): + """Test voice upload with file exceeding size limit.""" + # Create a file larger than 10MB + audio_content = b"x" * (11 * 1024 * 1024) # 11MB + files = { + "audio_sample": ("test.wav", audio_content, "audio/wav"), + } + data = { + "consent": "user_consent_123", + "name": "test_voice", + } + + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 400 + result = response.json() + assert "detail" in result + assert "10MB" in result["detail"] + + def test_upload_voice_invalid_mime_type(self, client): + """Test voice upload with invalid MIME type.""" + audio_content = b"fake audio content" + files = { + "audio_sample": ("test.txt", audio_content, "text/plain"), + } + data = { + "consent": "user_consent_123", + "name": "test_voice", + } + + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 400 + result = response.json() + assert "detail" in result + assert "MIME type" in result["detail"] + + def test_upload_voice_name_collision(self, client): + """Test voice upload with duplicate name.""" + # First upload + audio_content = b"fake audio content" + files = { + "audio_sample": ("test.wav", audio_content, "audio/wav"), + } + data = { + "consent": "user_consent_123", + "name": "test_voice", + } + + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 200 + + # Second upload with same name + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 400 + result = response.json() + assert "detail" in result + assert "already exists" in result["detail"] + response = client.delete("/v1/audio/voices/test_voice") + + def test_upload_voice_missing_parameters(self, client): + """Test voice upload with missing required parameters.""" + audio_content = b"fake audio content" + files = { + "audio_sample": ("test.wav", audio_content, "audio/wav"), + } + + # Missing consent + data = {"name": "test_voice5"} + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 422 # Validation error + + # Missing name + data = {"consent": "user_consent_123"} + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 422 # Validation error + + # Missing file + data = { + "consent": "user_consent_123", + "name": "test_voice6", + } + response = client.post("/v1/audio/voices", data=data) + assert response.status_code == 422 # Validation error + + def test_delete_voice_success(self, client): + """Test successful voice deletion.""" + # First upload a voice + audio_content = b"fake audio content" + files = { + "audio_sample": ("test.wav", audio_content, "audio/wav"), + } + data = { + "consent": "user_consent_123", + "name": "test_voice7", + } + + response = client.post("/v1/audio/voices", files=files, data=data) + assert response.status_code == 200 + + # Then delete it + response = client.delete("/v1/audio/voices/test_voice7") + assert response.status_code == 200 + result = response.json() + assert result["success"] is True + assert "deleted successfully" in result["message"] + + # Verify it's gone by trying to delete again + response = client.delete("/v1/audio/voices/test_voice7") + assert response.status_code == 404 + result = response.json() + assert "not found" in result["detail"] + + def test_delete_voice_not_found(self, client): + """Test deleting a non-existent voice.""" + response = client.delete("/v1/audio/voices/nonexistent") + assert response.status_code == 404 + result = response.json() + assert "not found" in result["detail"] + class TestTTSMethods: """Unit tests for TTS validation and parameter building.""" @@ -438,6 +627,121 @@ def test_load_supported_speakers(self, mocker: MockerFixture): # Verify speakers are normalized to lowercase assert server.supported_speakers == {"ryan", "vivian", "aiden"} + def test_build_tts_params_with_uploaded_voice(self, speech_server): + """Test _build_tts_params auto-sets ref_audio for uploaded voices.""" + # Mock an uploaded speaker + speech_server.uploaded_speakers = { + "custom_voice": { + "name": "custom_voice", + "file_path": "/tmp/voice_samples/custom_voice_consent_123.wav", + "mime_type": "audio/wav", + } + } + speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"} + + # Mock _get_uploaded_audio_data to return base64 data + with patch.object(speech_server, "_get_uploaded_audio_data") as mock_get_audio: + mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv" + + req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice", task_type="Base") + + params = speech_server._build_tts_params(req) + + # Verify ref_audio was auto-set + assert "ref_audio" in params + assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"] + assert "x_vector_only_mode" in params + assert params["x_vector_only_mode"] == [True] + mock_get_audio.assert_called_once_with("custom_voice") + + def test_build_tts_params_without_uploaded_voice(self, speech_server): + """Test _build_tts_params does not auto-set ref_audio for non-uploaded voices.""" + # No uploaded speakers + speech_server.uploaded_speakers = {} + speech_server.supported_speakers = {"ryan", "vivian"} + + req = OpenAICreateSpeechRequest(input="Hello", voice="ryan", task_type="Base") + + params = speech_server._build_tts_params(req) + + # Verify ref_audio was NOT auto-set + assert "ref_audio" not in params + assert "x_vector_only_mode" not in params + + def test_build_tts_params_with_explicit_ref_audio(self, speech_server): + """Test _build_tts_params uses explicit ref_audio even for uploaded voices.""" + # Mock an uploaded speaker + speech_server.uploaded_speakers = { + "custom_voice": { + "name": "custom_voice", + "file_path": "/tmp/voice_samples/custom_voice_consent_123.wav", + "mime_type": "audio/wav", + } + } + speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"} + + req = OpenAICreateSpeechRequest( + input="Hello", voice="custom_voice", task_type="Base", ref_audio="data:audio/wav;base64,ZXhwbGljaXQ=" + ) + + params = speech_server._build_tts_params(req) + + # _build_tts_params should NOT auto-set ref_audio when explicit ref_audio + # is provided (request.ref_audio is not None skips the auto-set branch). + # The explicit ref_audio is resolved later in create_speech() via + # _resolve_ref_audio(), not in _build_tts_params(). + assert "ref_audio" not in params + # x_vector_only_mode should not be set when explicit ref_audio is provided + assert "x_vector_only_mode" not in params + + def test_get_uploaded_audio_data(self, speech_server): + """Test _get_uploaded_audio_data function.""" + # Mock file operations + with ( + patch("builtins.open", create=True) as mock_open, + patch("base64.b64encode") as mock_b64encode, + patch("pathlib.Path.exists") as mock_exists, + ): + mock_exists.return_value = True + mock_b64encode.return_value = b"ZmFrZWF1ZGlv" + + # Setup mock file + mock_file = MagicMock() + mock_file.read.return_value = b"fakeaudio" + mock_open.return_value.__enter__.return_value = mock_file + + # Setup uploaded speaker + speech_server.uploaded_speakers = { + "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"} + } + result = speech_server._get_uploaded_audio_data("test_voice") + + assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv" + mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb") + mock_b64encode.assert_called_once_with(b"fakeaudio") + + def test_get_uploaded_audio_data_missing_file(self, speech_server): + """Test _get_uploaded_audio_data when file is missing.""" + with patch("pathlib.Path.exists") as mock_exists: + mock_exists.return_value = False + + # Setup uploaded speaker + speech_server.uploaded_speakers = { + "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"} + } + + result = speech_server._get_uploaded_audio_data("test_voice") + + assert result is None + + def test_get_uploaded_audio_data_voice_not_found(self, speech_server): + """Test _get_uploaded_audio_data when voice is not in uploaded_speakers.""" + speech_server.uploaded_speakers = {} + + result = speech_server._get_uploaded_audio_data("nonexistent") + + assert result is None + def test_max_instructions_length_default(self, speech_server): """Test default max instructions length (500) when no config provided.""" # Fixture creates server with no CLI override and no TTS stage @@ -544,6 +848,70 @@ def test_validate_instructions_length_uses_cached_value(self, mocker: MockerFixt assert "max 10 characters" in error +class TestFileValidationFunctions: + """Unit tests for file validation helper functions.""" + + def test_sanitize_filename(self): + """Test _sanitize_filename function.""" + from vllm_omni.entrypoints.openai.serving_speech import _sanitize_filename + + # Test normal filenames + assert _sanitize_filename("test.wav") == "test.wav" + assert _sanitize_filename("test-file.mp3") == "test-file.mp3" + assert _sanitize_filename("test_file.flac") == "test_file.flac" + + # Test path traversal attempts + assert _sanitize_filename("../../../etc/passwd") == "passwd" + assert _sanitize_filename("/absolute/path/file.wav") == "file.wav" + + # Test special characters + assert _sanitize_filename("file with spaces.wav") == "file_with_spaces.wav" + assert _sanitize_filename("file&with&special&chars.wav") == "file_with_special_chars.wav" + assert _sanitize_filename("file@with#special$chars%.wav") == "file_with_special_chars_.wav" + + # Test empty filename + assert _sanitize_filename("") == "file" + + # Test very long filename + long_name = "a" * 300 + sanitized = _sanitize_filename(long_name) + assert len(sanitized) == 255 + assert sanitized.startswith("a") + + def test_validate_path_within_directory(self, tmp_path): + """Test _validate_path_within_directory function.""" + from vllm_omni.entrypoints.openai.serving_speech import _validate_path_within_directory + + # Create test directory structure + base_dir = tmp_path / "uploads" + base_dir.mkdir() + + # Valid paths within directory + valid_file = base_dir / "test.wav" + valid_subdir_file = base_dir / "subdir" / "test.wav" + valid_subdir_file.parent.mkdir() + + assert _validate_path_within_directory(valid_file, base_dir) is True + assert _validate_path_within_directory(valid_subdir_file, base_dir) is True + + # Invalid paths outside directory + outside_file = tmp_path / "outside.wav" + assert _validate_path_within_directory(outside_file, base_dir) is False + + # Test with symlink (should fail) + if hasattr(os, "symlink"): + link_target = tmp_path / "target.wav" + link_target.touch() + symlink = base_dir / "link.wav" + os.symlink(link_target, symlink) + # Symlinks to outside should be rejected + assert _validate_path_within_directory(symlink, base_dir) is False + + # Test with non-existent file (should still validate path) + non_existent = base_dir / "nonexistent.wav" + assert _validate_path_within_directory(non_existent, base_dir) is True + + class TestStreamingProtocolValidation: """Unit tests for the stream field validators in OpenAICreateSpeechRequest.""" diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 5b1f2582444..2134411df43 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -881,8 +881,113 @@ async def list_voices(raw_request: Request): if handler is None: return base(raw_request).create_error_response(message="The model does not support Speech API") + # Get all speakers (both model built-in and uploaded) speakers = sorted(handler.supported_speakers) if handler.supported_speakers else [] - return JSONResponse(content={"voices": speakers}) + + # Get uploaded speakers details + uploaded_speakers = [] + if hasattr(handler, "uploaded_speakers"): + for voice_name, info in handler.uploaded_speakers.items(): + uploaded_speakers.append( + { + "name": info.get("name", voice_name), + "consent": info.get("consent", ""), + "created_at": info.get("created_at", 0), + "file_size": info.get("file_size", 0), + "mime_type": info.get("mime_type", ""), + } + ) + + return JSONResponse(content={"voices": speakers, "uploaded_voices": uploaded_speakers}) + + +@router.post( + "/v1/audio/voices", + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def upload_voice( + raw_request: Request, + audio_sample: UploadFile = File(...), + consent: str = Form(...), + name: str = Form(...), +): + """Upload a new voice sample for voice cloning. + + Uploads an audio file that can be used as a reference for voice cloning + in Base task TTS requests. The voice can then be referenced by name + in subsequent TTS requests. + + Args: + audio_sample: Audio file (max 10MB) + consent: Consent recording ID + name: Name for the new voice + raw_request: Raw FastAPI request + + Returns: + JSON response with voice information + """ + handler = Omnispeech(raw_request) + if handler is None: + return base(raw_request).create_error_response(message="The model does not support Speech API") + + try: + # Upload the voice + result = await handler.upload_voice(audio_sample, consent, name) + + return JSONResponse(content={"success": True, "voice": result}) + + except ValueError as e: + return base(raw_request).create_error_response(message=str(e)) + except Exception as e: + logger.exception(f"Failed to upload voice: {e}") + return base(raw_request).create_error_response(message=f"Failed to upload voice: {str(e)}") + + +@router.delete( + "/v1/audio/voices/{name}", + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def delete_voice(name: str, raw_request: Request): + """Delete an uploaded voice. + + Deletes the voice sample and associated metadata. Also removes any + cached voice clone prompts for this voice. + + Args: + name: Name of the voice to delete + raw_request: Raw FastAPI request + + Returns: + JSON response indicating success or failure + """ + handler = Omnispeech(raw_request) + if handler is None: + return base(raw_request).create_error_response(message="The model does not support Speech API") + + try: + # Delete the voice + success = await handler.delete_voice(name) + if not success: + return JSONResponse( + content={"success": False, "error": f"Voice '{name}' not found"}, + status_code=HTTPStatus.NOT_FOUND.value, + ) + + return JSONResponse(content={"success": True, "message": f"Voice '{name}' deleted successfully"}) + + except ValueError as e: + return base(raw_request).create_error_response(message=str(e)) + except Exception as e: + logger.exception(f"Failed to delete voice '{name}': {e}") + return base(raw_request).create_error_response(message=f"Failed to delete voice: {str(e)}") # Health and Model endpoints for diffusion mode diff --git a/vllm_omni/entrypoints/openai/metadata_manager.py b/vllm_omni/entrypoints/openai/metadata_manager.py new file mode 100644 index 00000000000..4077aa23bcd --- /dev/null +++ b/vllm_omni/entrypoints/openai/metadata_manager.py @@ -0,0 +1,243 @@ +""" +Metadata manager for voice samples and cache information. + +Provides a unified interface for managing metadata.json with +concurrency safety and data consistency across multiple processes. +""" + +import fcntl +import json +import logging +import os +import threading +import time +from collections.abc import Callable +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +class MetadataManager: + """ + Manages metadata for uploaded speakers and cache information. + + Features: + 1. Single source of truth for metadata + 2. Concurrency safety with threading locks + 3. Atomic read-modify-write operations + 4. Merge updates to preserve fields from different components + """ + + def __init__(self, metadata_file: Path): + """ + Initialize the metadata manager. + + Args: + metadata_file: Path to metadata.json file + """ + self.metadata_file = metadata_file + self._lock = threading.Lock() # For intra-process concurrency + self._metadata = self._load_from_disk() + + # Create lock file for cross-process synchronization + self.lock_file = metadata_file.with_suffix(".lock") + self.lock_file.parent.mkdir(parents=True, exist_ok=True) + + def _load_from_disk(self) -> dict[str, Any]: + """Load metadata from disk.""" + if not self.metadata_file.exists(): + return {"uploaded_speakers": {}} + + try: + with open(self.metadata_file) as f: + return json.load(f) + except Exception as e: + logger.error(f"Failed to load metadata from {self.metadata_file}: {e}") + return {"uploaded_speakers": {}} + + def _save_to_disk(self, metadata: dict[str, Any]) -> bool: + """Save metadata to disk.""" + try: + self.metadata_file.parent.mkdir(parents=True, exist_ok=True) + tmp = self.metadata_file.with_suffix(".tmp") + with open(tmp, "w") as f: + json.dump(metadata, f, indent=2) + tmp.replace(self.metadata_file) + return True + except Exception as e: + logger.error(f"Failed to save metadata to {self.metadata_file}: {e}") + return False + + # ================================ + # Core fix: single flock overwrites RMW + # ================================ + def _update_with_file_lock( + self, update_fn: Callable[[dict[str, Any]], dict[str, Any] | None] + ) -> dict[str, Any] | None: + lock_fd = os.open(self.lock_file, os.O_CREAT | os.O_RDWR) + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX) + + metadata = self._load_from_disk() + result = update_fn(metadata) + if result is None: + return None + + if not self._save_to_disk(metadata): + return None + + self._metadata = metadata + return result + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + os.close(lock_fd) + + def get_uploaded_speakers(self) -> dict[str, dict[str, Any]]: + """Get all uploaded speakers.""" + # Read directly from disk to ensure getting the latest data + metadata = self._load_from_disk() + return metadata.get("uploaded_speakers", {}).copy() + + def get_speaker(self, speaker_key: str) -> dict[str, Any] | None: + """Get specific speaker information.""" + # Read directly from disk to ensure getting the latest data + metadata = self._load_from_disk() + speakers = metadata.get("uploaded_speakers", {}) + return speakers.get(speaker_key, {}).copy() if speaker_key in speakers else None + + def update_speaker(self, speaker_key: str, updates: dict[str, Any]) -> bool: + """ + Update speaker information with merge semantics. + + Uses file locking for cross-process atomic operations. + """ + with self._lock: + + def _update(metadata: dict[str, Any]): + speakers = metadata.setdefault("uploaded_speakers", {}) + entry = speakers.get(speaker_key, {}) + entry.update(updates) + speakers[speaker_key] = entry + return True + + return self._update_with_file_lock(_update) is not None + + def create_speaker(self, speaker_key: str, speaker_data: dict[str, Any]) -> bool: + """ + Create a new speaker entry. + + Uses file locking for cross-process atomic operations. + """ + with self._lock: + + def _create(metadata: dict[str, Any]): + speakers = metadata.setdefault("uploaded_speakers", {}) + if speaker_key in speakers: + logger.warning(f"Speaker {speaker_key} already exists") + return None + speakers[speaker_key] = speaker_data + return True + + return self._update_with_file_lock(_create) is not None + + def update_cache_info(self, speaker_key: str, cache_file_path: Path, status: str = "ready") -> bool: + """ + Update cache information for a speaker. + """ + updates = { + "cache_status": status, + "cache_file": str(cache_file_path), + "cache_generated_at": time.time(), + } + return self.update_speaker(speaker_key, updates) + + def delete_speaker(self, speaker_key: str) -> dict[str, Any] | None: + """ + Delete a speaker from metadata and clean up associated files. + + Uses file locking for cross-process atomic operations. + + Args: + speaker_key: Speaker name (lowercase) + base_dir: Base directory for file validation (optional) + + Returns: + dict: Deleted speaker information if successful, None if speaker doesn't exist or error + """ + with self._lock: + + def _delete(metadata: dict[str, Any]): + speakers = metadata.get("uploaded_speakers", {}) + if speaker_key not in speakers: + logger.warning(f"Speaker {speaker_key} not found in metadata") + return None + + speaker_info = speakers.pop(speaker_key) + + # Clean up associated files + deleted_files = self._cleanup_speaker_files(speaker_info) + if deleted_files: + logger.info(f"Deleted {len(deleted_files)} files for speaker {speaker_key}: {deleted_files}") + + return speaker_info + + return self._update_with_file_lock(_delete) + + def _cleanup_speaker_files(self, speaker_info: dict[str, Any]) -> list[str]: + """ + Clean up files associated with a speaker. + + Args: + speaker_info: Speaker information dictionary + base_dir: Base directory for file validation (optional) + + Returns: + list: List of successfully deleted file paths + """ + deleted_files = [] + + # Helper function to safely delete a file + def safe_delete(file_path_str: str, description: str) -> bool: + if not file_path_str: + return False + + try: + file_path = Path(file_path_str) + + # Check if file exists + if not file_path.exists(): + logger.debug(f"{description} not found: {file_path}") + return False + + # Delete the file + file_path.unlink() + logger.info(f"Deleted {description}: {file_path}") + deleted_files.append(str(file_path)) + return True + + except Exception as e: + logger.error(f"Failed to delete {description} {file_path_str}: {e}") + return False + + # Delete audio file + audio_file = speaker_info.get("file_path") + if audio_file: + safe_delete(audio_file, "audio file") + + # Delete cache file + cache_file = speaker_info.get("cache_file") + if cache_file: + safe_delete(cache_file, "cache file") + + return deleted_files + + def reload_from_disk(self) -> bool: + """Force reload metadata from disk (useful for external changes).""" + with self._lock: + try: + self._metadata = self._load_from_disk() + return True + except Exception as e: + logger.error(f"Failed to reload metadata from disk: {e}") + return False diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index d53e4176cfc..de9c559ddfb 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -1,11 +1,15 @@ import asyncio +import base64 import json import math import os +import re +import time +from pathlib import Path from typing import Any import numpy as np -from fastapi import Request +from fastapi import Request, UploadFile from fastapi.responses import Response, StreamingResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.logger import init_logger @@ -13,6 +17,7 @@ from vllm.utils import random_uuid from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin +from vllm_omni.entrypoints.openai.metadata_manager import MetadataManager from vllm_omni.entrypoints.openai.protocol.audio import ( CreateAudio, OpenAICreateSpeechRequest, @@ -41,9 +46,52 @@ _TTS_MAX_NEW_TOKENS_MAX = 4096 +def _sanitize_filename(filename: str) -> str: + """Sanitize filename to prevent path traversal attacks. + + Only allows alphanumeric characters, underscores, hyphens, and dots. + Replaces any other characters with underscores. + """ + # Remove any path components + filename = os.path.basename(filename) + # Replace any non-alphanumeric, underscore, hyphen, or dot with underscore + sanitized = re.sub(r"[^a-zA-Z0-9_.\-]", "_", filename) + # Ensure filename is not empty + if not sanitized: + sanitized = "file" + # Limit length to prevent potential issues + if len(sanitized) > 255: + sanitized = sanitized[:255] + return sanitized + + +def _validate_path_within_directory(file_path: Path, directory: Path) -> bool: + """Validate that file_path is within the specified directory. + + Prevents path traversal attacks by ensuring the resolved path + is within the target directory. + """ + try: + # Resolve both paths to absolute paths + file_path_resolved = file_path.resolve() + directory_resolved = directory.resolve() + # Check if file_path is within directory + return directory_resolved in file_path_resolved.parents or directory_resolved == file_path_resolved + except Exception: + return False + + class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Initialize uploaded speakers storage + speech_voice_samples_dir = os.environ.get("SPEECH_VOICE_SAMPLES", "/tmp/voice_samples") + self.uploaded_speakers_dir = Path(speech_voice_samples_dir) + self.uploaded_speakers_dir.mkdir(parents=True, exist_ok=True) + self.metadata_file = self.uploaded_speakers_dir / "metadata.json" + + # Initialize metadata manager + self.metadata_manager = MetadataManager(self.metadata_file) # Find and cache the TTS stage (if any) during initialization self._tts_stage = self._find_tts_stage() @@ -54,9 +102,16 @@ def __init__(self, *args, **kwargs): # Load supported speakers self.supported_speakers = self._load_supported_speakers() - logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") + # Load uploaded speakers + self.uploaded_speakers = self.metadata_manager.get_uploaded_speakers() + + # Merge supported speakers with uploaded speakers + self.supported_speakers.update(self.uploaded_speakers.keys()) self._tts_tokenizer = None + logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") + logger.info(f"Loaded {len(self.uploaded_speakers)} uploaded speakers") + # Load speech tokenizer codec parameters for prompt length estimation self._codec_frame_rate: float | None = self._load_codec_frame_rate() @@ -204,6 +259,191 @@ def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: logger.warning("Failed to estimate TTS prompt length, using fallback 2048: %s", e) return 2048 + def _get_uploaded_audio_data(self, voice_name: str) -> str | None: + """Get base64 encoded audio data for uploaded voice.""" + voice_name_lower = voice_name.lower() + if voice_name_lower not in self.uploaded_speakers: + return None + + speaker_info = self.uploaded_speakers[voice_name_lower] + file_path = Path(speaker_info["file_path"]) + + if not file_path.exists(): + logger.warning(f"Audio file not found for voice {voice_name}: {file_path}") + return None + + try: + # Read audio file + with open(file_path, "rb") as f: + audio_bytes = f.read() + + # Encode to base64 + audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + + # Get MIME type from file extension + mime_type = speaker_info.get("mime_type", "audio/wav") + + # Return as data URL + return f"data:{mime_type};base64,{audio_b64}" + except Exception as e: + logger.error(f"Could not read audio file for voice {voice_name}: {e}") + return None + + async def upload_voice(self, audio_file: UploadFile, consent: str, name: str) -> dict: + """Upload a new voice sample.""" + # Validate file size (max 10MB) + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB + audio_file.file.seek(0, 2) # Seek to end + file_size = audio_file.file.tell() + audio_file.file.seek(0) # Reset to beginning + + if file_size > MAX_FILE_SIZE: + raise ValueError(f"File size exceeds maximum limit of 10MB. Got {file_size} bytes.") + + # Detect MIME type from filename if content_type is generic + mime_type = audio_file.content_type + if mime_type == "application/octet-stream": + # Simple MIME type detection based on file extension + filename_lower = audio_file.filename.lower() + if filename_lower.endswith(".wav"): + mime_type = "audio/wav" + elif filename_lower.endswith((".mp3", ".mpeg")): + mime_type = "audio/mpeg" + elif filename_lower.endswith(".flac"): + mime_type = "audio/flac" + elif filename_lower.endswith(".ogg"): + mime_type = "audio/ogg" + elif filename_lower.endswith(".aac"): + mime_type = "audio/aac" + elif filename_lower.endswith(".webm"): + mime_type = "audio/webm" + elif filename_lower.endswith(".mp4"): + mime_type = "audio/mp4" + else: + mime_type = "audio/wav" # Default + + # Validate MIME type + allowed_mime_types = { + "audio/mpeg", + "audio/wav", + "audio/x-wav", + "audio/ogg", + "audio/aac", + "audio/flac", + "audio/webm", + "audio/mp4", + } + + if mime_type not in allowed_mime_types: + raise ValueError(f"Unsupported MIME type: {mime_type}. Allowed: {allowed_mime_types}") + + # Normalize voice name + voice_name_lower = name.lower() + + # Check if voice already exists + if voice_name_lower in self.uploaded_speakers: + raise ValueError(f"Voice '{name}' already exists") + + # Sanitize name and consent to prevent path traversal + sanitized_name = _sanitize_filename(name) + sanitized_consent = _sanitize_filename(consent) + + # Generate filename with sanitized inputs + timestamp = int(time.time()) + file_suffix = Path(audio_file.filename).suffix + file_ext = file_suffix[1:] if file_suffix and len(file_suffix) > 1 else "wav" + # Sanitize file extension as well + sanitized_ext = _sanitize_filename(file_ext) + if not sanitized_ext or sanitized_ext == "file": + sanitized_ext = "wav" + + filename = f"{sanitized_name}_{sanitized_consent}_{timestamp}.{sanitized_ext}" + file_path = self.uploaded_speakers_dir / filename + + # Double-check that the path is within the upload directory + if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): + raise ValueError("Invalid file path: potential path traversal attack detected") + + # Save audio file + try: + with open(file_path, "wb") as f: + content = await audio_file.read() + f.write(content) + except Exception as e: + raise ValueError(f"Failed to save audio file: {e}") + + # Create speaker data + speaker_data = { + "name": name, + "consent": consent, + "file_path": str(file_path), + "created_at": timestamp, + "mime_type": mime_type, + "original_filename": audio_file.filename, + "file_size": file_size, + "cache_status": "pending", # The initial cache state is pending. + "cache_file": None, # The initial cache file is empty. + "cache_generated_at": None, # The initial cache generation time is empty. + } + + # Save metadata using metadata manager (concurrency safe) + success = self.metadata_manager.create_speaker(voice_name_lower, speaker_data) + if not success: + # Clean up the saved file if metadata creation failed + try: + file_path.unlink() + except Exception: + pass + raise ValueError(f"Failed to create metadata for voice '{name}' (possibly already exists)") + + # Update in-memory cache + self.uploaded_speakers[voice_name_lower] = speaker_data + self.supported_speakers.add(voice_name_lower) + + logger.info(f"Uploaded new voice '{name}' with consent ID '{consent}'") + + # Return voice information without exposing the server file path + return { + "name": name, + "consent": consent, + "created_at": timestamp, + "mime_type": mime_type, + "file_size": file_size, + } + + async def delete_voice(self, name: str) -> bool: + """ + Delete an uploaded voice. + + Args: + name: Voice name to delete + + Returns: + bool: True if successful, False if voice doesn't exist + """ + voice_name_lower = name.lower() + + # Check if voice exists in memory cache + if voice_name_lower not in self.uploaded_speakers: + logger.warning(f"Voice '{name}' not found in memory cache") + return False + + # Delete from metadata manager with file cleanup + # Pass base_dir for path validation + deleted_info = self.metadata_manager.delete_speaker(voice_name_lower) + if not deleted_info: + logger.error(f"Failed to delete voice '{name}' from metadata") + return False + + # Update in-memory cache + if voice_name_lower in self.uploaded_speakers: + del self.uploaded_speakers[voice_name_lower] + if voice_name_lower in self.supported_speakers: + self.supported_speakers.remove(voice_name_lower) + + logger.info(f"Deleted voice '{name}' and associated files") + return True + def _is_tts_model(self) -> bool: """Check if the current model is a supported TTS model.""" stage_list = getattr(self.engine_client, "stage_list", None) @@ -246,23 +486,48 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non # Validate Base task requirements if task_type == "Base": - if request.ref_audio is None: - return "Base task requires 'ref_audio' for voice cloning" - # Validate ref_audio format - if not ( - request.ref_audio.startswith(("http://", "https://")) - or request.ref_audio.startswith("data:") - or request.ref_audio.startswith("file://") - ): - return "ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)" - # In-context voice cloning (default) requires non-empty ref_text. - # x_vector_only_mode skips in-context and only uses speaker embedding. - if not request.x_vector_only_mode: - if not request.ref_text or not request.ref_text.strip(): - return ( - "Base task requires non-empty 'ref_text' (transcript of " - "the reference audio) unless 'x_vector_only_mode' is enabled" - ) + if request.voice is None: + if request.ref_audio is None: + return "Base task requires 'ref_audio' for voice cloning" + # Validate ref_audio format (include file:// from upstream) + if not ( + request.ref_audio.startswith(("http://", "https://")) + or request.ref_audio.startswith("data:") + or request.ref_audio.startswith("file://") + ): + return "ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)" + # In-context voice cloning (default) requires non-empty ref_text. + # x_vector_only_mode skips in-context and only uses speaker embedding. + if not request.x_vector_only_mode: + if not request.ref_text or not request.ref_text.strip(): + return ( + "Base task requires non-empty 'ref_text' (transcript of " + "the reference audio) unless 'x_vector_only_mode' is enabled" + ) + else: + # voice is not None + voice_lower = request.voice.lower() + if voice_lower in self.uploaded_speakers: + # Check if audio file exists for uploaded speaker + speaker_info = self.uploaded_speakers[voice_lower] + file_path = Path(speaker_info["file_path"]) + if not file_path.exists(): + return f"Audio file for uploaded speaker '{request.voice}' not found on disk" + else: + # need ref_audio for built-in speaker + if request.ref_audio is None: + return ( + f"Base task with built-in speaker '{request.voice}' requires 'ref_audio' for voice cloning" + ) + # Validate ref_audio format for built-in speaker + if not ( + request.ref_audio.startswith(("http://", "https://")) + or request.ref_audio.startswith("data:") + or request.ref_audio.startswith("file://") + ): + return ( + "ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)" + ) # Validate cross-parameter dependencies if task_type != "Base": @@ -411,6 +676,17 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any # Speaker (voice) if request.voice is not None: params["speaker"] = [request.voice] + + # If voice is an uploaded speaker and no ref_audio provided, auto-set it + if request.voice.lower() in self.uploaded_speakers and request.ref_audio is None: + audio_data = self._get_uploaded_audio_data(request.voice) + if audio_data: + params["ref_audio"] = [audio_data] + params["x_vector_only_mode"] = [True] + logger.info(f"Auto-set ref_audio for uploaded voice: {request.voice}") + else: + raise ValueError(f"Audio file for uploaded voice '{request.voice}' is missing or corrupted") + elif params["task_type"][0] == "CustomVoice": params["speaker"] = ["Vivian"] # Default for CustomVoice diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py index 4a32205a8cb..b73ab33d747 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts.py @@ -16,7 +16,6 @@ import io import urllib.request from collections.abc import Iterable -from dataclasses import dataclass from typing import Any from urllib.parse import urlparse @@ -35,6 +34,7 @@ from .configuration_qwen3_tts import Qwen3TTSConfig from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration from .processing_qwen3_tts import Qwen3TTSProcessor +from .voice_cache_manager import VoiceCacheManager, VoiceClonePromptItem logger = init_logger(__name__) @@ -59,21 +59,6 @@ def _normalize_task_type(raw: str) -> str: MaybeList = Any | list[Any] -@dataclass -class VoiceClonePromptItem: - """ - Container for one sample's voice-clone prompt information that can be fed to the model. - - Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`. - """ - - ref_code: torch.Tensor | None # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz - ref_spk_embedding: torch.Tensor # (D,) - x_vector_only_mode: bool - icl_mode: bool - ref_text: str | None = None - - class Qwen3TTSModelForGeneration(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -373,6 +358,12 @@ def __init__( self.processor = processor self.generate_defaults = generate_defaults or {} + # Initialize voice cache manager. + # Note: this creates its own MetadataManager for the same metadata.json + # used by serving_speech.py. Sharing is not possible across model/serving + # layers, but file locking in MetadataManager ensures correctness. + self.voice_cache_manager = VoiceCacheManager() + self.device = getattr(model, "device", None) if self.device is None: try: @@ -791,6 +782,7 @@ def generate_voice_clone( self, text: str | list[str], language: str | list[str] = None, + speaker: str | None = None, # New parameter: speaker name ref_audio: AudioLike | list[AudioLike] | None = None, ref_text: str | list[str | None] | None = None, x_vector_only_mode: bool | list[bool] = False, @@ -865,7 +857,27 @@ def generate_voice_clone( self._validate_languages(languages) - if voice_clone_prompt is None: + # Cache logic: if speaker parameter is provided, try to load from cache + cache_loaded = False + cache_speaker = None + cache_audio_path = None + + if speaker: + # Use VoiceCacheManager to load cached voice prompt, passing device parameter + cached_items = self.voice_cache_manager.load_cached_voice_prompt(speaker, device=str(self.device)) + if cached_items is not None: + voice_clone_prompt = cached_items + cache_loaded = True + + # If no cache, check if cache needs to be generated + if not cache_loaded: + audio_file_path = self.voice_cache_manager.get_speaker_audio_path(speaker) + if audio_file_path: + logger.info(f"Will generate cache for speaker: {speaker} (first use)") + cache_speaker = speaker + cache_audio_path = audio_file_path + + if voice_clone_prompt is None and not cache_loaded: if ref_audio is None: # For profile run sample_rate = int(self.model.speaker_encoder_sample_rate) @@ -879,6 +891,28 @@ def generate_voice_clone( prompt_items = self.create_voice_clone_prompt( ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode ) + + # If cache needs to be generated, save cache file + if cache_speaker and cache_audio_path: + try: + # Use VoiceCacheManager to save cache + success = self.voice_cache_manager.save_voice_cache(cache_speaker, cache_audio_path, prompt_items) + if success: + logger.info(f"Cache generated and saved for speaker: {cache_speaker}") + else: + logger.error(f"Failed to save cache for speaker: {cache_speaker}") + except Exception as e: + logger.error(f"Failed to save cache for speaker {cache_speaker}: {e}") + + if len(prompt_items) == 1 and len(texts) > 1: + prompt_items = prompt_items * len(texts) + if len(prompt_items) != len(texts): + raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}") + voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items) + ref_texts_for_ids = [it.ref_text for it in prompt_items] + elif cache_loaded and isinstance(voice_clone_prompt, list): + # Use cached VoiceClonePromptItem + prompt_items = voice_clone_prompt if len(prompt_items) == 1 and len(texts) > 1: prompt_items = prompt_items * len(texts) if len(prompt_items) != len(texts): diff --git a/vllm_omni/model_executor/models/qwen3_tts/voice_cache_manager.py b/vllm_omni/model_executor/models/qwen3_tts/voice_cache_manager.py new file mode 100644 index 00000000000..1e26a161da4 --- /dev/null +++ b/vllm_omni/model_executor/models/qwen3_tts/voice_cache_manager.py @@ -0,0 +1,271 @@ +# Copyright 2026 The Alibaba Qwen team. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from vllm.logger import init_logger + +from vllm_omni.entrypoints.openai.metadata_manager import MetadataManager + +logger = init_logger(__name__) + + +@dataclass +class VoiceClonePromptItem: + """ + Container for one sample's voice-clone prompt information that can be fed to the model. + + Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`. + """ + + ref_code: torch.Tensor | None # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz + ref_spk_embedding: torch.Tensor # (D,) + x_vector_only_mode: bool + icl_mode: bool + ref_text: str | None = None + + +class VoiceCacheManager: + """ + Voice cache manager, responsible for managing custom voice cache functionality. + + Main features: + 1. Load uploaded speaker information from metadata.json + 2. Manage voice clone prompt cache + 3. Update cache status to metadata.json + + Security properties: + - No pickle / torch.load + - Safetensors-only + - Cache path confined to voice samples directory + """ + + def __init__(self, speech_voice_samples_dir: str | None = None, metadata_manager: MetadataManager | None = None): + """ + Initialize the voice cache manager. + + Args: + speech_voice_samples_dir: Speech voice samples directory path, + if None, get from environment variable + metadata_manager: Optional MetadataManager instance for shared metadata access. + If not provided, will create its own (less efficient). + """ + self.speech_voice_samples_dir = speech_voice_samples_dir or os.environ.get( + "SPEECH_VOICE_SAMPLES", "/tmp/voice_samples" + ) + + # Initialize metadata manager + if metadata_manager is not None: + self.metadata_manager = metadata_manager + else: + metadata_file = Path(self.speech_voice_samples_dir) / "metadata.json" + self.metadata_manager = MetadataManager(metadata_file) + + # ------------------------------------------------------------------ + # Metadata helpers + # ------------------------------------------------------------------ + + def load_uploaded_speakers_from_metadata(self) -> dict[str, Any] | None: + """Load uploaded speakers from metadata manager.""" + try: + return self.metadata_manager.get_uploaded_speakers() + except Exception as e: + logger.warning(f"Failed to load uploaded speakers from metadata: {e}") + return None + + def update_metadata_cache_info(self, speaker: str, cache_file_path: Path, status: str = "ready") -> bool: + """ + Update cache information using metadata manager. + + Args: + speaker: Speaker name + cache_file_path: Cache file path + status: Cache status, default is "ready" + + Returns: + bool: Whether the update was successful + """ + try: + speaker_key = speaker.lower() + return self.metadata_manager.update_cache_info( + speaker_key=speaker_key, cache_file_path=cache_file_path, status=status + ) + except Exception as e: + logger.error(f"Failed to update metadata cache info: {e}") + return False + + # ------------------------------------------------------------------ + # Cache save (SAFE) + # ------------------------------------------------------------------ + + def save_voice_cache( + self, + speaker: str, + audio_file_path: Path, + prompt_items: list[VoiceClonePromptItem], + ) -> bool: + """ + Save voice cache using safetensors (no pickle, no RCE). + """ + try: + cache_file_path = audio_file_path.with_suffix(".safetensors") + + tensors: dict[str, torch.Tensor] = {} + metadata: dict[str, str] = {} + + tensors["__len__"] = torch.tensor(len(prompt_items), dtype=torch.int64) + + for i, item in enumerate(prompt_items): + prefix = f"item_{i}_" + + tensors[prefix + "ref_spk_embedding"] = item.ref_spk_embedding.detach().cpu() + + has_ref_code = item.ref_code is not None + tensors[prefix + "has_ref_code"] = torch.tensor(int(has_ref_code), dtype=torch.int8) + + if has_ref_code: + tensors[prefix + "ref_code"] = item.ref_code.detach().cpu() + + tensors[prefix + "x_vector_only_mode"] = torch.tensor(int(item.x_vector_only_mode), dtype=torch.int8) + tensors[prefix + "icl_mode"] = torch.tensor(int(item.icl_mode), dtype=torch.int8) + + if item.ref_text is not None: + metadata[prefix + "ref_text"] = item.ref_text + + save_file(tensors, str(cache_file_path), metadata=metadata) + + return self.update_metadata_cache_info( + speaker=speaker, + cache_file_path=cache_file_path, + status="ready", + ) + + except Exception as e: + logger.error(f"Failed to save safetensors cache for speaker {speaker}: {e}") + self.update_metadata_cache_info(speaker, Path(""), "failed") + return False + + # ------------------------------------------------------------------ + # Cache load (SAFE) + # ------------------------------------------------------------------ + + def load_cached_voice_prompt( + self, + speaker: str, + device: str | None = None, + ) -> list[VoiceClonePromptItem] | None: + """ + Load cached VoiceClonePromptItem list from safetensors. + """ + try: + uploaded_speakers = self.load_uploaded_speakers_from_metadata() + if not uploaded_speakers: + return None + + speaker_key = speaker.lower() + if speaker_key not in uploaded_speakers: + return None + + speaker_info = uploaded_speakers[speaker_key] + if speaker_info.get("cache_status") != "ready": + return None + + cache_file_path = Path(speaker_info.get("cache_file", "")).resolve() + + base_dir = Path(self.speech_voice_samples_dir).resolve() + + # ---- Path confinement (critical security check) + if not str(cache_file_path).startswith(str(base_dir)): + logger.error(f"Illegal cache path outside base dir: {cache_file_path}") + return None + + if not cache_file_path.exists(): + return None + + if cache_file_path.suffix != ".safetensors": + logger.error(f"Legacy or unsafe cache format rejected: {cache_file_path}") + return None + + with safe_open(cache_file_path, framework="pt", device="cpu") as f: + meta = f.metadata() + + num_items = int(f.get_tensor("__len__").item()) + result: list[VoiceClonePromptItem] = [] + + for i in range(num_items): + prefix = f"item_{i}_" + + has_ref_code = bool(f.get_tensor(prefix + "has_ref_code").item()) + + ref_code = f.get_tensor(prefix + "ref_code").to(device) if has_ref_code else None + + ref_spk_embedding = f.get_tensor(prefix + "ref_spk_embedding").to(device) + + x_vector_only_mode = bool(f.get_tensor(prefix + "x_vector_only_mode").item()) + icl_mode = bool(f.get_tensor(prefix + "icl_mode").item()) + + ref_text = meta.get(prefix + "ref_text") + + result.append( + VoiceClonePromptItem( + ref_code=ref_code, + ref_spk_embedding=ref_spk_embedding, + x_vector_only_mode=x_vector_only_mode, + icl_mode=icl_mode, + ref_text=ref_text, + ) + ) + + logger.info(f"Safetensors cache loaded for speaker: {speaker}") + return result + + except Exception as e: + logger.warning(f"Failed to load safetensors cache for speaker {speaker}: {e}") + return None + + # ------------------------------------------------------------------ + # Audio path helper + # ------------------------------------------------------------------ + + def get_speaker_audio_path(self, speaker: str) -> Path | None: + """ + Get speaker's audio file path. + + Args: + speaker: Speaker name + + Returns: + Optional[Path]: Audio file path, returns None if speaker doesn't exist + """ + uploaded_speakers = self.load_uploaded_speakers_from_metadata() + if not uploaded_speakers: + return None + + speaker_key = speaker.lower() + if speaker_key not in uploaded_speakers: + return None + + audio_file_path = Path(uploaded_speakers[speaker_key]["file_path"]) + if audio_file_path.exists(): + return audio_file_path + + logger.warning(f"Audio file not found for speaker {speaker}: {audio_file_path}") + return None