diff --git a/benchmarks/fish-speech/bench_voice_cache.py b/benchmarks/fish-speech/bench_speaker_cache.py similarity index 98% rename from benchmarks/fish-speech/bench_voice_cache.py rename to benchmarks/fish-speech/bench_speaker_cache.py index 8d465d6489f..a2ac460afbc 100644 --- a/benchmarks/fish-speech/bench_voice_cache.py +++ b/benchmarks/fish-speech/bench_speaker_cache.py @@ -3,11 +3,11 @@ Measures TTFP improvement from DAC-code caching when using uploaded voices. Setup: - 1. Start vllm-omni with Fish Speech S2 Pro (use our feat branch) + 1. Start vllm-omni with Fish Speech S2 Pro 2. Provide a reference audio file for voice cloning Usage: - python bench_voice_cache.py \ + python bench_speaker_cache.py \ --ref-audio /path/to/reference.wav \ --ref-text "Transcript of the reference audio." \ --num-prompts 20 \ diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md index 407aa2f6e68..11813ae35c3 100644 --- a/docs/serving/speech_api.md +++ b/docs/serving/speech_api.md @@ -358,6 +358,31 @@ curl -X POST http://localhost:8091/v1/audio/speech \ }' --output cloned.wav ``` +### Voice Storage & Caching + +Uploaded voices are persisted to disk as a single `.safetensors` file per voice +(audio samples + metadata — name, consent, ref_text, sample_rate, created_at — +in the file header). On server restart the directory is scanned and all +previously uploaded voices are restored automatically, so uploads survive +process restarts. + +Uploading an existing name overwrites the previous entry (a warning is logged). + +Feature extraction artifacts (ref_code, speaker_embedding, DAC codes, etc.) +are cached in-process with a shared LRU so repeated requests with the same +`voice=...` skip the extraction pipeline. The cache is a true singleton across +all TTS model types; deleting a voice invalidates every model-type slot at +once. + +**Configuration (environment variables):** + +| Variable | Default | Description | +|----------|---------|-------------| +| `SPEAKER_SAMPLES_DIR` | `~/.cache/vllm-omni/speakers` | Directory for persisted uploaded speakers (`.safetensors` files). | +| `SPEAKER_MAX_UPLOADED` | `1000` | Maximum number of uploaded speakers kept on disk. Upload requests past the cap return 400. | + +The in-memory LRU has a fixed 512 MiB byte budget. + ## Batch Speech Generation The batch endpoint synthesizes multiple texts in a single request, returning all results as JSON with base64-encoded audio. @@ -543,6 +568,30 @@ Fish Speech uses `ref_audio` and `ref_text` for voice cloning (no `task_type` ne |-------|-------------| | `mistralai/Voxtral-4B-TTS-2603` | 3B AR + FlowMatching TTS. Supports text-to-speech with preset voices. | +### CosyVoice3 + +| Model | Description | +|-------|-------------| +| `FunAudioLLM/Fun-CosyVoice3-0.5B-2512` | Voice cloning from `ref_audio` + `ref_text`. No built-in voice presets — upload a voice or pass `ref_audio`/`ref_text` per request. | + +### OmniVoice + +| Model | Description | +|-------|-------------| +| `k2-fsa/OmniVoice` | Pure-diffusion TTS. Supports voice cloning via `ref_audio` (with optional `ref_text`); no built-in voice presets. | + +### VoxCPM2 + +| Model | Description | +|-------|-------------| +| `openbmb/VoxCPM2` | TTS + voice cloning with built-in speaker presets and uploaded-voice support. Accepts `voice` (preset or uploaded) or `ref_audio` + optional `ref_text`. | + +### MOSS-TTS-Nano + +| Model | Description | +|-------|-------------| +| `OpenMOSS-Team/MOSS-TTS-Nano` | Voice cloning only. Requires `ref_audio` (or an uploaded `voice`); no built-in voice presets. `ref_text` is accepted but ignored — upstream's `voice_clone` mode does not consume a transcript. | + ## Error Responses ### 400 Bad Request diff --git a/docs/user_guide/examples/offline_inference/voxtral_tts.md b/docs/user_guide/examples/offline_inference/voxtral_tts.md new file mode 100644 index 00000000000..c6f41ac0875 --- /dev/null +++ b/docs/user_guide/examples/offline_inference/voxtral_tts.md @@ -0,0 +1,68 @@ +# Voxtral TTS Offline Inference + +Source . + + +`end2end.py` runs Voxtral TTS end-to-end offline inference using vLLM. It supports both blocking (`Omni`) and streaming (`AsyncOmni`) generation, batched prompts with configurable concurrency, and voice selection via preset name or reference audio file. + +When `mistral_common` has `SpeechRequest` support, prompt token IDs are built via `encode_speech_request`. Otherwise, the script falls back to manual token construction. + +## Usage Examples + + +```bash +# Basic single-prompt with cheerful_female voice preset +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --write-audio --voice cheerful_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# 32 replicate prompts with cheerful_female voice preset +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --num-prompts 32 --write-audio --voice cheerful_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# Streaming with neutral_female voice preset +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --streaming --write-audio --voice neutral_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# 32 prompts, 8 concurrent requests per wave, streaming with neutral_female voice +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --num-prompts 32 --concurrency 8 --streaming --write-audio --voice neutral_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# Short debug prompt with reference audio +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --write-audio \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "This is a test message." \ + --audio-path path/to/reference_audio.wav +``` + +## Arguments + +| Argument | Description | +|---|---| +| `--model PATH` | HuggingFace repo ID or local directory path (default: `mistralai/Voxtral-4B-TTS-2603`) | +| `--text TEXT` | Text to synthesize (default: `"This is a test message."`) | +| `--audio-path PATH` | Path to reference audio file for voice cloning | +| `--output-dir DIR` | Directory to write output WAV files (default: `output_audio`) | +| `--deploy-config PATH` | Override the deploy config path. If unset, auto-loads `vllm_omni/deploy/voxtral_tts.yaml` from the HF `model_type`. | +| `--num-prompts N` | Number of replicate prompts to run for measuring performance (default: 1) | +| `--streaming` | Use streaming generation via `AsyncOmni` (default: blocking `Omni`) | +| `--concurrency N` | Max concurrent requests per wave (must be used with `--streaming`, must evenly divide `--num-prompts`) | +| `--voice NAME` | Voice preset to use instead of reference audio file (e.g., casual_female, casual_male, cheerful_female, neutral_female, neutral_male) | +| `--write-audio` | Write generated audio to WAV files | +| `--profiling-mode` | Enable profiling mode (reduces max tokens to 50) | +| `--log-stats` | Enable detailed statistics logging | + +## Example materials + +??? abstract "end2end.py" + ``````py + --8<-- "examples/offline_inference/voxtral_tts/end2end.py" + `````` diff --git a/examples/offline_inference/voxtral_tts/README.md b/examples/offline_inference/voxtral_tts/README.md new file mode 100644 index 00000000000..bbe317798a8 --- /dev/null +++ b/examples/offline_inference/voxtral_tts/README.md @@ -0,0 +1,58 @@ +# Voxtral TTS Offline Inference + +`end2end.py` runs Voxtral TTS end-to-end offline inference using vLLM. It supports both blocking (`Omni`) and streaming (`AsyncOmni`) generation, batched prompts with configurable concurrency, and voice selection via preset name or reference audio file. + +When `mistral_common` has `SpeechRequest` support, prompt token IDs are built via `encode_speech_request`. Otherwise, the script falls back to manual token construction. + +## Usage Examples + + +```bash +# Basic single-prompt with cheerful_female voice preset +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --write-audio --voice cheerful_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# 32 replicate prompts with cheerful_female voice preset +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --num-prompts 32 --write-audio --voice cheerful_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# Streaming with neutral_female voice preset +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --streaming --write-audio --voice neutral_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# 32 prompts, 8 concurrent requests per wave, streaming with neutral_female voice +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --num-prompts 32 --concurrency 8 --streaming --write-audio --voice neutral_female \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?" + +# Short debug prompt with reference audio +python3 examples/offline_inference/voxtral_tts/end2end.py \ + --write-audio \ + --model mistralai/Voxtral-4B-TTS-2603 \ + --text "This is a test message." \ + --audio-path path/to/reference_audio.wav +``` + +## Arguments + +| Argument | Description | +|---|---| +| `--model PATH` | HuggingFace repo ID or local directory path (default: `mistralai/Voxtral-4B-TTS-2603`) | +| `--text TEXT` | Text to synthesize (default: `"This is a test message."`) | +| `--audio-path PATH` | Path to reference audio file for voice cloning | +| `--output-dir DIR` | Directory to write output WAV files (default: `output_audio`) | +| `--deploy-config PATH` | Override the deploy config path. If unset, auto-loads `vllm_omni/deploy/voxtral_tts.yaml` from the HF `model_type`. | +| `--num-prompts N` | Number of replicate prompts to run for measuring performance (default: 1) | +| `--streaming` | Use streaming generation via `AsyncOmni` (default: blocking `Omni`) | +| `--concurrency N` | Max concurrent requests per wave (must be used with `--streaming`, must evenly divide `--num-prompts`) | +| `--voice NAME` | Voice preset to use instead of reference audio file. Check Huggingface `mistralai/Voxtral-4B-TTS-2603` to get the list of available voices | +| `--write-audio` | Write generated audio to WAV files | +| `--profiling-mode` | Enable profiling mode (reduces max tokens to 50) | +| `--log-stats` | Enable detailed statistics logging | diff --git a/tests/conftest.py b/tests/conftest.py index 77075f9525a..2f23dcd9f7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ "tests.helpers.fixtures.log", "tests.helpers.fixtures.run_args", "tests.helpers.fixtures.runtime", + "tests.helpers.fixtures.speaker_cache", ) diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py index f94b11dbdfa..64f2e3d09b2 100644 --- a/tests/entrypoints/openai_api/test_serving_speech.py +++ b/tests/entrypoints/openai_api/test_serving_speech.py @@ -300,6 +300,24 @@ def client(test_app): class TestSpeechAPI: + @pytest.fixture(autouse=True) + def _mock_upload_io(self, mocker: MockerFixture): + """Mock soundfile/safetensors so upload accepts fake audio bytes.""" + samples = np.zeros(88200, dtype=np.float32) # 2s @ 44.1 kHz + mocker.patch("soundfile.read", return_value=(samples, 44100)) + + def _fake_save_file(tensors, path, metadata=None): + Path(path).touch() + + mocker.patch("safetensors.torch.save_file", side_effect=_fake_save_file) + mock_ctx = mocker.MagicMock() + mock_ctx.keys.return_value = ["audio"] + mock_ctx.get_tensor.return_value = torch.zeros(88200) + mock_ctx.metadata.return_value = {"sample_rate": "44100"} + mock_safe_open = mocker.MagicMock() + mock_safe_open.return_value.__enter__.return_value = mock_ctx + mocker.patch("safetensors.safe_open", mock_safe_open) + def test_create_speech_success(self, client): payload = { "input": "Hello world", @@ -470,27 +488,17 @@ def test_upload_voice_invalid_mime_type(self, client): assert "MIME type" in result["detail"] def test_upload_voice_name_collision(self, client): - """Test voice upload with duplicate name.""" - # First upload + """Re-uploading the same name overwrites the previous entry (no 400).""" audio_content = b"fake audio content" - files = { - "audio_sample": ("test.wav", audio_content, "audio/wav"), - } - data = { - "consent": "user_consent_123", - "name": "test_voice", - } + files = {"audio_sample": ("test.wav", audio_content, "audio/wav")} + data = {"consent": "user_consent_123", "name": "test_voice"} response = client.post("/v1/audio/voices", files=files, data=data) assert response.status_code == 200 - # Second upload with same name response = client.post("/v1/audio/voices", files=files, data=data) - assert response.status_code == 400 - result = response.json() - assert "detail" in result - assert "already exists" in result["detail"] - response = client.delete("/v1/audio/voices/test_voice") + assert response.status_code == 200 + client.delete("/v1/audio/voices/test_voice") def test_upload_voice_missing_parameters(self, client): """Test voice upload with missing required parameters.""" @@ -970,7 +978,7 @@ def test_build_tts_params_with_uploaded_voice(self, speech_server, mocker: Mocke "file_path": "/tmp/voice_samples/custom_voice_consent_123.wav", "mime_type": "audio/wav", "ref_text": None, - "created_at": 1711234567.89, + "created_at": 1711234567, } } speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"} @@ -983,7 +991,7 @@ def test_build_tts_params_with_uploaded_voice(self, speech_server, mocker: Mocke assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"] assert params["x_vector_only_mode"] == [True] assert params["task_type"] == ["Base"] - assert params["voice_created_at"] == [1711234567.89] + assert params["voice_created_at"] == [1711234567] assert "ref_text" not in params def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server, mocker: MockerFixture): @@ -994,7 +1002,7 @@ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server, mock "file_path": "/tmp/voice_samples/custom_voice_consent_123.wav", "mime_type": "audio/wav", "ref_text": "Hello world transcript", - "created_at": 1711234567.89, + "created_at": 1711234567, } } speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"} @@ -1008,7 +1016,7 @@ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server, mock assert params["x_vector_only_mode"] == [False] assert params["task_type"] == ["Base"] assert params["ref_text"] == ["Hello world transcript"] - assert params["voice_created_at"] == [1711234567.89] + assert params["voice_created_at"] == [1711234567] def test_build_tts_params_without_uploaded_voice(self, speech_server): """Test _build_tts_params does not auto-set ref_audio for non-uploaded voices.""" @@ -1051,28 +1059,29 @@ def test_build_tts_params_with_explicit_ref_audio(self, speech_server): assert "x_vector_only_mode" not in params def test_get_uploaded_audio_data(self, speech_server, mocker: MockerFixture): - """Test _get_uploaded_audio_data function.""" - # Mock file operations - mock_open = mocker.patch("builtins.open", create=True) - mock_b64encode = mocker.patch("base64.b64encode") - mock_exists = mocker.patch("pathlib.Path.exists") - mock_exists.return_value = True - mock_b64encode.return_value = b"ZmFrZWF1ZGlv" - - # Setup mock file - mock_file = mocker.MagicMock() - mock_file.read.return_value = b"fakeaudio" - mock_open.return_value.__enter__.return_value = mock_file + """Returns a data URL by loading audio via safetensors + re-encoding WAV.""" + mocker.patch("pathlib.Path.exists", return_value=True) + mocker.patch("soundfile.write") + mocker.patch("base64.b64encode", return_value=b"ZmFrZWF1ZGlv") + mock_ctx = mocker.MagicMock() + mock_ctx.keys.return_value = ["audio"] + mock_ctx.get_tensor.return_value = torch.zeros(88200) + mock_ctx.metadata.return_value = {"sample_rate": "44100"} + mock_safe_open = mocker.MagicMock() + mock_safe_open.return_value.__enter__.return_value = mock_ctx + mocker.patch("safetensors.safe_open", mock_safe_open) - # Setup uploaded speaker speech_server.uploaded_speakers = { - "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"} + "test_voice": { + "name": "test_voice", + "file_path": "/tmp/test.safetensors", + "mime_type": "audio/wav", + "embedding_source": "audio", + "sample_rate": 44100, + } } result = speech_server._get_uploaded_audio_data("test_voice") - assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv" - mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb") - mock_b64encode.assert_called_once_with(b"fakeaudio") def test_get_uploaded_audio_data_missing_file(self, speech_server, mocker: MockerFixture): """Test _get_uploaded_audio_data when file is missing.""" @@ -1230,24 +1239,6 @@ def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_ # Must NOT have ref_audio — that would fail for safetensors files assert "ref_audio" not in params - def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server, mocker: MockerFixture): - """Validation should reject embedding voices whose cache is not yet ready.""" - speech_server.uploaded_speakers = { - "myvoice": { - "name": "myvoice", - "file_path": "/tmp/myvoice.safetensors", - "mime_type": "application/x-safetensors", - "embedding_source": "direct", - "cache_status": "pending", - "cache_file": None, - } - } - req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "myvoice", "task_type": "Base"}) - mocker.patch("pathlib.Path.exists", return_value=True) - err = speech_server._validate_qwen_tts_request(req) - assert err is not None - assert "not yet ready" in err - def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_server, mocker: MockerFixture): """x_vector_only_mode set by uploaded embedding must not be overwritten by request field.""" speech_server.uploaded_speakers = { @@ -2294,6 +2285,7 @@ def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server, mocker: M "mm_processor_kwargs": {"prompt_text": "ref text", "sample_rate": 24000}, } ) + cosyvoice3_server._apply_cosyvoice3_dynamic_tokens = mocker.MagicMock(side_effect=lambda spl, req: spl) request = OpenAICreateSpeechRequest( input="Hello", diff --git a/tests/helpers/fixtures/speaker_cache.py b/tests/helpers/fixtures/speaker_cache.py new file mode 100644 index 00000000000..cf7a6212212 --- /dev/null +++ b/tests/helpers/fixtures/speaker_cache.py @@ -0,0 +1,21 @@ +"""Fixtures for the process-wide speaker cache singleton.""" + +from __future__ import annotations + +import pytest + + +@pytest.fixture +def fresh_speaker_cache(): + """Reset the process-wide speaker cache singleton before and after the test.""" + import vllm_omni.utils.speaker_cache as _sc + + def _reset(): + with _sc._SINGLETON_LOCK: + if _sc._SINGLETON is not None: + _sc._SINGLETON.clear() + _sc._SINGLETON = None + + _reset() + yield + _reset() diff --git a/tests/model_executor/models/test_fish_speech_voice_cache.py b/tests/model_executor/models/test_fish_speech_voice_cache.py deleted file mode 100644 index fef4b551ab2..00000000000 --- a/tests/model_executor/models/test_fish_speech_voice_cache.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Tests for Fish Speech DAC-code caching via VoiceEmbeddingCache. - -Covers: - - Cache miss → DAC encode → store - - Cache hit → skip DAC encode, reuse cached ref_codes_fq - - Inline ref_audio (no voice name) → no caching, full encode path - - Stale-cache protection via created_at - - Temp file cleanup on cache hit -""" - -import os -import tempfile - -import numpy as np -import pytest -import torch -from pytest_mock import MockerFixture - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -def _make_info_dict( - *, - text: str = "Hello world", - ref_text: str = "Reference transcript", - ref_audio_sr: int = 44100, - voice_name: str | None = None, - voice_created_at: float | None = None, - ref_audio_path: str | None = None, -) -> dict: - """Build a minimal info_dict for _build_structured_voice_clone_prefill_embeds.""" - d: dict = { - "text": text, - "ref_text": ref_text, - "ref_audio_sr": ref_audio_sr, - "fish_structured_voice_clone": True, - } - if ref_audio_path is not None: - d["ref_audio_path"] = ref_audio_path - if voice_name is not None: - d["voice_name"] = voice_name - if voice_created_at is not None: - d["voice_created_at"] = voice_created_at - return d - - -def _write_temp_npy(wav: np.ndarray | None = None) -> str: - """Write a temporary .npy file with dummy audio and return its path.""" - if wav is None: - wav = np.random.randn(44100).astype(np.float32) # 1 second @ 44.1kHz - with tempfile.NamedTemporaryFile(prefix="fish_test_", suffix=".npy", delete=False) as f: - np.save(f, wav) - return f.name - - -# Fake ref_codes_fq: [frames, codebooks] -_FAKE_REF_CODES = torch.randint(0, 1024, (10, 10), dtype=torch.long) - - -class TestFishSpeechVoiceCacheIntegration: - """Test the cache-hit / cache-miss / no-cache paths in the model.""" - - @pytest.fixture - def mock_model(self, mocker: MockerFixture): - """Create a mock FishSpeechSlowARForConditionalGeneration with cache.""" - from vllm_omni.utils.voice_cache import VoiceEmbeddingCache - - model = mocker.MagicMock() - model._voice_cache = VoiceEmbeddingCache(max_entries=4) - model._semantic_begin_id = 151678 - model._num_codebooks = 10 - model._codebook_size = 4096 - model.model_path = "/fake/model" - model.codebook_embeddings = mocker.MagicMock() - model.codebook_embeddings.weight = mocker.MagicMock() - model.codebook_embeddings.weight.device = torch.device("cpu") - return model - - def test_cache_miss_stores_codes(self, mock_model): - """First request with a named voice should encode and store in cache.""" - cache = mock_model._voice_cache - voice_name = "alice" - created_at = 1712345678.0 - - # Verify cache starts empty. - key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) - assert cache.get(key) is None - - # Simulate a cache store (what the model does on miss). - cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) - - # Verify it's now cached. - cached = cache.get(key) - assert cached is not None - assert torch.equal(cached["ref_codes_fq"], _FAKE_REF_CODES) - - def test_cache_hit_returns_cached_codes(self, mock_model): - """Second request with same voice should hit cache.""" - cache = mock_model._voice_cache - voice_name = "alice" - created_at = 1712345678.0 - - key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) - cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) - - # Hit. - cached = cache.get(key) - assert cached is not None - ref_codes = cached["ref_codes_fq"].to(device=torch.device("cpu"), dtype=torch.long) - assert torch.equal(ref_codes, _FAKE_REF_CODES) - assert cache.stats()["hits"] >= 1 - - def test_no_voice_name_skips_cache(self, mock_model): - """Inline ref_audio without voice_name should not use cache.""" - cache = mock_model._voice_cache - - # Without voice_name, the model should not interact with cache at all. - info = _make_info_dict(voice_name=None, ref_audio_path=_write_temp_npy()) - assert info.get("voice_name") is None - # Cache should remain untouched. - assert cache.stats()["hits"] == 0 - assert cache.stats()["misses"] == 0 - - def test_stale_cache_on_reupload(self, mock_model): - """Re-uploading a voice (new created_at) should not hit old cache.""" - cache = mock_model._voice_cache - voice_name = "alice" - - key_old = cache.make_cache_key(voice_name, xvec_only=False, created_at=1000.0) - cache.put(key_old, {"ref_codes_fq": _FAKE_REF_CODES}) - - # Re-upload produces a different created_at. - key_new = cache.make_cache_key(voice_name, xvec_only=False, created_at=2000.0) - assert cache.get(key_new) is None # miss - assert cache.get(key_old) is not None # old still there - - def test_temp_file_cleaned_on_cache_hit(self): - """On cache hit, the temp .npy file written by the entrypoint should be deleted.""" - tmp_path = _write_temp_npy() - assert os.path.exists(tmp_path) - - # Simulate what the model does on cache hit: remove the temp file. - try: - os.remove(tmp_path) - except OSError: - pass - assert not os.path.exists(tmp_path) - - def test_created_at_zero_disables_cache(self, mock_model): - """created_at=0 should not create a cache key (caching disabled).""" - cache = mock_model._voice_cache - - info = _make_info_dict( - voice_name="bob", - voice_created_at=0.0, - ref_audio_path=_write_temp_npy(), - ) - # The model checks: if _created_at > 0 → enable cache. - # With 0.0, no cache interaction should happen. - _created_at = float(info.get("voice_created_at", 0)) - assert _created_at <= 0 - assert cache.stats()["hits"] == 0 - assert cache.stats()["misses"] == 0 - - -class TestFishSpeechValidatorUploadedVoice: - """Test _validate_fish_tts_request uploaded voice resolution.""" - - def test_uploaded_voice_resolves_ref_audio(self, mocker: MockerFixture): - """When voice matches an uploaded speaker, ref_audio should be auto-set.""" - request = mocker.MagicMock() - request.input = "Hello" - request.voice = "alice" - request.ref_audio = None - request.ref_text = None - request.max_new_tokens = None - - # Uploaded speaker with ref_text. - uploaded_speakers = { - "alice": { - "file_path": "/tmp/fake_audio.wav", - "ref_text": "Hi this is Alice", - "created_at": 1712345678, - }, - } - - # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL. - mocker.patch("pathlib.Path.exists", return_value=True) - voice_lower = request.voice.lower() - assert voice_lower in uploaded_speakers - - speaker_info = uploaded_speakers[voice_lower] - ref_text_from_upload = speaker_info.get("ref_text") - assert ref_text_from_upload == "Hi this is Alice" - - def test_uploaded_voice_without_ref_text_uses_request_ref_text(self, mocker: MockerFixture): - """If upload has no ref_text but request provides it, use request's.""" - request = mocker.MagicMock() - request.input = "Hello" - request.voice = "bob" - request.ref_audio = None - request.ref_text = "Request-level transcript" - request.max_new_tokens = None - - uploaded_speakers = { - "bob": { - "file_path": "/tmp/fake_audio.wav", - "ref_text": None, - "created_at": 1712345678, - }, - } - - voice_lower = request.voice.lower() - speaker_info = uploaded_speakers[voice_lower] - upload_ref_text = speaker_info.get("ref_text") - # Upload has no ref_text, so request.ref_text should remain. - assert upload_ref_text is None - assert request.ref_text == "Request-level transcript" diff --git a/tests/test_fish_speech_voice_cache.py b/tests/test_fish_speech_voice_cache.py deleted file mode 100644 index 1c299d80142..00000000000 --- a/tests/test_fish_speech_voice_cache.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Tests for Fish Speech DAC-code caching via VoiceEmbeddingCache. - -Covers: - - Cache miss → DAC encode → store - - Cache hit → skip DAC encode, reuse cached ref_codes_fq - - Inline ref_audio (no voice name) → no caching, full encode path - - Stale-cache protection via created_at - - Temp file cleanup on cache hit -""" - -import os -import tempfile -from pathlib import Path - -import numpy as np -import pytest -import torch -from pytest_mock import MockerFixture - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -def _make_info_dict( - *, - text: str = "Hello world", - ref_text: str = "Reference transcript", - ref_audio_sr: int = 44100, - voice_name: str | None = None, - voice_created_at: float | None = None, - ref_audio_path: str | None = None, -) -> dict: - """Build a minimal info_dict for _build_structured_voice_clone_prefill_embeds.""" - d: dict = { - "text": text, - "ref_text": ref_text, - "ref_audio_sr": ref_audio_sr, - "fish_structured_voice_clone": True, - } - if ref_audio_path is not None: - d["ref_audio_path"] = ref_audio_path - if voice_name is not None: - d["voice_name"] = voice_name - if voice_created_at is not None: - d["voice_created_at"] = voice_created_at - return d - - -def _write_temp_npy(wav: np.ndarray | None = None) -> str: - """Write a temporary .npy file with dummy audio and return its path.""" - if wav is None: - wav = np.random.randn(44100).astype(np.float32) # 1 second @ 44.1kHz - with tempfile.NamedTemporaryFile(prefix="fish_test_", suffix=".npy", delete=False) as f: - np.save(f, wav) - return f.name - - -# Fake ref_codes_fq: [frames, codebooks] -_FAKE_REF_CODES = torch.randint(0, 1024, (10, 10), dtype=torch.long) - - -class TestFishSpeechVoiceCacheIntegration: - """Test the cache-hit / cache-miss / no-cache paths in the model.""" - - @pytest.fixture - def mock_model(self, mocker: MockerFixture): - """Create a mock FishSpeechSlowARForConditionalGeneration with cache.""" - from vllm_omni.utils.voice_cache import VoiceEmbeddingCache - - model = mocker.MagicMock() - model._voice_cache = VoiceEmbeddingCache(max_entries=4) - model._semantic_begin_id = 151678 - model._num_codebooks = 10 - model._codebook_size = 4096 - model.model_path = "/fake/model" - model.codebook_embeddings = mocker.MagicMock() - model.codebook_embeddings.weight = mocker.MagicMock() - model.codebook_embeddings.weight.device = torch.device("cpu") - return model - - def test_cache_miss_stores_codes(self, mock_model): - """First request with a named voice should encode and store in cache.""" - cache = mock_model._voice_cache - voice_name = "alice" - created_at = 1712345678.0 - - # Verify cache starts empty. - key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) - assert cache.get(key) is None - - # Simulate a cache store (what the model does on miss). - cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) - - # Verify it's now cached. - cached = cache.get(key) - assert cached is not None - assert torch.equal(cached["ref_codes_fq"], _FAKE_REF_CODES) - - def test_cache_hit_returns_cached_codes(self, mock_model): - """Second request with same voice should hit cache.""" - cache = mock_model._voice_cache - voice_name = "alice" - created_at = 1712345678.0 - - key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) - cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) - - # Hit. - cached = cache.get(key) - assert cached is not None - ref_codes = cached["ref_codes_fq"].to(device=torch.device("cpu"), dtype=torch.long) - assert torch.equal(ref_codes, _FAKE_REF_CODES) - assert cache.stats()["hits"] >= 1 - - def test_no_voice_name_skips_cache(self, mock_model): - """Inline ref_audio without voice_name should not use cache.""" - cache = mock_model._voice_cache - - # Without voice_name, the model should not interact with cache at all. - info = _make_info_dict(voice_name=None, ref_audio_path=_write_temp_npy()) - assert info.get("voice_name") is None - # Cache should remain untouched. - assert cache.stats()["hits"] == 0 - assert cache.stats()["misses"] == 0 - - def test_stale_cache_on_reupload(self, mock_model): - """Re-uploading a voice (new created_at) should not hit old cache.""" - cache = mock_model._voice_cache - voice_name = "alice" - - key_old = cache.make_cache_key(voice_name, xvec_only=False, created_at=1000.0) - cache.put(key_old, {"ref_codes_fq": _FAKE_REF_CODES}) - - # Re-upload produces a different created_at. - key_new = cache.make_cache_key(voice_name, xvec_only=False, created_at=2000.0) - assert cache.get(key_new) is None # miss - assert cache.get(key_old) is not None # old still there - - def test_temp_file_cleaned_on_cache_hit(self): - """On cache hit, the temp .npy file written by the entrypoint should be deleted.""" - tmp_path = _write_temp_npy() - assert os.path.exists(tmp_path) - - # Simulate what the model does on cache hit: remove the temp file. - try: - os.remove(tmp_path) - except OSError: - pass - assert not os.path.exists(tmp_path) - - def test_created_at_zero_disables_cache(self, mock_model): - """created_at=0 should not create a cache key (caching disabled).""" - cache = mock_model._voice_cache - - info = _make_info_dict( - voice_name="bob", - voice_created_at=0.0, - ref_audio_path=_write_temp_npy(), - ) - # The model checks: if _created_at > 0 → enable cache. - # With 0.0, no cache interaction should happen. - _created_at = float(info.get("voice_created_at", 0)) - assert _created_at <= 0 - assert cache.stats()["hits"] == 0 - assert cache.stats()["misses"] == 0 - - -class TestFishSpeechValidatorUploadedVoice: - """Test _validate_fish_tts_request uploaded voice resolution.""" - - def test_uploaded_voice_resolves_ref_audio( - self, - monkeypatch: pytest.MonkeyPatch, - mocker: MockerFixture, - ): - """When voice matches an uploaded speaker, ref_audio should be auto-set.""" - request = mocker.MagicMock() - request.input = "Hello" - request.voice = "alice" - request.ref_audio = None - request.ref_text = None - request.max_new_tokens = None - - # Uploaded speaker with ref_text. - uploaded_speakers = { - "alice": { - "file_path": "/tmp/fake_audio.wav", - "ref_text": "Hi this is Alice", - "created_at": 1712345678, - }, - } - - # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL. - monkeypatch.setattr(Path, "exists", lambda self: True) - - voice_lower = request.voice.lower() - assert voice_lower in uploaded_speakers - - speaker_info = uploaded_speakers[voice_lower] - ref_text_from_upload = speaker_info.get("ref_text") - assert ref_text_from_upload == "Hi this is Alice" - - def test_uploaded_voice_without_ref_text_uses_request_ref_text( - self, - mocker: MockerFixture, - ): - """If upload has no ref_text but request provides it, use request's.""" - request = mocker.MagicMock() - request.input = "Hello" - request.voice = "bob" - request.ref_audio = None - request.ref_text = "Request-level transcript" - request.max_new_tokens = None - - uploaded_speakers = { - "bob": { - "file_path": "/tmp/fake_audio.wav", - "ref_text": None, - "created_at": 1712345678, - }, - } - - voice_lower = request.voice.lower() - speaker_info = uploaded_speakers[voice_lower] - upload_ref_text = speaker_info.get("ref_text") - # Upload has no ref_text, so request.ref_text should remain. - assert upload_ref_text is None - assert request.ref_text == "Request-level transcript" diff --git a/tests/test_speaker_cache.py b/tests/test_speaker_cache.py new file mode 100644 index 00000000000..373baacb4a5 --- /dev/null +++ b/tests/test_speaker_cache.py @@ -0,0 +1,150 @@ +import threading + +import pytest +import torch + +from vllm_omni.utils.speaker_cache import SpeakerEmbeddingCache, get_speaker_cache + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +@pytest.fixture +def cache(): + return SpeakerEmbeddingCache(max_bytes=10 * 1024**2) + + +def _k(model: str, name: str, created_at: int = 0) -> tuple[str, str, int]: + return (model, name, created_at) + + +class TestSpeakerEmbeddingCacheBehavior: + def test_miss_returns_none(self, cache): + assert cache.get(_k("voxcpm2", "nonexistent")) is None + + def test_put_and_hit(self, cache): + cache.put(_k("voxcpm2", "alice"), {"val": 42}) + assert cache.get(_k("voxcpm2", "alice"))["val"] == 42 + + def test_lru_access_promotes(self): + c = SpeakerEmbeddingCache(max_bytes=4 * 4096) + for k in ("a", "b", "c", "d"): + c.put(_k("m", k), {"emb": torch.zeros(1024, dtype=torch.float32)}) + c.get(_k("m", "a")) + c.put(_k("m", "e"), {"emb": torch.zeros(1024, dtype=torch.float32)}) + assert c.get(_k("m", "a")) is not None + assert c.get(_k("m", "b")) is None + + def test_put_overwrites(self, cache): + cache.put(_k("m", "k"), {"old": True}) + cache.put(_k("m", "k"), {"new": True}) + assert "new" in cache.get(_k("m", "k")) + assert cache.stats()["entries"] == 1 + + def test_make_cache_key_namespaces_model_type(self): + k1 = SpeakerEmbeddingCache.make_cache_key("alice", model_type="voxcpm2") + k2 = SpeakerEmbeddingCache.make_cache_key("alice", model_type="fish_speech") + assert k1 != k2 + assert k1 == ("voxcpm2", "alice", 0) + assert k2 == ("fish_speech", "alice", 0) + + def test_make_cache_key_created_at_isolation(self): + k_old = SpeakerEmbeddingCache.make_cache_key("alice", model_type="voxcpm2", created_at=1712000000) + k_new = SpeakerEmbeddingCache.make_cache_key("alice", model_type="voxcpm2", created_at=1712000042) + assert k_old != k_new + + def test_make_cache_key_requires_fields(self): + with pytest.raises(ValueError): + SpeakerEmbeddingCache.make_cache_key("", model_type="voxcpm2") + with pytest.raises(ValueError): + SpeakerEmbeddingCache.make_cache_key("alice", model_type="") + + def test_clear_all(self, cache): + cache.put(_k("m", "a"), {"v": 1}) + cache.put(_k("m", "b"), {"v": 2}) + assert cache.clear() == 2 + assert cache.stats()["entries"] == 0 + + def test_clear_matches_speaker_across_model_types(self, cache): + cache.put(_k("voxcpm2", "alice", 1), {"v": 1}) + cache.put(_k("fish_speech", "alice", 2), {"v": 2}) + cache.put(_k("cosyvoice3", "bob", 3), {"v": 3}) + assert cache.clear("alice") == 2 + assert cache.get(_k("voxcpm2", "alice", 1)) is None + assert cache.get(_k("fish_speech", "alice", 2)) is None + assert cache.get(_k("cosyvoice3", "bob", 3)) is not None + + def test_stale_cache_on_reupload(self, cache): + cache.put(_k("voxcpm2", "alice", 1712000000), {"emb": torch.zeros(4), "gen": "old"}) + assert cache.get(_k("voxcpm2", "alice", 1712000042)) is None + + def test_memory_bytes(self, cache): + assert cache.memory_bytes() == 0 + t = torch.zeros(1024, dtype=torch.float32) # 4096 bytes + cache.put(_k("m", "k"), {"emb": t}) + assert cache.memory_bytes() == 4096 + + def test_memory_bytes_ignores_non_tensors(self, cache): + cache.put(_k("m", "k"), {"flag": True, "name": "test", "nothing": None}) + assert cache.memory_bytes() == 0 + + def test_byte_budget_evicts(self): + c = SpeakerEmbeddingCache(max_bytes=8192) + c.put(_k("m", "a"), {"emb": torch.zeros(1024, dtype=torch.float32)}) + c.put(_k("m", "b"), {"emb": torch.zeros(1024, dtype=torch.float32)}) + c.put(_k("m", "c"), {"emb": torch.zeros(1024, dtype=torch.float32)}) + assert c.get(_k("m", "a")) is None + assert c.get(_k("m", "b")) is not None + assert c.get(_k("m", "c")) is not None + assert c.memory_bytes() <= 8192 + + def test_oversize_entry_skipped(self): + c = SpeakerEmbeddingCache(max_bytes=1024) + c.put(_k("m", "huge"), {"emb": torch.zeros(2048, dtype=torch.float32)}) + assert c.get(_k("m", "huge")) is None + assert c.stats()["entries"] == 0 + + def test_stats(self, cache): + cache.put(_k("m", "x"), {"v": 1}) + cache.get(_k("m", "x")) + cache.get(_k("m", "y")) + s = cache.stats() + assert s["hits"] == 1 + assert s["misses"] >= 1 + assert s["entries"] == 1 + + def test_thread_safety(self): + cache = SpeakerEmbeddingCache() + errors = [] + + def worker(tid): + try: + for i in range(50): + cache.put(_k("m", f"t{tid}_v{i}"), {"tid": tid}) + cache.get(_k("m", f"t{tid}_v{i}")) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(t,)) for t in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + assert not errors + assert cache.stats()["entries"] == 500 + + def test_empty_speaker_name_raises_error(self, cache): + with pytest.raises(ValueError, match="speaker_name cannot be an empty string"): + cache.clear("") + + def test_cpu_storage_verification(self, cache): + tensor = torch.randn(10, 128) + cache.put(_k("m", "alice"), {"emb": tensor}) + cached = cache.get(_k("m", "alice")) + assert cached["emb"].device.type == "cpu" + + +class TestSingleton: + def test_singleton_identity(self, fresh_speaker_cache): + a = get_speaker_cache() + b = get_speaker_cache() + assert a is b diff --git a/tests/test_speaker_cache_integration.py b/tests/test_speaker_cache_integration.py new file mode 100644 index 00000000000..8a24659f09c --- /dev/null +++ b/tests/test_speaker_cache_integration.py @@ -0,0 +1,59 @@ +"""Integration tests for the process-wide speaker cache across serving + models.""" + +import pytest +import torch + +from vllm_omni.utils.speaker_cache import SpeakerEmbeddingCache, get_speaker_cache + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class TestSpeakerCacheIntegration: + def test_delete_propagates_across_model_types(self): + cache = SpeakerEmbeddingCache() + voxcpm_key = SpeakerEmbeddingCache.make_cache_key("alice", model_type="voxcpm2") + fish_key = SpeakerEmbeddingCache.make_cache_key("alice", model_type="fish_speech") + bob_key = SpeakerEmbeddingCache.make_cache_key("bob", model_type="voxcpm2") + + cache.put(voxcpm_key, {"emb": torch.zeros(4)}) + cache.put(fish_key, {"emb": torch.zeros(4)}) + cache.put(bob_key, {"emb": torch.zeros(4)}) + + removed = cache.clear("alice") + + assert removed == 2 + assert cache.get(voxcpm_key) is None + assert cache.get(fish_key) is None + assert cache.get(bob_key) is not None + + def test_singleton_shared_across_call_sites(self, fresh_speaker_cache): + cache_a = get_speaker_cache() + cache_b = get_speaker_cache() + assert cache_a is cache_b + key = SpeakerEmbeddingCache.make_cache_key("carol", model_type="voxcpm2") + cache_a.put(key, {"tag": "from_a"}) + got = cache_b.get(key) + assert got is not None + assert got["tag"] == "from_a" + + def test_shutdown_clears_all_entries(self): + cache = SpeakerEmbeddingCache() + for i in range(3): + k = SpeakerEmbeddingCache.make_cache_key(f"voice{i}", model_type="voxcpm2") + cache.put(k, {"emb": torch.zeros(2)}) + assert cache.stats()["entries"] == 3 + cache.clear() + assert cache.stats()["entries"] == 0 + + def test_stale_cache_protection_delete_then_reupload(self): + cache = SpeakerEmbeddingCache() + old_key = SpeakerEmbeddingCache.make_cache_key("alice", model_type="voxcpm2", created_at=1712000000) + cache.put(old_key, {"emb": torch.ones(4) * 3.14, "gen": "old"}) + + new_key = SpeakerEmbeddingCache.make_cache_key("alice", model_type="voxcpm2", created_at=1712000042) + assert cache.get(new_key) is None + assert cache.get(old_key) is not None + + cache.clear("alice") + assert cache.get(old_key) is None + assert cache.get(new_key) is None diff --git a/tests/test_speaker_metadata_persistence.py b/tests/test_speaker_metadata_persistence.py new file mode 100644 index 00000000000..812c156afa0 --- /dev/null +++ b/tests/test_speaker_metadata_persistence.py @@ -0,0 +1,100 @@ +"""Tests for speaker metadata (created_at, consent, ref_text, ...) round-tripping +through the safetensors header. + +These cover the helpers that back ``_restore_uploaded_speakers`` — the logic +that rebuilds ``uploaded_speakers`` on server start by reading the +``.safetensors`` file written during upload. +""" + +import pytest + +from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class TestSpeakerMetadataRoundTrip: + def test_str_only_header(self): + """The safetensors metadata dict must be ``dict[str, str]`` — every value + serialized as a string (ints and None handled).""" + data = { + "name": "Alice", + "voice_name_lower": "alice", + "consent": "xxx", + "created_at": 1712345678, + "sample_rate": 24000, + "embedding_source": "audio", + "ref_text": None, # None values must be dropped + "file_path": "/tmp/should/be/stripped.safetensors", # re-derived, not persisted + } + header = OmniOpenAIServingSpeech._speaker_metadata_to_header(data) + assert all(isinstance(k, str) and isinstance(v, str) for k, v in header.items()) + assert "ref_text" not in header # None stripped + assert "file_path" not in header # never persisted + + def test_int_fields_coerce_back(self): + """Int fields (created_at, file_size, sample_rate, embedding_dim) survive + the string round-trip with their type preserved.""" + data = { + "name": "Bob", + "voice_name_lower": "bob", + "consent": "yyy", + "created_at": 1712345678, + "sample_rate": 44100, + "file_size": 12345, + "embedding_dim": 1024, + "embedding_source": "direct", + } + header = OmniOpenAIServingSpeech._speaker_metadata_to_header(data) + back = OmniOpenAIServingSpeech._speaker_metadata_from_header(header, "/some/path.safetensors") + assert isinstance(back["created_at"], int) + assert back["created_at"] == 1712345678 + assert isinstance(back["sample_rate"], int) + assert back["sample_rate"] == 44100 + assert isinstance(back["file_size"], int) + assert isinstance(back["embedding_dim"], int) + + def test_file_path_reinjected_on_load(self): + """file_path is not persisted in the header; restore derives it from the + actual file location on disk.""" + data = { + "name": "Carol", + "voice_name_lower": "carol", + "consent": "zzz", + "created_at": 1234, + "embedding_source": "audio", + } + header = OmniOpenAIServingSpeech._speaker_metadata_to_header(data) + back = OmniOpenAIServingSpeech._speaker_metadata_from_header(header, "/real/path.safetensors") + assert back["file_path"] == "/real/path.safetensors" + + def test_string_fields_preserved(self): + """ref_text, consent, speaker_description must survive unchanged.""" + data = { + "name": "Dave", + "voice_name_lower": "dave", + "consent": "consent-id-42", + "created_at": 1, + "ref_text": "Hello. This is a transcript with punctuation!", + "speaker_description": "A warm baritone voice.", + "embedding_source": "audio", + } + header = OmniOpenAIServingSpeech._speaker_metadata_to_header(data) + back = OmniOpenAIServingSpeech._speaker_metadata_from_header(header, "/x.safetensors") + assert back["ref_text"] == data["ref_text"] + assert back["speaker_description"] == data["speaker_description"] + assert back["consent"] == data["consent"] + + def test_malformed_int_is_left_as_string(self): + """If an int field somehow contains non-numeric text (manual edit, + corruption), the loader does not crash; it leaves the field as-is.""" + header = { + "name": "Eve", + "voice_name_lower": "eve", + "consent": "x", + "created_at": "not-a-number", + "embedding_source": "audio", + } + back = OmniOpenAIServingSpeech._speaker_metadata_from_header(header, "/p.safetensors") + # Preserved as string rather than raising. + assert back["created_at"] == "not-a-number" diff --git a/tests/test_voice_cache.py b/tests/test_voice_cache.py deleted file mode 100644 index 69327aae571..00000000000 --- a/tests/test_voice_cache.py +++ /dev/null @@ -1,129 +0,0 @@ -import threading - -import pytest - -from vllm_omni.utils.voice_cache import VoiceEmbeddingCache - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -@pytest.fixture -def cache(): - return VoiceEmbeddingCache(max_entries=4) - - -class TestVoiceEmbeddingCache: - def test_miss_returns_none(self, cache: VoiceEmbeddingCache): - assert cache.get("nonexistent") is None - assert cache.stats()["misses"] == 1 - - def test_put_and_hit(self, cache: VoiceEmbeddingCache): - cache.put("abc", {"val": 42}) - result = cache.get("abc") - assert result is not None - assert result["val"] == 42 - assert cache.stats()["hits"] == 1 - - def test_lru_eviction(self, cache: VoiceEmbeddingCache): - for i in range(5): - cache.put(f"key{i}", {"i": i}) - # key0 should have been evicted (oldest, max_entries=4) - assert cache.get("key0") is None - # key1..key4 should still be present - for i in range(1, 5): - assert cache.get(f"key{i}") is not None - assert cache.stats()["entries"] == 4 - - def test_lru_access_promotes(self, cache: VoiceEmbeddingCache): - cache.put("a", {"v": 1}) - cache.put("b", {"v": 2}) - cache.put("c", {"v": 3}) - cache.put("d", {"v": 4}) - # Access "a" to promote it to MRU - cache.get("a") - # Insert "e" -- should evict "b" (now the oldest), not "a" - cache.put("e", {"v": 5}) - assert cache.get("a") is not None - assert cache.get("b") is None - - def test_put_overwrites(self, cache: VoiceEmbeddingCache): - cache.put("k", {"old": True}) - cache.put("k", {"new": True}) - result = cache.get("k") - assert result is not None - assert "new" in result - assert "old" not in result - assert cache.stats()["entries"] == 1 - - def test_make_cache_key_includes_mode(self): - k1 = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=True) - k2 = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=False) - assert k1 != k2 - assert "xvec" in k1 - assert "icl" in k2 - - def test_make_cache_key_deterministic(self): - k1 = VoiceEmbeddingCache.make_cache_key("bob", xvec_only=True) - k2 = VoiceEmbeddingCache.make_cache_key("bob", xvec_only=True) - assert k1 == k2 - - def test_make_cache_key_created_at_isolation(self): - """Different created_at timestamps must produce different keys (stale-cache protection).""" - k1 = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=False, created_at=1000.0) - k2 = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=False, created_at=2000.0) - assert k1 != k2 - - def test_stale_cache_protection(self, cache: VoiceEmbeddingCache): - """Re-upload (new created_at) must produce a cache miss, not a stale hit.""" - key_old = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=False, created_at=1000.0) - key_new = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=False, created_at=2000.0) - cache.put(key_old, {"ref_spk_embedding": "old_emb"}) - # Re-upload produces a new created_at → different key → cold miss - assert cache.get(key_new) is None - # Old key still in cache (not yet evicted) - assert cache.get(key_old) is not None - - def test_cache_mode_isolation(self, cache: VoiceEmbeddingCache): - """xvec entry must NOT be served for an icl request (same voice).""" - key_xvec = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=True) - key_icl = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=False) - cache.put(key_xvec, {"ref_code": None, "ref_spk_embedding": "emb"}) - # icl request should miss — different key - assert cache.get(key_icl) is None - # xvec request should hit - assert cache.get(key_xvec) is not None - - def test_stats_counters(self, cache: VoiceEmbeddingCache): - cache.put("x", {"v": 1}) - cache.get("x") # hit - cache.get("x") # hit - cache.get("y") # miss - s = cache.stats() - assert s["hits"] == 2 - assert s["misses"] == 1 - assert s["entries"] == 1 - assert s["max_entries"] == 4 - - def test_thread_safety(self): - cache = VoiceEmbeddingCache(max_entries=32) - errors = [] - - def worker(thread_id: int): - try: - for i in range(50): - key = f"t{thread_id}_k{i}" - cache.put(key, {"tid": thread_id, "i": i}) - cache.get(key) - cache.get(f"t{(thread_id + 1) % 10}_k{i}") - except Exception as e: - errors.append(e) - - threads = [threading.Thread(target=worker, args=(t,)) for t in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - assert not errors, f"Thread safety errors: {errors}" - s = cache.stats() - assert s["entries"] <= 32 diff --git a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py index c330e91de8d..b331aff9c65 100644 --- a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py +++ b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py @@ -30,6 +30,7 @@ from vllm_omni.model_executor.models.omnivoice.duration import RuleDurationEstimator from vllm_omni.model_executor.models.omnivoice.omnivoice_decoder import OmniVoiceDecoder from vllm_omni.model_executor.models.omnivoice.omnivoice_generator import OmniVoiceGenerator +from vllm_omni.utils.speaker_cache import get_speaker_cache try: from transformers import HiggsAudioV2TokenizerModel @@ -101,6 +102,9 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): # Duration estimator self.duration_estimator = RuleDurationEstimator() + # Speaker cache for ref_audio_tokens + self._speaker_cache = get_speaker_cache() + # Generation parameters self.num_step = self.config.num_step self.guidance_scale = self.config.guidance_scale @@ -144,10 +148,12 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: lang = "None" instruct = "None" + voice_name = None if isinstance(prompt, dict): text = prompt.get("input", prompt.get("text", str(prompt))) ref_audio = prompt.get("ref_audio") ref_text = prompt.get("ref_text") + voice_name = prompt.get("voice_name") lang = prompt.get("lang") or "None" instruct = prompt.get("instruct") or "None" else: @@ -175,17 +181,37 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: text_tokens = torch.tensor(encoding.ids, dtype=torch.long, device=device) text_len = text_tokens.shape[0] - # Encode reference audio tokens if provided + # Encode reference audio tokens if provided (with voice caching) ref_audio_tokens = None if ref_audio is not None: if self.audio_tokenizer is None: raise RuntimeError( "Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'" ) - audio_signal, sr = ref_audio - if isinstance(audio_signal, np.ndarray): - audio_signal = torch.from_numpy(audio_signal).float() - ref_audio_tokens = self._encode_ref_audio(audio_signal, int(sr)).to(device) + # Check speaker cache first + _cache_key = None + if voice_name: + _cache_key = self._speaker_cache.make_cache_key( + voice_name, + model_type="omnivoice", + created_at=int(prompt.get("voice_created_at") or 0), + ) + cached = self._speaker_cache.get(_cache_key) + if cached is not None: + ref_audio_tokens = cached["ref_audio_tokens"].to(device) + _cache_key = None # hit → don't store again + logger.debug("Speaker cache HIT for OmniVoice speaker '%s'", voice_name) + + if ref_audio_tokens is None: + audio_signal, sr = ref_audio + if isinstance(audio_signal, np.ndarray): + audio_signal = torch.from_numpy(audio_signal).float() + ref_audio_tokens = self._encode_ref_audio(audio_signal, int(sr)).to(device) + + # Store in cache for next request + if _cache_key is not None: + self._speaker_cache.put(_cache_key, {"ref_audio_tokens": ref_audio_tokens.cpu()}) + logger.debug("Speaker cache STORE for OmniVoice speaker '%s'", voice_name) # Build conditional + unconditional batches [2, 8, max_len] text_ids = text_tokens.unsqueeze(0).repeat(num_cb, 1) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 2fc55ebb4d0..f6477bdca71 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -54,6 +54,7 @@ create_instruction as ming_create_instruction, ) from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.utils.speaker_cache import get_speaker_cache logger = init_logger(__name__) @@ -159,6 +160,14 @@ def _sanitize_filename(filename: str) -> str: return sanitized +def _validate_speaker_name(name: str) -> str: + """Trim and reject empty / path-separator / NUL / reserved voice names.""" + trimmed = (name or "").strip() + if not trimmed or trimmed in (".", "..") or any(c in trimmed for c in "/\\\x00"): + raise ValueError(f"Invalid voice name {name!r}: must be non-empty, no path separators or NUL") + return trimmed + + def _validate_path_within_directory(file_path: Path, directory: Path) -> bool: """Validate that file_path is within the specified directory. @@ -179,6 +188,104 @@ class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin): _diffusion_mode: bool = False _tts_executor: ThreadPoolExecutor | None = None + def _init_speaker_storage(self) -> None: + """Initialize speaker storage + cache, restoring any persisted uploads.""" + speaker_samples_dir = os.environ.get("SPEAKER_SAMPLES_DIR", os.path.expanduser("~/.cache/vllm-omni/speakers")) + self.uploaded_speakers_dir = Path(speaker_samples_dir).expanduser() + self.uploaded_speakers_dir.mkdir(parents=True, exist_ok=True) + _raw_cap = os.environ.get("SPEAKER_MAX_UPLOADED", "") + try: + self._max_uploaded_speakers = int(_raw_cap) if _raw_cap else 1000 + except ValueError: + logger.warning("Invalid SPEAKER_MAX_UPLOADED=%r; using default 1000", _raw_cap) + self._max_uploaded_speakers = 1000 + self.uploaded_speakers: dict[str, dict] = {} + self.supported_speakers: set[str] = set() + self._ref_audio_data_url_cache: dict[str, str] = {} + self._speaker_cache = get_speaker_cache() + self._last_upload_ts = 0 + self._upload_lock = asyncio.Lock() + self._restore_uploaded_speakers() + logger.info( + "Speaker storage: dir=%s, max_speakers=%d, restored=%d", + self.uploaded_speakers_dir, + self._max_uploaded_speakers, + len(self.uploaded_speakers), + ) + + def _next_upload_timestamp(self) -> int: + ts = max(int(time.time()), self._last_upload_ts + 1) + self._last_upload_ts = ts + return ts + + _META_SCALAR_INT_KEYS: tuple[str, ...] = ( + "created_at", + "file_size", + "sample_rate", + "embedding_dim", + ) + + @classmethod + def _speaker_metadata_to_header(cls, speaker_data: dict[str, Any]) -> dict[str, str]: + """Serialize a speaker_data dict into safetensors' ``dict[str, str]`` header.""" + header: dict[str, str] = {} + for k, v in speaker_data.items(): + if v is None: + continue + # file_path is re-derived from the path on load; don't persist it. + if k == "file_path": + continue + header[k] = str(v) + return header + + @classmethod + def _speaker_metadata_from_header(cls, header: dict[str, str], file_path: str) -> dict[str, Any]: + """Reverse of :meth:`_speaker_metadata_to_header`: coerce ints back and re-inject file_path.""" + data: dict[str, Any] = dict(header) + for k in cls._META_SCALAR_INT_KEYS: + if k in data: + try: + data[k] = int(data[k]) + except ValueError: + logger.warning( + "Speaker metadata %r in %s is not a valid int (got %r); leaving as string", + k, + file_path, + data[k], + ) + data["file_path"] = file_path + return data + + def _restore_uploaded_speakers(self) -> None: + """Scan ``uploaded_speakers_dir`` for safetensors files and rebuild state.""" + try: + from safetensors import safe_open + except ImportError: + logger.warning("safetensors unavailable; uploaded voices will not persist across restarts") + return + + restored = 0 + for path in sorted(self.uploaded_speakers_dir.glob("*.safetensors")): + try: + with safe_open(str(path), framework="pt") as f: + header = dict(f.metadata() or {}) + except Exception as e: + logger.warning("Could not read voice file %s: %s", path, e) + continue + voice_name_lower = header.get("voice_name_lower") or header.get("name", "").lower() + if not voice_name_lower: + logger.warning("Voice file %s has no voice name in metadata; skipping", path) + continue + speaker_data = self._speaker_metadata_from_header(header, str(path)) + speaker_data.setdefault("name", voice_name_lower) + speaker_data.setdefault("file_size", int(path.stat().st_size)) + self.uploaded_speakers[voice_name_lower] = speaker_data + self.supported_speakers.add(voice_name_lower) + self._last_upload_ts = max(self._last_upload_ts, int(speaker_data.get("created_at", 0))) + restored += 1 + if restored: + logger.info("Restored %d uploaded voice(s) from %s", restored, self.uploaded_speakers_dir) + @classmethod def for_diffusion( cls, @@ -196,15 +303,19 @@ def for_diffusion( instance._diffusion_engine = diffusion_engine instance._diffusion_model_name = model_name instance._diffusion_stage_configs = stage_configs + instance._tts_model_type = "omnivoice" + instance._is_tts = False + instance._is_fish_speech = False + # Diffusion-only instances don't have a TTS stage; set None so any + # ``_is_tts_model()`` / ``_tts_stage`` access doesn't raise AttributeError. + instance._tts_stage = None + instance._init_speaker_storage() return instance def __init__(self, *args, **kwargs): self.model_name = kwargs.pop("model_name", None) super().__init__(*args, **kwargs) - # Initialize uploaded speakers storage (ephemeral — cleared on restart) - speech_voice_samples_dir = os.environ.get("SPEECH_VOICE_SAMPLES", "/tmp/voice_samples") - self.uploaded_speakers_dir = Path(speech_voice_samples_dir) - self.uploaded_speakers_dir.mkdir(parents=True, exist_ok=True) + self._init_speaker_storage() # Find and cache the TTS stage (if any) during initialization self._tts_stage = self._find_tts_stage() @@ -220,26 +331,19 @@ def __init__(self, *args, **kwargs): and getattr(getattr(self._tts_stage, "engine_args", None), "model_stage", None) in _COSYVOICE3_TTS_MODEL_STAGES ) - self._cosyvoice3_tokenizer = None - # Determine TTS model type or None self._tts_model_type = self._detect_tts_model_type() # Cache TTS configuration values (computed once, reused per request) self._max_instructions_length = self._compute_max_instructions_length() - # Load supported speakers (built-in only; uploaded voices start empty) - self.supported_speakers = self._load_supported_speakers() - self.uploaded_speakers: dict[str, dict] = {} - logger.warning( - "Uploaded voices are ephemeral and will be lost on server restart. " - "Re-upload voices after each restart if needed." - ) + # Merge built-in speakers into the set initialized by _init_speaker_storage. + self.supported_speakers |= self._load_supported_speakers() self._tts_tokenizer = None self._voxcpm2_tokenizer = None self._voxcpm2_split_map: dict[int, list[int]] = {} - logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}") + logger.info("Loaded %d supported speakers: %s", len(self.supported_speakers), sorted(self.supported_speakers)) # Batch configuration self._batch_max_items: int = getattr(self.engine_client, "tts_batch_max_items", 32) @@ -291,19 +395,21 @@ def _load_codec_frame_rate(self) -> float | None: if output_sr and downsample and downsample > 0: rate = float(output_sr) / float(downsample) logger.info( - f"Loaded codec frame rate: {rate:.1f} Hz " - f"(output_sample_rate={output_sr}, encode_downsample_rate={downsample})" + "Loaded codec frame rate: %.1f Hz (output_sample_rate=%s, encode_downsample_rate=%s)", + rate, + output_sr, + downsample, ) return rate except Exception as e: - logger.warning(f"Failed to load codec frame rate from speech tokenizer config: {e}") + logger.warning("Failed to load codec frame rate from speech tokenizer config: %s", e) # Fallback: try codec_frame_rate_hz from hf_config try: hf_config = self.engine_client.model_config.hf_config rate = getattr(hf_config, "codec_frame_rate_hz", None) if rate is not None: - logger.info(f"Using codec frame rate from hf_config: {rate} Hz") + logger.info("Using codec frame rate from hf_config: %s Hz", rate) return float(rate) except Exception: pass @@ -314,6 +420,8 @@ def shutdown(self) -> None: if self._tts_executor is not None: self._tts_executor.shutdown(wait=False, cancel_futures=True) self._tts_executor = None + for name in list(self.uploaded_speakers.keys()): + self._speaker_cache.clear(name) def _find_tts_stage(self): """Find and return the TTS stage config, or None if not found.""" @@ -400,7 +508,7 @@ def _load_supported_speakers(self) -> set[str]: logger.warning("No speakers found in config (checked spk_id and speaker_id)") except Exception as e: - logger.warning(f"Could not load speakers from model config: {e}") + logger.warning("Could not load speakers from model config: %s", e) return set() @@ -507,8 +615,26 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object) logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e) return 2048 - async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: - """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length).""" + def _voice_created_at(self, voice_lower: str) -> int: + """Return the upload timestamp of an uploaded voice, or 0 for built-ins. + + Plumbed through to the model-side cache key so that delete + re-upload + of the same name yields a fresh cache slot. + """ + info = self.uploaded_speakers.get(voice_lower) + return int(info.get("created_at", 0)) if info else 0 + + async def _build_voxcpm2_prompt( + self, + request: OpenAICreateSpeechRequest, + *, + uploaded_ref: tuple[np.ndarray, int] | None = None, + ) -> dict[str, Any]: + """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length). + + ``uploaded_ref`` supplies the audio for uploaded voices (no explicit + ``ref_audio`` in the request) so prefill length includes it. + """ from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt self._voxcpm2_encode("") # lazy-init tokenizer + split_map @@ -516,6 +642,9 @@ async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dic ref_sr = None if request.ref_audio is not None: ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio) + elif uploaded_ref is not None: + wav_np, ref_sr = uploaded_ref + ref_audio = wav_np.tolist() return build_voxcpm2_prompt( hf_config=self.engine_client.model_config.hf_config, tokenizer=self._voxcpm2_tokenizer, @@ -526,78 +655,157 @@ async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dic ref_text=request.ref_text, ) - def _get_uploaded_audio_data(self, voice_name: str) -> str | None: - """Get base64 encoded audio data for uploaded voice.""" + def _load_uploaded_audio(self, voice_name: str) -> tuple[np.ndarray, int] | None: + """Load decoded audio samples + sample rate from an uploaded voice's safetensors.""" voice_name_lower = voice_name.lower() - if voice_name_lower not in self.uploaded_speakers: + info = self.uploaded_speakers.get(voice_name_lower) + if info is None or info.get("embedding_source") != "audio": return None - - speaker_info = self.uploaded_speakers[voice_name_lower] - file_path = Path(speaker_info["file_path"]) - + file_path = Path(info["file_path"]) if not file_path.exists(): - logger.warning(f"Audio file not found for voice {voice_name}: {file_path}") + logger.warning("Voice file not found for %s: %s", voice_name, file_path) + return None + try: + from safetensors import safe_open + except ImportError: + logger.error("The 'safetensors' package is required to load uploaded voices") return None - try: - # Read audio file - with open(file_path, "rb") as f: - audio_bytes = f.read() + with safe_open(str(file_path), framework="pt") as f: + if "audio" not in f.keys(): + return None + samples = f.get_tensor("audio").numpy() + sr = int((f.metadata() or {}).get("sample_rate", info.get("sample_rate", 0))) + except Exception as e: + logger.error("Could not load audio for voice %s: %s", voice_name, e) + return None + if sr <= 0: + return None + return samples, sr - # Encode to base64 - audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") + def _get_uploaded_audio_data(self, voice_name: str) -> str | None: + """Return a base64-encoded WAV data URL for an uploaded voice. - # Get MIME type from file extension - mime_type = speaker_info.get("mime_type", "audio/wav") + Memoized so the WAV re-encode runs once per voice per process. + """ + voice_name_lower = voice_name.lower() + cached = self._ref_audio_data_url_cache.get(voice_name_lower) + if cached is not None: + return cached - # Return as data URL - return f"data:{mime_type};base64,{audio_b64}" + data = self._load_uploaded_audio(voice_name) + if data is None: + return None + samples, sr = data + try: + buf = io.BytesIO() + sf.write(buf, samples, sr, format="WAV") + audio_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + data_url = f"data:audio/wav;base64,{audio_b64}" except Exception as e: - logger.error(f"Could not read audio file for voice {voice_name}: {e}") + logger.error("Could not encode voice %s as WAV: %s", voice_name, e) return None + self._ref_audio_data_url_cache[voice_name_lower] = data_url + return data_url def _get_uploaded_speaker_embedding(self, voice_name: str) -> list[float] | None: - """Load pre-computed speaker embedding for an uploaded voice. + """Load a pre-computed speaker embedding from an uploaded voice's safetensors. - Returns the embedding as a list of floats, or None if the voice - was not uploaded with an embedding (i.e. it has audio instead). - """ + Returns ``None`` if the voice has audio (not a direct embedding).""" voice_name_lower = voice_name.lower() - if voice_name_lower not in self.uploaded_speakers: - return None - - speaker_info = self.uploaded_speakers[voice_name_lower] - if speaker_info.get("embedding_source") != "direct": + info = self.uploaded_speakers.get(voice_name_lower) + if info is None or info.get("embedding_source") != "direct": return None - - cache_file = speaker_info.get("cache_file") - if not cache_file or not Path(cache_file).exists(): - logger.warning("Embedding file not found for voice %s: %s", voice_name, cache_file) + file_path = Path(info["file_path"]) + if not file_path.exists(): + logger.warning("Embedding file not found for voice %s: %s", voice_name, file_path) return None - - if not _validate_path_within_directory(Path(cache_file), self.uploaded_speakers_dir): - logger.error("Cache file path traversal detected for voice %s: %s", voice_name, cache_file) + if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): + logger.error("File path traversal detected for voice %s: %s", voice_name, file_path) return None - try: from safetensors.torch import load_file except ImportError: - logger.error( - "The 'safetensors' package is required to load speaker embeddings. " - "Install it with: pip install safetensors" - ) + logger.error("The 'safetensors' package is required to load speaker embeddings") return None - try: - tensors = load_file(cache_file) + tensors = load_file(str(file_path)) if "speaker_embedding" not in tensors: - logger.warning("Key 'speaker_embedding' not found in %s for voice %s", cache_file, voice_name) + logger.warning("Key 'speaker_embedding' missing in %s", file_path) return None return tensors["speaker_embedding"].squeeze().tolist() except Exception as e: logger.error("Could not load embedding for voice %s: %s", voice_name, e) return None + def _apply_uploaded_speaker(self, request: OpenAICreateSpeechRequest) -> str | None: + """Resolve ``request.voice`` against uploaded speakers, mutating + ``request.ref_audio`` / ``request.ref_text`` in place. Returns an + error string if the voice is invalid, else ``None``. + """ + if request.voice is None or request.ref_audio is not None: + return None + + voice_lower = request.voice.lower() + if voice_lower not in self.uploaded_speakers: + if self._tts_model_type in ("cosyvoice3", "fish_tts", "omnivoice", "moss_tts_nano"): + label = { + "cosyvoice3": "CosyVoice3", + "fish_tts": "Fish Speech", + "omnivoice": "OmniVoice", + "moss_tts_nano": "MOSS-TTS-Nano", + }.get(self._tts_model_type, self._tts_model_type) + return ( + f"Unknown voice '{request.voice}'. {label} has no " + f"built-in speakers. Upload a voice first via " + f"POST /v1/audio/voices, or use ref_audio + ref_text." + ) + return None + + speaker_info = self.uploaded_speakers[voice_lower] + if speaker_info.get("embedding_source") == "direct": + return ( + f"Uploaded voice '{request.voice}' uses a speaker embedding " + f"(Qwen3-only). Re-upload with an audio file for this model." + ) + + audio_data = self._get_uploaded_audio_data(request.voice) + if not audio_data: + return f"Audio file for uploaded voice '{request.voice}' is missing" + + request.ref_audio = audio_data + if not request.ref_text or not request.ref_text.strip(): + stored_ref_text = speaker_info.get("ref_text") + if stored_ref_text: + request.ref_text = stored_ref_text + + logger.info("Resolved uploaded voice '%s' for %s", voice_lower, self._tts_model_type) + return None + + def _check_upload_cap(self) -> None: + if len(self.uploaded_speakers) >= self._max_uploaded_speakers: + raise ValueError( + f"Uploaded voice limit reached ({self._max_uploaded_speakers}). " + f"Delete an existing voice before registering a new one, or raise " + f"the cap via SPEAKER_MAX_UPLOADED." + ) + + def _evict_existing_upload(self, voice_name_lower: str, name: str) -> None: + """Drop an existing upload with this name so the caller can re-register it.""" + if voice_name_lower not in self.uploaded_speakers: + return + old = self.uploaded_speakers.pop(voice_name_lower) + self.supported_speakers.discard(voice_name_lower) + self._ref_audio_data_url_cache.pop(voice_name_lower, None) + old_path = old.get("file_path") + if old_path: + try: + Path(old_path).unlink(missing_ok=True) + except Exception as e: + logger.warning("Failed to remove previous file for '%s': %s", name, e) + self._speaker_cache.clear(voice_name_lower) + logger.info("Speaker '%s' re-uploaded; previous cache and file overwritten", name) + async def upload_voice( self, audio_file: UploadFile, @@ -608,6 +816,7 @@ async def upload_voice( speaker_description: str | None = None, ) -> dict: """Upload a new voice sample.""" + name = _validate_speaker_name(name) # Normalize optional strings: treat whitespace-only as absent if ref_text is not None: ref_text = ref_text.strip() or None @@ -659,39 +868,32 @@ async def upload_voice( if mime_type not in allowed_mime_types: raise ValueError(f"Unsupported MIME type: {mime_type}. Allowed: {allowed_mime_types}") - # Normalize voice name - voice_name_lower = name.lower() - - # Check if voice already exists - if voice_name_lower in self.uploaded_speakers: - raise ValueError( - f"Voice '{name}' already exists. To re-register this voice, delete it first and then upload it again." - ) - - # Sanitize name and consent to prevent path traversal - sanitized_name = _sanitize_filename(name) - sanitized_consent = _sanitize_filename(consent) + # Read content before acquiring the lock; decode happens inside. + content = await audio_file.read() - # Generate filename with sanitized inputs - timestamp = int(time.time()) - file_suffix = Path(audio_file.filename).suffix - file_ext = file_suffix[1:] if file_suffix and len(file_suffix) > 1 else "wav" - # Sanitize file extension as well - sanitized_ext = _sanitize_filename(file_ext) - if not sanitized_ext or sanitized_ext == "file": - sanitized_ext = "wav" + async with self._upload_lock: + voice_name_lower = name.lower() + self._evict_existing_upload(voice_name_lower, name) + self._check_upload_cap() - filename = f"{sanitized_name}_{sanitized_consent}_{timestamp}.{sanitized_ext}" - file_path = self.uploaded_speakers_dir / filename + sanitized_name = _sanitize_filename(name) + sanitized_consent = _sanitize_filename(consent) + timestamp = self._next_upload_timestamp() + file_suffix = Path(audio_file.filename).suffix + file_ext = file_suffix[1:] if file_suffix and len(file_suffix) > 1 else "wav" + sanitized_ext = _sanitize_filename(file_ext) + if not sanitized_ext or sanitized_ext == "file": + sanitized_ext = "wav" - # Double-check that the path is within the upload directory - if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): - raise ValueError("Invalid file path: potential path traversal attack detected") + filename = f"{sanitized_name}_{sanitized_consent}_{timestamp}.safetensors" + file_path = self.uploaded_speakers_dir / filename + if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): + raise ValueError("Invalid file path: potential path traversal attack detected") - # Read content and validate duration before saving - content = await audio_file.read() - try: - wav_np, sr = sf.read(io.BytesIO(content)) + try: + wav_np, sr = sf.read(io.BytesIO(content)) + except Exception as e: + raise ValueError(f"Could not decode audio file: {e}") duration = len(wav_np) / sr if sr > 0 else 0.0 if duration < _REF_AUDIO_MIN_DURATION: raise ValueError( @@ -703,39 +905,41 @@ async def upload_voice( f"Reference audio too long ({duration:.1f}s). " f"Maximum {_REF_AUDIO_MAX_DURATION:.0f}s supported — use a shorter clip." ) - except ValueError: - raise - except Exception as e: - logger.warning("Could not validate audio duration: %s", e) - # Save audio file - try: - with open(file_path, "wb") as f: - f.write(content) - except Exception as e: - raise ValueError(f"Failed to save audio file: {e}") - - # Create speaker data - speaker_data: dict[str, Any] = { - "name": name, - "consent": consent, - "file_path": str(file_path), - "created_at": timestamp, - "mime_type": mime_type, - "original_filename": audio_file.filename, - "file_size": file_size, - "ref_text": ref_text, - "embedding_source": "audio", - } + speaker_data: dict[str, Any] = { + "name": name, + "voice_name_lower": voice_name_lower, + "consent": consent, + "file_path": str(file_path), + "created_at": timestamp, + "mime_type": mime_type, + "original_filename": audio_file.filename, + "file_size": file_size, + "sample_rate": int(sr), + "ref_text": ref_text, + "embedding_source": "audio", + } + if speaker_description: + speaker_data["speaker_description"] = speaker_description - # Store voice description if provided. - if speaker_description: - speaker_data["speaker_description"] = speaker_description + try: + from safetensors.torch import save_file + except ImportError as exc: + raise ValueError("safetensors is required for voice upload") from exc + try: + audio_tensor = torch.from_numpy(np.asarray(wav_np, dtype=np.float32)).contiguous() + save_file( + {"audio": audio_tensor}, + str(file_path), + metadata=self._speaker_metadata_to_header(speaker_data), + ) + except Exception as e: + raise ValueError(f"Failed to save voice file: {e}") - self.uploaded_speakers[voice_name_lower] = speaker_data - self.supported_speakers.add(voice_name_lower) + self.uploaded_speakers[voice_name_lower] = speaker_data + self.supported_speakers.add(voice_name_lower) - logger.info(f"Uploaded new voice '{name}' with consent ID '{consent}'") + logger.info("Uploaded new voice '%s' with consent ID '%s'", name, consent) # Return voice information without exposing the server file path result = { @@ -765,6 +969,7 @@ async def upload_voice_embedding(self, embedding_json: str, consent: str, name: Returns: dict with voice information. """ + name = _validate_speaker_name(name) try: embedding = json.loads(embedding_json) except (json.JSONDecodeError, TypeError) as exc: @@ -773,6 +978,9 @@ async def upload_voice_embedding(self, embedding_json: str, consent: str, name: if not isinstance(embedding, list) or not embedding: raise ValueError("'speaker_embedding' must be a non-empty list of numbers") + if len(embedding) > 4096: + raise ValueError("'speaker_embedding' exceeds maximum length (4096 elements)") + if not all(isinstance(x, (int, float)) for x in embedding): raise ValueError("'speaker_embedding' must contain only numeric values") @@ -784,46 +992,47 @@ async def upload_voice_embedding(self, embedding_json: str, consent: str, name: if dim_err is not None: raise ValueError(dim_err) - voice_name_lower = name.lower() - if voice_name_lower in self.uploaded_speakers: - raise ValueError(f"Voice '{name}' already exists") + async with self._upload_lock: + voice_name_lower = name.lower() + self._evict_existing_upload(voice_name_lower, name) + self._check_upload_cap() - sanitized_name = _sanitize_filename(name) - sanitized_consent = _sanitize_filename(consent) - timestamp = int(time.time()) - - # Store as safetensors for efficient loading - try: - import torch - from safetensors.torch import save_file + sanitized_name = _sanitize_filename(name) + sanitized_consent = _sanitize_filename(consent) + timestamp = self._next_upload_timestamp() tensor = torch.tensor(embedding, dtype=torch.float32) filename = f"{sanitized_name}_{sanitized_consent}_{timestamp}.safetensors" file_path = self.uploaded_speakers_dir / filename - if not _validate_path_within_directory(file_path, self.uploaded_speakers_dir): raise ValueError("Invalid file path: potential path traversal attack detected") - save_file({"speaker_embedding": tensor}, str(file_path)) - except ImportError: - raise ValueError("safetensors and torch are required for embedding upload") - - speaker_data = { - "name": name, - "consent": consent, - "file_path": str(file_path), - "created_at": timestamp, - "mime_type": "application/x-safetensors", - "original_filename": filename, - "file_size": file_path.stat().st_size, - "embedding_source": "direct", - "embedding_dim": emb_dim, - } + speaker_data: dict[str, Any] = { + "name": name, + "voice_name_lower": voice_name_lower, + "consent": consent, + "file_path": str(file_path), + "created_at": timestamp, + "mime_type": "application/x-safetensors", + "original_filename": filename, + "embedding_source": "direct", + "embedding_dim": emb_dim, + } + try: + from safetensors.torch import save_file + except ImportError as exc: + raise ValueError("safetensors is required for embedding upload") from exc + save_file( + {"speaker_embedding": tensor}, + str(file_path), + metadata=self._speaker_metadata_to_header(speaker_data), + ) + speaker_data["file_size"] = file_path.stat().st_size - self.uploaded_speakers[voice_name_lower] = speaker_data - self.supported_speakers.add(voice_name_lower) + self.uploaded_speakers[voice_name_lower] = speaker_data + self.supported_speakers.add(voice_name_lower) - logger.info(f"Uploaded voice '{name}' from speaker embedding ({emb_dim}-dim)") + logger.info("Uploaded voice '%s' from speaker embedding (%d-dim)", name, emb_dim) return { "name": name, @@ -843,24 +1052,27 @@ async def delete_voice(self, name: str) -> bool: Returns: bool: True if successful, False if voice doesn't exist """ - voice_name_lower = name.lower() + async with self._upload_lock: + voice_name_lower = name.lower() - if voice_name_lower not in self.uploaded_speakers: - logger.warning(f"Voice '{name}' not found") - return False + if voice_name_lower not in self.uploaded_speakers: + logger.warning("Voice '%s' not found", name) + return False - speaker_info = self.uploaded_speakers.pop(voice_name_lower) - self.supported_speakers.discard(voice_name_lower) + speaker_info = self.uploaded_speakers.pop(voice_name_lower) + self.supported_speakers.discard(voice_name_lower) + self._ref_audio_data_url_cache.pop(voice_name_lower, None) - # Clean up audio file on disk - file_path = speaker_info.get("file_path") - if file_path: - try: - Path(file_path).unlink(missing_ok=True) - except Exception as e: - logger.warning(f"Failed to delete audio file for '{name}': {e}") + file_path = speaker_info.get("file_path") + if file_path: + try: + Path(file_path).unlink(missing_ok=True) + except Exception as e: + logger.warning("Failed to delete audio file for '%s': %s", name, e) + + self._speaker_cache.clear(voice_name_lower) - logger.info(f"Deleted voice '{name}'") + logger.info("Deleted voice '%s'", name) return True def _is_tts_model(self) -> bool: @@ -1080,9 +1292,6 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str "the reference audio) unless 'x_vector_only_mode' is enabled" ) else: - # Handle the case where request.voice is NOT None - pass - # voice is not None voice_lower = request.voice.lower() if voice_lower in self.uploaded_speakers: # Check if data file exists for uploaded speaker @@ -1090,12 +1299,6 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str file_path = Path(speaker_info["file_path"]) if not file_path.exists(): return f"Data file for uploaded speaker '{request.voice}' not found on disk" - # For embedding-uploaded voices, verify the cache is ready - if speaker_info.get("embedding_source") == "direct": - cache_file = speaker_info.get("cache_file") - if not cache_file or not Path(cache_file).exists(): - status = speaker_info.get("cache_status", "unknown") - return f"Speaker embedding for '{request.voice}' is not yet ready (cache_status='{status}')" else: # need ref_audio for built-in speaker if request.ref_audio is None: @@ -1183,32 +1386,10 @@ async def _build_moss_tts_params(self, request: OpenAICreateSpeechRequest) -> di return params def _validate_fish_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: - """Validate Fish Speech request parameters. Returns error message or None. - - Side effect: if request.voice references an uploaded speaker, resolves - it to request.ref_audio and request.ref_text for voice cloning. - """ + """Validate Fish Speech request parameters. Returns error message or None.""" if not request.input or not request.input.strip(): return "Input text cannot be empty" - # Support uploaded voices: auto-resolve voice → ref_audio + ref_text. - if request.voice is not None and request.ref_audio is None: - voice_lower = request.voice.lower() - if voice_lower in self.uploaded_speakers: - speaker_info = self.uploaded_speakers[voice_lower] - file_path = Path(speaker_info["file_path"]) - if not file_path.exists(): - return f"Audio file for uploaded voice '{request.voice}' not found on disk" - audio_data_url = self._get_uploaded_audio_data(voice_lower) - if audio_data_url is None: - return f"Could not load audio for uploaded voice '{request.voice}'" - request.ref_audio = audio_data_url - # Use ref_text from upload metadata if not provided in request. - if not request.ref_text or not request.ref_text.strip(): - upload_ref_text = speaker_info.get("ref_text") - if upload_ref_text and upload_ref_text.strip(): - request.ref_text = upload_ref_text - if request.ref_audio is not None: fmt_err = self._validate_ref_audio_format(request.ref_audio) if fmt_err: @@ -1393,9 +1574,9 @@ def _extract_audio_output(res) -> tuple[dict | None, str | None]: ro = getattr(res, "request_output", None) mm = getattr(ro, "multimodal_output", None) if ro else None if not mm: - if ro is None: - ro = getattr(res, "request_output", None) - outputs = getattr(ro, "outputs", None) if ro else None + # MultimodalOutputProcessor attaches mm_accumulated on per-completion outputs. + container = res if hasattr(res, "outputs") else ro + outputs = getattr(container, "outputs", None) if container is not None else None if outputs: for completion_output in outputs: completion_mm = getattr(completion_output, "multimodal_output", None) @@ -1445,6 +1626,7 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any # Speaker (voice) if request.voice is not None: params["speaker"] = [request.voice] + params["voice_created_at"] = [self._voice_created_at(request.voice.lower())] # Uploaded voices use task_type="Base" (CustomVoice requires built-in spk_id). # If ref_text was provided at upload time, use in-context cloning; otherwise x_vector only. @@ -1466,7 +1648,6 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any stored_ref_text = speaker_info.get("ref_text") params["ref_audio"] = [audio_data] params["task_type"] = ["Base"] - params["voice_created_at"] = [speaker_info.get("created_at", 0)] if stored_ref_text: params["ref_text"] = [stored_ref_text] params["x_vector_only_mode"] = [False] @@ -1521,14 +1702,16 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any # ---- Voxtral TTS helpers ---- def _build_voxtral_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: - """Build Voxtral TTS engine prompt from shared TTS parameters.""" + """Build Voxtral TTS engine prompt, supporting both preset voices and inline + ``ref_audio`` (base64 or data URI).""" from mistral_common.protocol.speech.request import SpeechRequest text = request.input voice = request.voice ref_audio = request.ref_audio - assert voice or ref_audio, "Either voice or ref_audio must be provided" - # Strip data URI prefix — mistral_common expects raw base64 + if not voice and not ref_audio: + raise ValueError("Voxtral requires either a voice name or ref_audio.") + # mistral_common expects raw base64 (no data: prefix) if ref_audio is not None and isinstance(ref_audio, str) and ref_audio.startswith("data:"): _, _, ref_audio = ref_audio.partition(",") if self._tts_tokenizer is None: @@ -1608,7 +1791,7 @@ def _build_fish_speech_prompt( voice_lower = request.voice.lower() if voice_lower in self.uploaded_speakers: additional_information["voice_name"] = voice_lower - additional_information["voice_created_at"] = self.uploaded_speakers[voice_lower].get("created_at", 0) + additional_information["voice_created_at"] = self._voice_created_at(voice_lower) if request.max_new_tokens is not None: additional_information["max_new_tokens"] = request.max_new_tokens prompt = tokens_input(prompt_token_ids=[1] * ph_len) @@ -1631,17 +1814,64 @@ async def _build_cosyvoice3_prompt( wav_samples, sr = await self._resolve_ref_audio(request.ref_audio) audio_data = (np.asarray(wav_samples, dtype=np.float32), sr) + mm_kwargs: dict[str, Any] = { + "prompt_text": request.ref_text, + "sample_rate": sr, + } + # Pass voice metadata for caching in the processor + if request.voice: + voice_lower = request.voice.lower() + mm_kwargs["voice_name"] = voice_lower + mm_kwargs["voice_created_at"] = self._voice_created_at(voice_lower) + return { "prompt": request.input, "multi_modal_data": { "audio": audio_data, }, - "mm_processor_kwargs": { - "prompt_text": request.ref_text, - "sample_rate": sr, - }, + "mm_processor_kwargs": mm_kwargs, } + def _apply_cosyvoice3_dynamic_tokens( + self, + sampling_params_list: list, + request: OpenAICreateSpeechRequest, + ) -> list: + """Set min/max tokens from tokenized text length (ratios target tokens, not chars).""" + import copy + + from vllm_omni.model_executor.models.cosyvoice3.tokenizer import get_qwen_tokenizer + from vllm_omni.model_executor.models.cosyvoice3.utils import extract_text_token + + sampling_params_list = copy.deepcopy(sampling_params_list) + hf_cfg = self.model_config.hf_config + model_path = self.engine_client.model_config.model + if not os.path.isdir(model_path): + from huggingface_hub import snapshot_download + + model_path = snapshot_download(model_path) + tokenizer = get_qwen_tokenizer( + token_path=os.path.join(model_path, hf_cfg.qwen_pretrain_path), + skip_special_tokens=hf_cfg.skip_special_tokens, + version=hf_cfg.version, + ) + _, text_token_len = extract_text_token( + request.input, + tokenizer, + hf_cfg.allowed_special, + ) + min_ratio = getattr(hf_cfg, "min_token_text_ratio", 2) + max_ratio = getattr(hf_cfg, "max_token_text_ratio", 20) + sampling_params_list[0].min_tokens = max(1, int(text_token_len * min_ratio)) + sampling_params_list[0].max_tokens = min(2048, int(text_token_len * max_ratio)) + logger.info( + "CosyVoice3 dynamic tokens: text_tokens=%d, min_tokens=%d, max_tokens=%d", + text_token_len, + sampling_params_list[0].min_tokens, + sampling_params_list[0].max_tokens, + ) + return sampling_params_list + # ---- Ming-flash-omni standalone-talker (TTS) helpers ---- def _build_ming_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: @@ -1704,6 +1934,13 @@ async def _prepare_speech_generation( sampling_params_list = list(self.engine_client.default_sampling_params_list) sampling_params_list = coerce_param_message_types(sampling_params_list, request.stream) + # Resolve uploaded voice for non-Qwen3 models. + # Qwen3 TTS has its own uploaded voice handling in _build_tts_params(). + if self._tts_model_type in ("fish_tts", "cosyvoice3", "moss_tts_nano"): + err = self._apply_uploaded_speaker(request) + if err: + raise ValueError(err) + if self._is_fish_speech: validation_error = self._validate_fish_tts_request(request) if validation_error: @@ -1717,34 +1954,48 @@ async def _prepare_speech_generation( elif self._tts_model_type == "omnivoice": if not request.input or not request.input.strip(): raise ValueError("Input text cannot be empty") + err = self._apply_uploaded_speaker(request) + if err: + raise ValueError(err) tts_params = {} prompt: dict[str, Any] = {"input": request.input} - # Resolve ref_audio: explicit request param or uploaded voice - ref_src = request.ref_audio - if not ref_src and request.voice: - vl = request.voice.lower() - if vl in self.uploaded_speakers: - sp = self.uploaded_speakers[vl] - if sp.get("embedding_source") == "audio": - ref_src = self._get_uploaded_audio_data(request.voice) - if not ref_src: - raise ValueError(f"Audio for voice '{request.voice}' missing") - prompt["ref_text"] = sp.get("ref_text") - if ref_src: - fmt_err = self._validate_ref_audio_format(ref_src) - if fmt_err: - raise ValueError(fmt_err) - wav, sr = await self._resolve_ref_audio(ref_src) + if request.ref_audio: + wav, sr = await self._resolve_ref_audio(request.ref_audio) prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr) if request.ref_text: prompt["ref_text"] = request.ref_text + if request.voice: + voice_lower = request.voice.lower() + prompt["voice_name"] = voice_lower + prompt["voice_created_at"] = self._voice_created_at(voice_lower) if request.language: prompt["lang"] = request.language if request.instructions: prompt["instruct"] = request.instructions elif self._tts_model_type == "voxcpm2": - prompt = await self._build_voxcpm2_prompt(request) + # voxcpm2 doesn't use `_apply_uploaded_speaker` because the prompt builder needs the + # raw waveform tuple for prefill-length accounting, not a base64 data URL. + uploaded_ref: tuple[np.ndarray, int] | None = None + if request.voice: + voice_lower = request.voice.lower() + if voice_lower not in self.uploaded_speakers and voice_lower not in self.supported_speakers: + all_voices = sorted(self.uploaded_speakers.keys() | self.supported_speakers) + raise ValueError(f"Invalid voice '{request.voice}'. Supported: {', '.join(all_voices) or 'none'}") + if voice_lower in self.uploaded_speakers: + if self.uploaded_speakers[voice_lower].get("embedding_source") == "direct": + raise ValueError( + f"Uploaded voice '{request.voice}' uses a speaker embedding (Qwen3-only). " + f"Re-upload with an audio file for VoxCPM2." + ) + if request.ref_audio is None: + uploaded_ref = self._load_uploaded_audio(voice_lower) + prompt = await self._build_voxcpm2_prompt(request, uploaded_ref=uploaded_ref) tts_params = {} + if request.voice: + voice_lower = request.voice.lower() + additional = prompt.setdefault("additional_information", {}) + additional["voice_name"] = voice_lower + additional["voice_created_at"] = self._voice_created_at(voice_lower) elif self._is_tts: validation_error = self._validate_tts_request(request) if validation_error: @@ -1761,6 +2012,10 @@ async def _prepare_speech_generation( tts_params = {} elif self._tts_model_type == "moss_tts_nano": tts_params = await self._build_moss_tts_params(request) + if request.voice: + voice_lower = request.voice.lower() + tts_params["voice_name"] = [voice_lower] + tts_params["voice_created_at"] = [self._voice_created_at(voice_lower)] prompt = tokens_input(prompt_token_ids=[1]) prompt["additional_information"] = tts_params else: @@ -1830,24 +2085,7 @@ async def _prepare_speech_generation( # The official model requires min_token_text_ratio to prevent early # EOS and max_token_text_ratio to cap generation length. if self._tts_model_type == "cosyvoice3" and sampling_params_list: - import copy - - sampling_params_list = copy.deepcopy(sampling_params_list) - text_len = len(request.input) # rough char-level estimate - # Use the model's configured ratios (defaults: min=2, max=20) - hf_cfg = self.model_config.hf_config - min_ratio = getattr(hf_cfg, "min_token_text_ratio", 2) - max_ratio = getattr(hf_cfg, "max_token_text_ratio", 20) - min_tokens = max(1, int(text_len * min_ratio)) - max_tokens = min(2048, int(text_len * max_ratio)) - sampling_params_list[0].min_tokens = min_tokens - sampling_params_list[0].max_tokens = max_tokens - logger.info( - "CosyVoice3 dynamic tokens: text_len=%d, min_tokens=%d, max_tokens=%d", - text_len, - min_tokens, - max_tokens, - ) + sampling_params_list = self._apply_cosyvoice3_dynamic_tokens(sampling_params_list, request) # Apply model-specific extra parameters if request.extra_params is not None and sampling_params_list: @@ -2035,14 +2273,21 @@ async def _create_diffusion_speech( if not request.input or not request.input.strip(): raise ValueError("Input text cannot be empty") - # Validate ref_audio format up-front so that bogus inputs return a - # 4xx instead of falling through to MediaConnector and surfacing as - # a 500 Internal Server Error (e.g. test_voice_clone_invalid_ref_audio_format). if request.ref_audio is not None: fmt_err = self._validate_ref_audio_format(request.ref_audio) if fmt_err: return self._diffusion_error_response(fmt_err, status_code=400) + if request.voice: + voice_lower = request.voice.lower() + if voice_lower not in self.uploaded_speakers and voice_lower not in self.supported_speakers: + all_voices = sorted(self.uploaded_speakers.keys() | self.supported_speakers) + raise ValueError(f"Invalid voice '{request.voice}'. Supported: {', '.join(all_voices) or 'none'}") + + err = self._apply_uploaded_speaker(request) + if err: + raise ValueError(err) + request_id = f"speech-{random_uuid()}" prompt: dict[str, Any] = {"input": request.input} if request.ref_audio: @@ -2050,6 +2295,10 @@ async def _create_diffusion_speech( prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr) if request.ref_text: prompt["ref_text"] = request.ref_text + if request.voice: + voice_lower = request.voice.lower() + prompt["voice_name"] = voice_lower + prompt["voice_created_at"] = self._voice_created_at(voice_lower) if request.language: prompt["lang"] = request.language if request.instructions: @@ -2109,10 +2358,6 @@ async def _create_diffusion_speech( except (EngineGenerateError, EngineDeadError): raise # Propagate to the global Omni exception handler except ValueError as e: - # ValueError represents invalid client input (bad ref_audio URI, - # empty text, model-side validation failures forwarded as ValueError, ...). - # Return 400 to match `create_speech`'s `self.create_error_response(e)` - # default (HTTPStatus.BAD_REQUEST) rather than masking it as a 500. return self._diffusion_error_response(str(e), status_code=400) except Exception as e: logger.exception("Diffusion speech generation failed: %s", e) diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py index 2fba8fb8af1..9134b292b7d 100644 --- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py +++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py @@ -41,6 +41,7 @@ extract_text_token, ) from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.utils.speaker_cache import get_speaker_cache logger = init_logger(__name__) @@ -104,6 +105,7 @@ def _ensure_cached_runtime_components(self, model_dir: str, config: CosyVoice3Co providers=["CPUExecutionProvider"], ) self._cached_model_dir = model_dir + self._speaker_cache = get_speaker_cache() def _call_hf_processor( self, @@ -161,6 +163,28 @@ def _call_hf_processor( ) device = "cpu" + # Speaker cache: skip 3 ONNX sessions on cache hit + voice_name = mm_kwargs.get("voice_name") + cache_key = None + if voice_name and isinstance(voice_name, str): + cache_key = self._speaker_cache.make_cache_key( + voice_name, + model_type="cosyvoice3", + created_at=int(mm_kwargs.get("voice_created_at") or 0), + ) + cached = self._speaker_cache.get(cache_key) + if cached is not None: + ft = BatchFeature( + { + "input_ids": input_ids, + "speech_feat": cached["speech_feat"].clone(), + "speech_token": cached["speech_token"].clone(), + "speech_token_len": [cached["speech_token_len"].clone()], + "embedding": cached["embedding"].clone(), + } + ) + return ft + speech_token, speech_token_len = extract_speech_token(audio, self.speech_tokenizer, device) speech_feat, speech_feat_len = extract_speech_feat(audio, self.feat_extractor, device) @@ -171,6 +195,18 @@ def _call_hf_processor( embedding = extract_spk_embedding(audio, self.campplus_session, device) + # Cache the extracted artifacts for named speakers + if cache_key is not None: + self._speaker_cache.put( + cache_key, + { + "speech_feat": speech_feat.detach().cpu(), + "speech_token": speech_token.detach().cpu(), + "speech_token_len": speech_token_len.detach().cpu(), + "embedding": embedding.detach().cpu(), + }, + ) + ft = BatchFeature( { "input_ids": input_ids, diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 6bdb0549b6b..24182b93ea3 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -32,7 +32,7 @@ from vllm.sequence import IntermediateTensors from vllm_omni.model_executor.models.output_templates import OmniOutput -from vllm_omni.utils.voice_cache import VoiceEmbeddingCache +from vllm_omni.utils.speaker_cache import get_speaker_cache from .configuration_fish_speech import FishSpeechConfig, FishSpeechFastARConfig, FishSpeechSlowARConfig from .dac_encoder import _load_dac_codec, encode_reference_audio_codes @@ -262,7 +262,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.register_buffer("_semantic_allowed_mask", semantic_mask, persistent=False) # In-memory LRU cache for DAC-encoded reference audio codes. - self._voice_cache = VoiceEmbeddingCache() + self._speaker_cache = get_speaker_cache() # Tokeniser (lazy). self._tokenizer = None @@ -539,37 +539,27 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any] if not isinstance(ref_text, str) or not isinstance(text, str): raise ValueError("Fish Speech structured voice clone requires string text and ref_text") - # --- Voice cache: reuse DAC codes for uploaded (named) voices --- - _voice_cache_key: str | None = None + _speaker_cache_key: tuple[str, str, int] | None = None voice_name = info_dict.get("voice_name") - voice_created_at = info_dict.get("voice_created_at") if isinstance(voice_name, str) and voice_name: - _created_at = float(voice_created_at) if voice_created_at is not None else 0.0 - if _created_at <= 0: - logger.warning( - "Voice '%s' has no created_at timestamp; DAC code caching disabled for this request", - voice_name, + _speaker_cache_key = self._speaker_cache.make_cache_key( + voice_name, + model_type="fish_speech", + created_at=int(info_dict.get("voice_created_at") or 0), + ) + _cached = self._speaker_cache.get(_speaker_cache_key) + if _cached is not None: + ref_codes_fq = _cached["ref_codes_fq"].to( + device=self.codebook_embeddings.weight.device, + dtype=torch.long, ) - else: - _voice_cache_key = self._voice_cache.make_cache_key( - voice_name, - xvec_only=False, - created_at=_created_at, + logger.debug("Speaker cache HIT for Fish Speech speaker '%s'", voice_name) + return self._apply_codebook_embeddings( + tokenizer, + text, + ref_text, + ref_codes_fq, ) - _cached = self._voice_cache.get(_voice_cache_key) - if _cached is not None: - ref_codes_fq = _cached["ref_codes_fq"].to( - device=self.codebook_embeddings.weight.device, - dtype=torch.long, - ) - _voice_cache_key = None # hit → don't store again - logger.debug("Voice cache HIT for Fish Speech voice '%s'", voice_name) - return self._apply_codebook_embeddings( - tokenizer, - text, - ref_text, - ref_codes_fq, - ) if not isinstance(ref_audio_sr, int): raise ValueError("Fish Speech structured voice clone requires integer ref_audio_sr") @@ -590,12 +580,12 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any] ) # Cache miss: store DAC codes for future reuse. - if _voice_cache_key is not None: - self._voice_cache.put( - _voice_cache_key, + if _speaker_cache_key is not None: + self._speaker_cache.put( + _speaker_cache_key, {"ref_codes_fq": ref_codes_fq.detach().cpu()}, ) - logger.debug("Voice cache STORE for Fish Speech voice '%s'", voice_name) + logger.debug("Speaker cache STORE for Fish Speech speaker '%s'", voice_name) return self._apply_codebook_embeddings(tokenizer, text, ref_text, ref_codes_fq) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index a66148c81cb..d9c454bd722 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -29,7 +29,7 @@ from vllm_omni.data_entry_keys import OmniPayload from vllm_omni.model_executor.models.output_templates import OmniOutput from vllm_omni.utils.audio import mel_filter_bank -from vllm_omni.utils.voice_cache import VoiceEmbeddingCache +from vllm_omni.utils.speaker_cache import get_speaker_cache from .configuration_qwen3_tts import Qwen3TTSConfig, Qwen3TTSSpeakerEncoderConfig, Qwen3TTSTalkerConfig from .qwen3_tts_code_predictor_vllm import Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM @@ -428,8 +428,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._tokenizer = None self._speech_tokenizer: Qwen3TTSTokenizer | None = None - # In-memory LRU cache for voice extraction artifacts (Base voice clone). - self._voice_cache = VoiceEmbeddingCache() + self._speaker_cache = get_speaker_cache() raw_subtalker_sampling = getattr(vllm_config.model_config, "subtalker_sampling_params", None) self._subtalker_sampling_params: dict[str, Any] = ( dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {} @@ -1364,23 +1363,36 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: in_context_mode = not xvec_only voice_clone_prompt = _normalize_voice_clone_prompt(info_dict.get("voice_clone_prompt")) - # Voice cache: only for uploaded voices (created_at > 0) - _voice_cache_key = None + # Speaker cache: only for uploaded (named) speakers + _speaker_cache_key = None if voice_clone_prompt is None: _speaker_list = info_dict.get("speaker") if isinstance(_speaker_list, list) and _speaker_list: _voice_name = str(_speaker_list[0]).lower() - _voice_created_at = float((info_dict.get("voice_created_at") or [0])[0]) - if _voice_created_at > 0: - _voice_cache_key = self._voice_cache.make_cache_key(_voice_name, xvec_only, _voice_created_at) - _cached = self._voice_cache.get(_voice_cache_key) if _voice_cache_key is not None else None + # Per-mode namespace — xvec and icl produce different artifacts + # for the same voice, so they must not share a cache slot. + _mode = "xvec" if xvec_only else "icl" + _voice_created_at = int((info_dict.get("voice_created_at") or [0])[0]) + _speaker_cache_key = self._speaker_cache.make_cache_key( + _voice_name, + model_type=f"qwen3_tts_{_mode}", + created_at=_voice_created_at, + ) + _cached = self._speaker_cache.get(_speaker_cache_key) if _cached is not None: + # Transfer cached tensors to current device + ref_code_cached = _cached.get("ref_code") + ref_spk_embed_cached = _cached.get("ref_spk_embedding") + if isinstance(ref_code_cached, torch.Tensor): + ref_code_cached = ref_code_cached.to(device=input_ids.device) + if isinstance(ref_spk_embed_cached, torch.Tensor): + ref_spk_embed_cached = ref_spk_embed_cached.to(device=input_ids.device) voice_clone_prompt = { - "ref_code": _cached.get("ref_code"), - "ref_spk_embedding": _cached.get("ref_spk_embedding"), + "ref_code": ref_code_cached, + "ref_spk_embedding": ref_spk_embed_cached, "icl_mode": _cached.get("icl_mode"), } - _voice_cache_key = None # hit -> don't store again + _speaker_cache_key = None # hit → don't store again # Official implementation may pass `voice_clone_prompt.icl_mode`. if voice_clone_prompt is not None and "icl_mode" in voice_clone_prompt: @@ -1432,9 +1444,9 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None: speaker_embed = self._extract_speaker_embedding(wav_np, sr).view(1, 1, -1) # Cache miss: store extraction result - if _voice_cache_key is not None and speaker_embed is not None: - self._voice_cache.put( - _voice_cache_key, + if _speaker_cache_key is not None and speaker_embed is not None: + self._speaker_cache.put( + _speaker_cache_key, { "ref_code": ref_code_prompt.detach().cpu() if isinstance(ref_code_prompt, torch.Tensor) diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 05f7df9c9d6..0da965d9952 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -34,6 +34,7 @@ from vllm.sequence import IntermediateTensors from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.utils.speaker_cache import get_speaker_cache from .minicpm4_paged import MiniCPM4PagedForVoxCPM2, MiniCPM4PagedResidualLM from .voxcpm2_import_utils import import_voxcpm2_core @@ -446,6 +447,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._max_decode_steps = 2000 self._max_batch_size = getattr(vllm_config.scheduler_config, "max_num_seqs", 4) + # Speaker cache for ref_audio_feat across requests + self._speaker_cache = get_speaker_cache() + self._perf = _PerfTimer(enabled=_ENABLE_PROFILING) self._cfm_buffers: _CFMBufferManager | None = None self._enable_cuda_graph = True @@ -1165,15 +1169,48 @@ def preprocess( prompt_text = prompt_text[0] if prompt_text else None state.prompt_cache = None + voice_name = info_dict.get("voice_name") + if isinstance(voice_name, list): + voice_name = voice_name[0] if voice_name else None + _created_at = int(info_dict.get("voice_created_at") or 0) + if ref_audio or (prompt_audio and prompt_text): - try: - state.prompt_cache = self._build_prompt_cache( - ref_audio=ref_audio, - prompt_audio=prompt_audio, - prompt_text=prompt_text, + # Check speaker cache for reference-only mode + if voice_name and ref_audio and not prompt_audio: + _cache_key = self._speaker_cache.make_cache_key( + voice_name, model_type="voxcpm2", created_at=_created_at ) - except Exception as e: - logger.warning("build_prompt_cache failed: %s", e) + cached = self._speaker_cache.get(_cache_key) + if cached is not None: + state.prompt_cache = { + "mode": "reference", + "ref_audio_feat": cached["ref_audio_feat"].clone(), + } + logger.debug("Speaker cache HIT for VoxCPM2 speaker '%s'", voice_name) + + if state.prompt_cache is None: + try: + state.prompt_cache = self._build_prompt_cache( + ref_audio=ref_audio, + prompt_audio=prompt_audio, + prompt_text=prompt_text, + ) + if ( + voice_name + and state.prompt_cache is not None + and state.prompt_cache.get("mode") == "reference" + and "ref_audio_feat" in state.prompt_cache + ): + _key = self._speaker_cache.make_cache_key( + voice_name, model_type="voxcpm2", created_at=_created_at + ) + self._speaker_cache.put( + _key, {"ref_audio_feat": state.prompt_cache["ref_audio_feat"].cpu()} + ) + logger.debug("Speaker cache STORE for VoxCPM2 speaker '%s'", voice_name) + except Exception as e: + logger.warning("build_prompt_cache failed: %s; falling back to zero-shot", e) + state.prompt_cache = None inputs = self._build_prefill_inputs(token_ids, dev, req_id) tts = self.tts diff --git a/vllm_omni/utils/speaker_cache.py b/vllm_omni/utils/speaker_cache.py new file mode 100644 index 00000000000..80623d89431 --- /dev/null +++ b/vllm_omni/utils/speaker_cache.py @@ -0,0 +1,136 @@ +"""Process-wide thread-safe LRU cache for speaker extraction artifacts. + +Keyed by ``(model_type, speaker_name, created_at)`` so each upload generation +has its own slot. Access via :func:`get_speaker_cache`. +""" + +from __future__ import annotations + +import threading +from collections import OrderedDict +from typing import Any + +import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + +_MAX_BYTES = 512 * 1024**2 # 512 MiB + +_SINGLETON: SpeakerEmbeddingCache | None = None +_SINGLETON_LOCK = threading.Lock() + + +def _estimate_tensor_bytes(obj: object) -> int: + if isinstance(obj, torch.Tensor): + return obj.numel() * obj.element_size() + if isinstance(obj, dict): + return sum(_estimate_tensor_bytes(v) for v in obj.values()) + if isinstance(obj, (list, tuple)): + return sum(_estimate_tensor_bytes(item) for item in obj) + return 0 + + +class SpeakerEmbeddingCache: + """Thread-safe in-memory LRU cache for speaker extraction artifacts.""" + + def __init__(self, *, max_bytes: int = _MAX_BYTES): + self._cache: OrderedDict[tuple[str, str, int], dict[str, Any]] = OrderedDict() + self._sizes: dict[tuple[str, str, int], int] = {} + self._total_bytes = 0 + self._lock = threading.Lock() + self._hits = 0 + self._misses = 0 + self._max_bytes = max_bytes + logger.info("Speaker cache ready (max_bytes=%d)", self._max_bytes) + + @staticmethod + def make_cache_key(speaker_name: str, model_type: str, created_at: int = 0) -> tuple[str, str, int]: + """Build a cache key. ``created_at=0`` for built-in speakers (no upload). + + Names are normalized (stripped + lowercased) so delete/clear paths that + normalize to lowercase match entries put with mixed-case names. + """ + if not speaker_name or not speaker_name.strip(): + raise ValueError("speaker_name is required") + if not model_type: + raise ValueError("model_type is required") + return (model_type, speaker_name.strip().lower(), int(created_at)) + + def get(self, key: tuple[str, str, int]) -> dict[str, Any] | None: + with self._lock: + if key in self._cache: + self._cache.move_to_end(key) + self._hits += 1 + return self._cache[key] + self._misses += 1 + return None + + def put(self, key: tuple[str, str, int], artifacts: dict[str, Any]) -> None: + with self._lock: + self._insert_locked(key, artifacts) + + def _insert_locked(self, key: tuple[str, str, int], artifacts: dict[str, Any]) -> None: + size = _estimate_tensor_bytes(artifacts) + if size > self._max_bytes: + logger.warning("Speaker cache skip: entry %s size=%dB exceeds max_bytes=%dB", key, size, self._max_bytes) + return + if key in self._cache: + self._total_bytes -= self._sizes.pop(key, 0) + del self._cache[key] + self._cache[key] = artifacts + self._sizes[key] = size + self._total_bytes += size + self._cache.move_to_end(key) + while self._cache and self._total_bytes > self._max_bytes: + evict_key, _ = self._cache.popitem(last=False) + self._total_bytes -= self._sizes.pop(evict_key, 0) + logger.debug("Speaker cache EVICT: key=%s", evict_key) + + def clear(self, speaker_name: str | None = None) -> int: + """Remove entries. With a name, drops matches across model types and generations.""" + with self._lock: + if speaker_name is None: + removed = len(self._cache) + self._cache.clear() + self._sizes.clear() + self._total_bytes = 0 + self._hits = 0 + self._misses = 0 + return removed + + if not speaker_name or not speaker_name.strip(): + raise ValueError("speaker_name cannot be an empty string") + normalized = speaker_name.strip().lower() + removed = 0 + for k in list(self._cache.keys()): + if isinstance(k, tuple) and len(k) >= 2 and k[1] == normalized: + self._total_bytes -= self._sizes.pop(k, 0) + del self._cache[k] + removed += 1 + return removed + + def memory_bytes(self) -> int: + with self._lock: + return self._total_bytes + + def stats(self) -> dict[str, Any]: + with self._lock: + return { + "entries": len(self._cache), + "memory_bytes": self._total_bytes, + "max_bytes": self._max_bytes, + "memory_mb": round(self._total_bytes / (1024 * 1024), 2), + "hits": self._hits, + "misses": self._misses, + } + + +def get_speaker_cache() -> SpeakerEmbeddingCache: + """Return the process-wide speaker cache singleton.""" + global _SINGLETON + if _SINGLETON is None: + with _SINGLETON_LOCK: + if _SINGLETON is None: + _SINGLETON = SpeakerEmbeddingCache() + return _SINGLETON diff --git a/vllm_omni/utils/voice_cache.py b/vllm_omni/utils/voice_cache.py deleted file mode 100644 index 2d78a5bfdb9..00000000000 --- a/vllm_omni/utils/voice_cache.py +++ /dev/null @@ -1,89 +0,0 @@ -"""In-memory LRU cache for voice extraction artifacts. - -Keyed by voice name + extraction mode (e.g. ``"alice:icl"``). -Only named voices are cached; inline ``ref_audio`` without a voice -name is not cached. - -Usage:: - - key = VoiceEmbeddingCache.make_cache_key("alice", xvec_only=False) - cached = cache.get(key) - if cached is None: - # ... extract ... - cache.put(key, {"artifact": result}) -""" - -import os -import threading -from collections import OrderedDict -from typing import Any - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -_DEFAULT_MAX_ENTRIES = 128 - - -class VoiceEmbeddingCache: - """LRU cache for voice extraction outputs. - - Each entry stores a ``dict[str, Any]`` whose contents are model-specific. - Thread-safe via a lightweight ``threading.Lock``. - """ - - def __init__(self, max_entries: int | None = None): - if max_entries is None: - max_entries = int(os.environ.get("VOICE_CACHE_MAX_ENTRIES", _DEFAULT_MAX_ENTRIES)) - self._cache: OrderedDict[str, dict[str, Any]] = OrderedDict() - self._max_entries = max_entries - self._lock = threading.Lock() - self._hits = 0 - self._misses = 0 - logger.info("Voice embedding cache initialized (max_entries=%d)", max_entries) - - @staticmethod - def make_cache_key(voice_name: str, xvec_only: bool, created_at: float = 0.0) -> str: - """Build a cache key from a voice name, upload timestamp, and extraction mode. - - Args: - voice_name: The speaker/voice name (case-insensitive, lowered - by the caller). - xvec_only: True for speaker-embedding-only mode, False for - ICL mode (speaker embedding + ref_code). - created_at: Upload timestamp from metadata. Prevents stale cache - hits after a voice is deleted and re-uploaded with the same - name but different audio. - """ - mode = "xvec" if xvec_only else "icl" - return f"{voice_name}:{created_at:.6f}:{mode}" - - def get(self, key: str) -> dict[str, Any] | None: - """Return cached artifacts or ``None`` on miss. Promotes to MRU on hit.""" - with self._lock: - if key in self._cache: - self._cache.move_to_end(key) - self._hits += 1 - logger.debug("Voice cache HIT (key=%s, hits=%d)", key, self._hits) - return self._cache[key] - self._misses += 1 - return None - - def put(self, key: str, artifacts: dict[str, Any]) -> None: - """Store *artifacts* under *key*, evicting the LRU entry if full.""" - with self._lock: - self._cache[key] = artifacts - self._cache.move_to_end(key) - while len(self._cache) > self._max_entries: - evicted_key, _ = self._cache.popitem(last=False) - logger.debug("Voice cache EVICT (key=%s)", evicted_key) - - def stats(self) -> dict[str, int]: - """Return cache statistics.""" - with self._lock: - return { - "entries": len(self._cache), - "max_entries": self._max_entries, - "hits": self._hits, - "misses": self._misses, - }