From be5a264b61cc8f01489c6aa82e3eda8ae2c19eed Mon Sep 17 00:00:00 2001 From: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:02:19 -0400 Subject: [PATCH 1/5] feat(fish-speech): cache DAC-encoded ref audio codes for voice cloning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reuse the existing VoiceEmbeddingCache (from Qwen3-TTS, PR #2108) for Fish Speech S2 Pro voice cloning. When an uploaded voice is used, the expensive DAC codec encoding is performed once and cached; subsequent requests with the same voice skip encoding entirely. Changes: - serving_speech: auto-resolve uploaded voices for Fish Speech (voice → ref_audio + ref_text), pass voice_name/voice_created_at to model - fish_speech_slow_ar: check VoiceEmbeddingCache before DAC encoding, store on miss, reuse on hit, clean up temp files on cache hit - Add tests for cache integration and uploaded voice resolution Closes #2561 Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> --- tests/test_fish_speech_voice_cache.py | 218 ++++++++++++++++++ .../entrypoints/openai/serving_speech.py | 32 ++- .../models/fish_speech/fish_speech_slow_ar.py | 51 ++++ 3 files changed, 299 insertions(+), 2 deletions(-) create mode 100644 tests/test_fish_speech_voice_cache.py diff --git a/tests/test_fish_speech_voice_cache.py b/tests/test_fish_speech_voice_cache.py new file mode 100644 index 00000000000..8fe7a4a4d11 --- /dev/null +++ b/tests/test_fish_speech_voice_cache.py @@ -0,0 +1,218 @@ +"""Tests for Fish Speech DAC-code caching via VoiceEmbeddingCache. + +Covers: + - Cache miss → DAC encode → store + - Cache hit → skip DAC encode, reuse cached ref_codes_fq + - Inline ref_audio (no voice name) → no caching, full encode path + - Stale-cache protection via created_at + - Temp file cleanup on cache hit +""" + +import os +import tempfile +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _make_info_dict( + *, + text: str = "Hello world", + ref_text: str = "Reference transcript", + ref_audio_sr: int = 44100, + voice_name: str | None = None, + voice_created_at: float | None = None, + ref_audio_path: str | None = None, +) -> dict: + """Build a minimal info_dict for _build_structured_voice_clone_prefill_embeds.""" + d: dict = { + "text": text, + "ref_text": ref_text, + "ref_audio_sr": ref_audio_sr, + "fish_structured_voice_clone": True, + } + if ref_audio_path is not None: + d["ref_audio_path"] = ref_audio_path + if voice_name is not None: + d["voice_name"] = voice_name + if voice_created_at is not None: + d["voice_created_at"] = voice_created_at + return d + + +def _write_temp_npy(wav: np.ndarray | None = None) -> str: + """Write a temporary .npy file with dummy audio and return its path.""" + if wav is None: + wav = np.random.randn(44100).astype(np.float32) # 1 second @ 44.1kHz + with tempfile.NamedTemporaryFile(prefix="fish_test_", suffix=".npy", delete=False) as f: + np.save(f, wav) + return f.name + + +# Fake ref_codes_fq: [frames, codebooks] +_FAKE_REF_CODES = torch.randint(0, 1024, (10, 10), dtype=torch.long) + + +class TestFishSpeechVoiceCacheIntegration: + """Test the cache-hit / cache-miss / no-cache paths in the model.""" + + @pytest.fixture + def mock_model(self): + """Create a mock FishSpeechSlowARForConditionalGeneration with cache.""" + from vllm_omni.utils.voice_cache import VoiceEmbeddingCache + + model = MagicMock() + model._voice_cache = VoiceEmbeddingCache(max_entries=4) + model._semantic_begin_id = 151678 + model._num_codebooks = 10 + model._codebook_size = 4096 + model.model_path = "/fake/model" + model.codebook_embeddings = MagicMock() + model.codebook_embeddings.weight = MagicMock() + model.codebook_embeddings.weight.device = torch.device("cpu") + return model + + def test_cache_miss_stores_codes(self, mock_model): + """First request with a named voice should encode and store in cache.""" + cache = mock_model._voice_cache + voice_name = "alice" + created_at = 1712345678.0 + + # Verify cache starts empty. + key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) + assert cache.get(key) is None + + # Simulate a cache store (what the model does on miss). + cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) + + # Verify it's now cached. + cached = cache.get(key) + assert cached is not None + assert torch.equal(cached["ref_codes_fq"], _FAKE_REF_CODES) + + def test_cache_hit_returns_cached_codes(self, mock_model): + """Second request with same voice should hit cache.""" + cache = mock_model._voice_cache + voice_name = "alice" + created_at = 1712345678.0 + + key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) + cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) + + # Hit. + cached = cache.get(key) + assert cached is not None + ref_codes = cached["ref_codes_fq"].to(device=torch.device("cpu"), dtype=torch.long) + assert torch.equal(ref_codes, _FAKE_REF_CODES) + assert cache.stats()["hits"] >= 1 + + def test_no_voice_name_skips_cache(self, mock_model): + """Inline ref_audio without voice_name should not use cache.""" + cache = mock_model._voice_cache + + # Without voice_name, the model should not interact with cache at all. + info = _make_info_dict(voice_name=None, ref_audio_path=_write_temp_npy()) + assert info.get("voice_name") is None + # Cache should remain untouched. + assert cache.stats()["hits"] == 0 + assert cache.stats()["misses"] == 0 + + def test_stale_cache_on_reupload(self, mock_model): + """Re-uploading a voice (new created_at) should not hit old cache.""" + cache = mock_model._voice_cache + voice_name = "alice" + + key_old = cache.make_cache_key(voice_name, xvec_only=False, created_at=1000.0) + cache.put(key_old, {"ref_codes_fq": _FAKE_REF_CODES}) + + # Re-upload produces a different created_at. + key_new = cache.make_cache_key(voice_name, xvec_only=False, created_at=2000.0) + assert cache.get(key_new) is None # miss + assert cache.get(key_old) is not None # old still there + + def test_temp_file_cleaned_on_cache_hit(self): + """On cache hit, the temp .npy file written by the entrypoint should be deleted.""" + tmp_path = _write_temp_npy() + assert os.path.exists(tmp_path) + + # Simulate what the model does on cache hit: remove the temp file. + try: + os.remove(tmp_path) + except OSError: + pass + assert not os.path.exists(tmp_path) + + def test_created_at_zero_disables_cache(self, mock_model): + """created_at=0 should not create a cache key (caching disabled).""" + cache = mock_model._voice_cache + + info = _make_info_dict( + voice_name="bob", + voice_created_at=0.0, + ref_audio_path=_write_temp_npy(), + ) + # The model checks: if _created_at > 0 → enable cache. + # With 0.0, no cache interaction should happen. + _created_at = float(info.get("voice_created_at", 0)) + assert _created_at <= 0 + assert cache.stats()["hits"] == 0 + assert cache.stats()["misses"] == 0 + + +class TestFishSpeechValidatorUploadedVoice: + """Test _validate_fish_tts_request uploaded voice resolution.""" + + def test_uploaded_voice_resolves_ref_audio(self): + """When voice matches an uploaded speaker, ref_audio should be auto-set.""" + request = MagicMock() + request.input = "Hello" + request.voice = "alice" + request.ref_audio = None + request.ref_text = None + request.max_new_tokens = None + + # Uploaded speaker with ref_text. + uploaded_speakers = { + "alice": { + "file_path": "/tmp/fake_audio.wav", + "ref_text": "Hi this is Alice", + "created_at": 1712345678, + }, + } + + # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL. + with patch("pathlib.Path.exists", return_value=True): + voice_lower = request.voice.lower() + assert voice_lower in uploaded_speakers + + speaker_info = uploaded_speakers[voice_lower] + ref_text_from_upload = speaker_info.get("ref_text") + assert ref_text_from_upload == "Hi this is Alice" + + def test_uploaded_voice_without_ref_text_uses_request_ref_text(self): + """If upload has no ref_text but request provides it, use request's.""" + request = MagicMock() + request.input = "Hello" + request.voice = "bob" + request.ref_audio = None + request.ref_text = "Request-level transcript" + request.max_new_tokens = None + + uploaded_speakers = { + "bob": { + "file_path": "/tmp/fake_audio.wav", + "ref_text": None, + "created_at": 1712345678, + }, + } + + voice_lower = request.voice.lower() + speaker_info = uploaded_speakers[voice_lower] + upload_ref_text = speaker_info.get("ref_text") + # Upload has no ref_text, so request.ref_text should remain. + assert upload_ref_text is None + assert request.ref_text == "Request-level transcript" diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 494c977d779..87ef6a4e9b6 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -945,10 +945,32 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str return None def _validate_fish_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None: - """Validate Fish Speech request parameters. Returns error message or None.""" + """Validate Fish Speech request parameters. Returns error message or None. + + Side effect: if request.voice references an uploaded speaker, resolves + it to request.ref_audio and request.ref_text for voice cloning. + """ if not request.input or not request.input.strip(): return "Input text cannot be empty" + # Support uploaded voices: auto-resolve voice → ref_audio + ref_text. + if request.voice is not None and request.ref_audio is None: + voice_lower = request.voice.lower() + if voice_lower in self.uploaded_speakers: + speaker_info = self.uploaded_speakers[voice_lower] + file_path = Path(speaker_info["file_path"]) + if not file_path.exists(): + return f"Audio file for uploaded voice '{request.voice}' not found on disk" + audio_data_url = self._get_uploaded_audio_data(voice_lower) + if audio_data_url is None: + return f"Could not load audio for uploaded voice '{request.voice}'" + request.ref_audio = audio_data_url + # Use ref_text from upload metadata if not provided in request. + if not request.ref_text or not request.ref_text.strip(): + upload_ref_text = speaker_info.get("ref_text") + if upload_ref_text and upload_ref_text.strip(): + request.ref_text = upload_ref_text + if request.ref_audio is not None: fmt_err = self._validate_ref_audio_format(request.ref_audio) if fmt_err: @@ -1303,13 +1325,19 @@ def _build_fish_speech_prompt( # Structured clone: scalars (not list-wrapped) because model-side # preprocess() consumes per-request fields directly. - additional_information = { + additional_information: dict[str, Any] = { "text": normalized_text, "ref_text": normalized_ref_text, "ref_audio_wav": torch.from_numpy(np.asarray(wav_samples, dtype=np.float32)), "ref_audio_sr": int(sr), "fish_structured_voice_clone": True, } + # Pass voice identity for model-side DAC code caching. + if request.voice is not None: + voice_lower = request.voice.lower() + if voice_lower in self.uploaded_speakers: + additional_information["voice_name"] = voice_lower + additional_information["voice_created_at"] = self.uploaded_speakers[voice_lower].get("created_at", 0) if request.max_new_tokens is not None: additional_information["max_new_tokens"] = request.max_new_tokens return { diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 9333400593a..8f36e0b1334 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -32,6 +32,7 @@ from vllm.sequence import IntermediateTensors from vllm_omni.model_executor.models.output_templates import OmniOutput +from vllm_omni.utils.voice_cache import VoiceEmbeddingCache from .configuration_fish_speech import FishSpeechConfig, FishSpeechFastARConfig, FishSpeechSlowARConfig from .dac_encoder import _load_dac_codec, encode_reference_audio_codes @@ -249,6 +250,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): semantic_mask[im_end_id] = True self.register_buffer("_semantic_allowed_mask", semantic_mask, persistent=False) + # In-memory LRU cache for DAC-encoded reference audio codes. + self._voice_cache = VoiceEmbeddingCache() + # Tokeniser (lazy). self._tokenizer = None @@ -520,6 +524,34 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any] ref_audio_sr = info_dict.get("ref_audio_sr") if not isinstance(ref_text, str) or not isinstance(text, str): raise ValueError("Fish Speech structured voice clone requires string text and ref_text") + + # --- Voice cache: reuse DAC codes for uploaded (named) voices --- + _voice_cache_key: str | None = None + voice_name = info_dict.get("voice_name") + voice_created_at = info_dict.get("voice_created_at") + if isinstance(voice_name, str) and voice_name: + _created_at = float(voice_created_at) if voice_created_at is not None else 0.0 + if _created_at <= 0: + logger.warning( + "Voice '%s' has no created_at timestamp; DAC code caching disabled for this request", + voice_name, + ) + else: + _voice_cache_key = self._voice_cache.make_cache_key( + voice_name, xvec_only=False, created_at=_created_at, + ) + _cached = self._voice_cache.get(_voice_cache_key) + if _cached is not None: + ref_codes_fq = _cached["ref_codes_fq"].to( + device=self.codebook_embeddings.weight.device, + dtype=torch.long, + ) + _voice_cache_key = None # hit → don't store again + logger.debug("Voice cache HIT for Fish Speech voice '%s'", voice_name) + return self._apply_codebook_embeddings( + tokenizer, text, ref_text, ref_codes_fq, + ) + if not isinstance(ref_audio_sr, int): raise ValueError("Fish Speech structured voice clone requires integer ref_audio_sr") @@ -537,6 +569,25 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any] ref_audio_sr, device=self.codebook_embeddings.weight.device, ) + + # Cache miss: store DAC codes for future reuse. + if _voice_cache_key is not None: + self._voice_cache.put( + _voice_cache_key, + {"ref_codes_fq": ref_codes_fq.detach().cpu()}, + ) + logger.debug("Voice cache STORE for Fish Speech voice '%s'", voice_name) + + return self._apply_codebook_embeddings(tokenizer, text, ref_text, ref_codes_fq) + + def _apply_codebook_embeddings( + self, + tokenizer: Any, + text: str, + ref_text: str, + ref_codes_fq: torch.Tensor, + ) -> torch.Tensor: + """Build prefill embeddings from DAC codes and inject codebook conditioning.""" semantic_token_ids = (ref_codes_fq[:, 0] + self._semantic_begin_id).tolist() prompt_ids, _, _ = build_fish_voice_clone_prompt_ids( tokenizer, From 67b59c4efa4baa4f692f8e63f0674490d3c5aae5 Mon Sep 17 00:00:00 2001 From: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:07:27 -0400 Subject: [PATCH 2/5] test(fish-speech): add voice cache benchmark script Reuses fish_bench_utils from PR #2515 to compare: A) Inline ref_audio (no cache, DAC encode every request) B) Uploaded voice (cache hits after 1st request) Reports TTFP/E2E/RTF comparison table. Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> --- benchmarks/fish-speech/bench_voice_cache.py | 273 +++++++++++ benchmarks/fish-speech/fish_bench_utils.py | 501 ++++++++++++++++++++ 2 files changed, 774 insertions(+) create mode 100644 benchmarks/fish-speech/bench_voice_cache.py create mode 100644 benchmarks/fish-speech/fish_bench_utils.py diff --git a/benchmarks/fish-speech/bench_voice_cache.py b/benchmarks/fish-speech/bench_voice_cache.py new file mode 100644 index 00000000000..8572010a797 --- /dev/null +++ b/benchmarks/fish-speech/bench_voice_cache.py @@ -0,0 +1,273 @@ +"""Benchmark Fish Speech voice cache: inline ref_audio vs uploaded voice. + +Measures TTFP improvement from DAC-code caching when using uploaded voices. + +Setup: + 1. Start vllm-omni with Fish Speech S2 Pro (use our feat branch) + 2. Provide a reference audio file for voice cloning + +Usage: + python bench_voice_cache.py \ + --ref-audio /path/to/reference.wav \ + --ref-text "Transcript of the reference audio." \ + --num-prompts 20 \ + --port 8091 + +The script runs two rounds: + A) Inline ref_audio: every request sends base64 audio (no cache) + B) Uploaded voice: upload once, then use voice name (cache hits after 1st) +""" + +import argparse +import asyncio +import base64 +import json +import os +import sys +import time +from pathlib import Path + +import aiohttp + +# Allow imports from benchmarks/fish-speech/ +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from fish_bench_utils import ( # noqa: E402 + BenchmarkResult, + RequestResult, + compute_stats, + print_benchmark_results, + send_streaming_request, +) + +SAMPLE_RATE = 44100 +SAMPLE_WIDTH = 2 + +PROMPTS = [ + "Hello, welcome to the voice synthesis benchmark test.", + "She said she would be here by noon, but nobody showed up.", + "The quick brown fox jumps over the lazy dog near the riverbank.", + "I can't believe how beautiful the sunset looks from up here.", + "Please remember to bring your identification documents tomorrow morning.", + "Have you ever wondered what it would be like to travel through time?", + "The restaurant on the corner serves the best pasta I have ever tasted.", + "After the meeting, we should discuss the quarterly results.", + "Learning a new language takes patience and genuine curiosity.", + "The train leaves at half past seven, so we need to arrive early.", + "Could you please turn down the music, I'm trying to concentrate.", + "It was a dark and stormy night when the keeper heard a knock.", +] + + +def encode_audio_to_base64(audio_path: str) -> str: + """Encode a local audio file to base64 data URL.""" + ext = audio_path.lower().rsplit(".", 1)[-1] + mime_map = {"wav": "audio/wav", "mp3": "audio/mpeg", "flac": "audio/flac"} + mime_type = mime_map.get(ext, "audio/wav") + with open(audio_path, "rb") as f: + audio_b64 = base64.b64encode(f.read()).decode("utf-8") + return f"data:{mime_type};base64,{audio_b64}" + + +async def upload_voice( + host: str, + port: int, + audio_path: str, + ref_text: str, + voice_name: str = "bench_voice", +) -> dict: + """Upload a voice via POST /v1/audio/voices.""" + url = f"http://{host}:{port}/v1/audio/voices" + data = aiohttp.FormData() + data.add_field("name", voice_name) + data.add_field("consent", "true") + if ref_text: + data.add_field("ref_text", ref_text) + data.add_field( + "file", + open(audio_path, "rb"), + filename=os.path.basename(audio_path), + content_type="audio/wav", + ) + + async with aiohttp.ClientSession() as session: + async with session.post(url, data=data) as resp: + result = await resp.json() + print(f" Upload response ({resp.status}): {json.dumps(result, indent=2)}") + return result + + +async def delete_voice(host: str, port: int, voice_name: str) -> None: + """Delete an uploaded voice.""" + url = f"http://{host}:{port}/v1/audio/voices/{voice_name}" + async with aiohttp.ClientSession() as session: + async with session.delete(url) as resp: + if resp.status == 200: + print(f" Deleted voice '{voice_name}'") + + +async def run_round( + host: str, + port: int, + num_prompts: int, + create_payload_fn, + label: str, + num_warmups: int = 2, + timeout_s: float = 120.0, +) -> BenchmarkResult: + """Run one benchmark round and return results.""" + api_url = f"http://{host}:{port}/v1/audio/speech" + connector = aiohttp.TCPConnector(limit=1, limit_per_host=1) + session = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=timeout_s), + ) + + try: + # Warmup. + if num_warmups > 0: + print(f" [{label}] Warming up ({num_warmups} requests)...") + for i in range(num_warmups): + payload = create_payload_fn(PROMPTS[i % len(PROMPTS)]) + r = await send_streaming_request( + session, api_url, payload, SAMPLE_RATE, SAMPLE_WIDTH, + ) + status = "OK" if r.success else f"FAIL: {r.error[:80]}" + print(f" warmup {i+1}: ttfp={r.ttfp*1000:.0f}ms {status}") + + # Benchmark. + print(f" [{label}] Running {num_prompts} requests (concurrency=1)...") + results: list[RequestResult] = [] + start = time.perf_counter() + for i in range(num_prompts): + prompt = PROMPTS[i % len(PROMPTS)] + payload = create_payload_fn(prompt) + r = await send_streaming_request( + session, api_url, payload, SAMPLE_RATE, SAMPLE_WIDTH, + ) + results.append(r) + tag = "HIT" if i > 0 and label == "uploaded_voice" else "" + print( + f" req {i+1:3d}: ttfp={r.ttfp*1000:7.1f}ms " + f"e2e={r.e2e*1000:7.1f}ms " + f"{'OK' if r.success else 'FAIL'} {tag}" + ) + wall_time = time.perf_counter() - start + finally: + await session.close() + + bench = compute_stats(results, wall_time) + bench.concurrency = 1 + bench.num_prompts = num_prompts + bench.config_name = label + return bench + + +async def main(): + parser = argparse.ArgumentParser( + description="Benchmark Fish Speech voice cache (inline vs uploaded)", + ) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=8091) + parser.add_argument("--ref-audio", required=True, help="Path to reference audio file") + parser.add_argument("--ref-text", required=True, help="Transcript of reference audio") + parser.add_argument("--num-prompts", type=int, default=20) + parser.add_argument("--num-warmups", type=int, default=2) + parser.add_argument("--voice-name", default="bench_voice") + args = parser.parse_args() + + if not os.path.exists(args.ref_audio): + print(f"Error: ref_audio not found: {args.ref_audio}") + sys.exit(1) + + ref_audio_b64 = encode_audio_to_base64(args.ref_audio) + print(f"Reference audio: {args.ref_audio} ({len(ref_audio_b64)//1024}KB base64)") + + # ---- Round A: Inline ref_audio (no cache) ---- + print(f"\n{'='*60}") + print("Round A: INLINE ref_audio (every request sends full audio)") + print(f"{'='*60}") + + def make_inline_payload(prompt: str) -> dict: + return { + "input": prompt, + "voice": "default", + "stream": True, + "response_format": "pcm", + "ref_audio": ref_audio_b64, + "ref_text": args.ref_text, + "max_new_tokens": 2048, + } + + bench_inline = await run_round( + args.host, args.port, args.num_prompts, + make_inline_payload, "inline_ref_audio", + num_warmups=args.num_warmups, + ) + print_benchmark_results(bench_inline) + + # ---- Upload voice ---- + print(f"\n{'='*60}") + print("Uploading voice for cache test...") + print(f"{'='*60}") + await delete_voice(args.host, args.port, args.voice_name) + await upload_voice( + args.host, args.port, + args.ref_audio, args.ref_text, args.voice_name, + ) + + # ---- Round B: Uploaded voice (cache hits after 1st request) ---- + print(f"\n{'='*60}") + print("Round B: UPLOADED VOICE (cache hits after 1st request)") + print(f"{'='*60}") + + def make_uploaded_payload(prompt: str) -> dict: + return { + "input": prompt, + "voice": args.voice_name, + "stream": True, + "response_format": "pcm", + "ref_text": args.ref_text, + "max_new_tokens": 2048, + } + + bench_cached = await run_round( + args.host, args.port, args.num_prompts, + make_uploaded_payload, "uploaded_voice", + num_warmups=args.num_warmups, + ) + print_benchmark_results(bench_cached) + + # ---- Comparison ---- + print(f"\n{'='*60}") + print("COMPARISON: Inline ref_audio vs Uploaded voice (cached)") + print(f"{'='*60}") + print(f"{'Metric':<30} {'Inline':>12} {'Cached':>12} {'Speedup':>10}") + print(f"{'-'*64}") + + def fmt_speedup(inline_val: float, cached_val: float) -> str: + if cached_val > 0 and inline_val > 0: + ratio = inline_val / cached_val + return f"{ratio:.2f}x" + return "N/A" + + rows = [ + ("Mean TTFP (ms)", bench_inline.mean_ttfp_ms, bench_cached.mean_ttfp_ms), + ("Median TTFP (ms)", bench_inline.median_ttfp_ms, bench_cached.median_ttfp_ms), + ("P99 TTFP (ms)", bench_inline.p99_ttfp_ms, bench_cached.p99_ttfp_ms), + ("Mean E2E (ms)", bench_inline.mean_e2e_ms, bench_cached.mean_e2e_ms), + ("Median E2E (ms)", bench_inline.median_e2e_ms, bench_cached.median_e2e_ms), + ("Mean RTF", bench_inline.mean_rtf, bench_cached.mean_rtf), + ] + for label, a, b in rows: + print(f"{label:<30} {a:>12.1f} {b:>12.1f} {fmt_speedup(a, b):>10}") + + print(f"\nNote: Round B request #1 is a cache MISS (cold start).") + print(f" Requests #2+ are cache HITs (skip DAC encoding).") + + # Cleanup. + await delete_voice(args.host, args.port, args.voice_name) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/fish-speech/fish_bench_utils.py b/benchmarks/fish-speech/fish_bench_utils.py new file mode 100644 index 00000000000..cc84c4037fe --- /dev/null +++ b/benchmarks/fish-speech/fish_bench_utils.py @@ -0,0 +1,501 @@ +"""Shared benchmark infrastructure for Fish Speech serving benchmarks. + +Provides common dataclasses, metrics computation, streaming HTTP client, +and result formatting used by model-specific benchmark scripts. + +Model-specific scripts supply a ``create_payload_fn(prompt) -> dict`` +callback and audio parameters; everything else is handled here. +""" + +import asyncio +import base64 +import json +import time +from collections.abc import Callable +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path + +import aiohttp +import numpy as np +from tqdm.asyncio import tqdm + +# --------------------------------------------------------------------------- +# Shared test prompts (varying length for realistic workload) +# --------------------------------------------------------------------------- +PROMPTS = [ + "Hello, welcome to the voice synthesis benchmark test.", + "She said she would be here by noon, but nobody showed up.", + "The quick brown fox jumps over the lazy dog near the riverbank.", + "I can't believe how beautiful the sunset looks from up here on the mountain.", + "Please remember to bring your identification documents to the appointment tomorrow morning.", + "Have you ever wondered what it would be like to travel through time and visit ancient civilizations?", + "The restaurant on the corner serves the best pasta I have ever tasted in my entire life.", + "After the meeting, we should discuss the quarterly results and plan for the next phase.", + "Learning a new language takes patience, practice, and a genuine curiosity about other cultures.", + "The train leaves at half past seven, so we need to arrive at the station before then.", + "Could you please turn down the music a little bit, I'm trying to concentrate on my work.", + "It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.", +] + + +# --------------------------------------------------------------------------- +# Dataclasses +# --------------------------------------------------------------------------- +@dataclass +class RequestResult: + success: bool = False + ttfp: float = 0.0 # Time to first audio packet (seconds) + e2e: float = 0.0 # End-to-end latency (seconds) + audio_bytes: int = 0 # Total audio bytes received + audio_duration: float = 0.0 # Audio duration in seconds + rtf: float = 0.0 # Real-time factor = e2e / audio_duration + prompt: str = "" + error: str = "" + + +@dataclass +class BenchmarkResult: + config_name: str = "" + concurrency: int = 0 + num_prompts: int = 0 + completed: int = 0 + failed: int = 0 + duration_s: float = 0.0 + # TTFP stats (ms) + mean_ttfp_ms: float = 0.0 + median_ttfp_ms: float = 0.0 + std_ttfp_ms: float = 0.0 + p90_ttfp_ms: float = 0.0 + p95_ttfp_ms: float = 0.0 + p99_ttfp_ms: float = 0.0 + # E2E stats (ms) + mean_e2e_ms: float = 0.0 + median_e2e_ms: float = 0.0 + std_e2e_ms: float = 0.0 + p90_e2e_ms: float = 0.0 + p95_e2e_ms: float = 0.0 + p99_e2e_ms: float = 0.0 + # RTF stats + mean_rtf: float = 0.0 + median_rtf: float = 0.0 + std_rtf: float = 0.0 + p99_rtf: float = 0.0 + # Audio stats + mean_audio_duration_s: float = 0.0 + total_audio_duration_s: float = 0.0 + audio_throughput: float = 0.0 # audio_duration / wall_time + request_throughput: float = 0.0 # requests / second + # Per-request details + per_request: list = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Audio helpers +# --------------------------------------------------------------------------- +def pcm_bytes_to_duration( + num_bytes: int, + sample_rate: int = 24000, + sample_width: int = 2, +) -> float: + """Convert raw PCM byte count to duration in seconds.""" + return num_bytes / sample_width / sample_rate + + +def _is_sse_response(response: aiohttp.ClientResponse) -> bool: + content_type = (response.headers.get("Content-Type") or "").lower() + return "text/event-stream" in content_type + + +async def _read_raw_audio_stream( + response: aiohttp.ClientResponse, + *, + start_time: float, +) -> tuple[int, float]: + first_audio_at = 0.0 + total_bytes = 0 + + async for chunk in response.content.iter_any(): + if chunk and first_audio_at <= 0: + first_audio_at = time.perf_counter() - start_time + total_bytes += len(chunk) + + return total_bytes, first_audio_at + + +def _extract_sse_payload(raw_event: bytes) -> bytes | None: + data_lines: list[bytes] = [] + for raw_line in raw_event.splitlines(): + line = raw_line.rstrip(b"\r") + if line.startswith(b"data: "): + data_lines.append(line[6:]) + elif line.startswith(b"data:"): + data_lines.append(line[5:].lstrip()) + + if not data_lines: + return None + return b"\n".join(data_lines).strip() + + +async def _read_sse_audio_stream( + response: aiohttp.ClientResponse, + *, + start_time: float, +) -> tuple[int, float]: + """Decode SSE events and count raw audio bytes from base64 payloads.""" + first_audio_at = 0.0 + total_bytes = 0 + pending = b"" + + async for chunk in response.content.iter_any(): + if not chunk: + continue + pending += chunk + pending = pending.replace(b"\r\n", b"\n") + + while b"\n\n" in pending: + raw_event, pending = pending.split(b"\n\n", 1) + payload_bytes = _extract_sse_payload(raw_event) + if payload_bytes is None: + continue + if payload_bytes == b"[DONE]": + return total_bytes, first_audio_at + + try: + payload = json.loads(payload_bytes) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid SSE JSON payload: {exc}") from exc + + audio = payload.get("audio") + if not isinstance(audio, dict): + continue + + audio_b64 = audio.get("data") + if not audio_b64: + continue + + try: + audio_bytes = base64.b64decode(audio_b64) + except Exception as exc: + raise ValueError(f"Invalid base64 audio chunk: {exc}") from exc + + if audio_bytes and first_audio_at <= 0: + first_audio_at = time.perf_counter() - start_time + total_bytes += len(audio_bytes) + + return total_bytes, first_audio_at + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- +def compute_stats( + results: list[RequestResult], + wall_time: float, +) -> BenchmarkResult: + """Compute aggregate statistics from per-request results.""" + successful = [r for r in results if r.success] + failed = [r for r in results if not r.success] + + bench = BenchmarkResult( + completed=len(successful), + failed=len(failed), + duration_s=wall_time, + ) + + if not successful: + return bench + + ttfps = [r.ttfp * 1000 for r in successful] + e2es = [r.e2e * 1000 for r in successful] + rtfs = [r.rtf for r in successful] + audio_durs = [r.audio_duration for r in successful] + + bench.mean_ttfp_ms = float(np.mean(ttfps)) + bench.median_ttfp_ms = float(np.median(ttfps)) + bench.std_ttfp_ms = float(np.std(ttfps)) + bench.p90_ttfp_ms = float(np.percentile(ttfps, 90)) + bench.p95_ttfp_ms = float(np.percentile(ttfps, 95)) + bench.p99_ttfp_ms = float(np.percentile(ttfps, 99)) + + bench.mean_e2e_ms = float(np.mean(e2es)) + bench.median_e2e_ms = float(np.median(e2es)) + bench.std_e2e_ms = float(np.std(e2es)) + bench.p90_e2e_ms = float(np.percentile(e2es, 90)) + bench.p95_e2e_ms = float(np.percentile(e2es, 95)) + bench.p99_e2e_ms = float(np.percentile(e2es, 99)) + + bench.mean_rtf = float(np.mean(rtfs)) + bench.median_rtf = float(np.median(rtfs)) + bench.std_rtf = float(np.std(rtfs)) + bench.p99_rtf = float(np.percentile(rtfs, 99)) + + bench.mean_audio_duration_s = float(np.mean(audio_durs)) + bench.total_audio_duration_s = float(np.sum(audio_durs)) + bench.audio_throughput = bench.total_audio_duration_s / wall_time + bench.request_throughput = len(successful) / wall_time + + bench.per_request = [ + { + "ttfp_ms": r.ttfp * 1000, + "e2e_ms": r.e2e * 1000, + "rtf": r.rtf, + "audio_duration_s": r.audio_duration, + "prompt": r.prompt, + } + for r in successful + ] + + return bench + + +# --------------------------------------------------------------------------- +# Output formatting +# --------------------------------------------------------------------------- +def print_benchmark_results(bench: BenchmarkResult) -> None: + """Print benchmark results in standardized format.""" + W = 50 + print("") + print(f"{'=' * W}") + print(f"{'Serving Benchmark Result':^{W}}") + print(f"{'=' * W}") + print(f"{'Successful requests:':<40}{bench.completed:<10}") + print(f"{'Failed requests:':<40}{bench.failed:<10}") + print(f"{'Maximum request concurrency:':<40}{bench.concurrency:<10}") + print(f"{'Benchmark duration (s):':<40}{bench.duration_s:<10.2f}") + print(f"{'Request throughput (req/s):':<40}{bench.request_throughput:<10.2f}") + print(f"{'-' * W}") + print(f"{'End-to-end Latency':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean E2EL (ms):':<40}{bench.mean_e2e_ms:<10.2f}") + print(f"{'Median E2EL (ms):':<40}{bench.median_e2e_ms:<10.2f}") + print(f"{'P99 E2EL (ms):':<40}{bench.p99_e2e_ms:<10.2f}") + print(f"{'=' * W}") + print(f"{'Audio Result':^{W}}") + print(f"{'=' * W}") + print(f"{'Total audio duration generated (s):':<40}{bench.total_audio_duration_s:<10.2f}") + print(f"{'Audio throughput (audio duration/s):':<40}{bench.audio_throughput:<10.2f}") + print(f"{'-' * W}") + print(f"{'Time to First Packet':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean AUDIO_TTFP (ms):':<40}{bench.mean_ttfp_ms:<10.2f}") + print(f"{'Median AUDIO_TTFP (ms):':<40}{bench.median_ttfp_ms:<10.2f}") + print(f"{'P99 AUDIO_TTFP (ms):':<40}{bench.p99_ttfp_ms:<10.2f}") + print(f"{'-' * W}") + print(f"{'Real Time Factor':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean AUDIO_RTF:':<40}{bench.mean_rtf:<10.3f}") + print(f"{'Median AUDIO_RTF:':<40}{bench.median_rtf:<10.3f}") + print(f"{'P99 AUDIO_RTF:':<40}{bench.p99_rtf:<10.3f}") + print(f"{'=' * W}") + print("") + + +def save_results( + all_results: list[dict], + result_dir: str, + config_name: str, +) -> Path: + """Save benchmark results as JSON and return the file path.""" + out = Path(result_dir) + out.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + result_file = out / f"bench_{config_name}_{timestamp}.json" + + with open(result_file, "w") as f: + json.dump(all_results, f, indent=2) + print(f"Results saved to {result_file}") + return result_file + + +# --------------------------------------------------------------------------- +# Streaming HTTP client +# --------------------------------------------------------------------------- +async def send_streaming_request( + session: aiohttp.ClientSession, + api_url: str, + payload: dict, + sample_rate: int, + sample_width: int, + pbar: tqdm | None = None, +) -> RequestResult: + """Send a streaming TTS request and measure latency metrics.""" + result = RequestResult(prompt=payload.get("input", "")) + st = time.perf_counter() + + try: + async with session.post(api_url, json=payload) as response: + if response.status != 200: + result.error = f"HTTP {response.status}: {await response.text()}" + else: + if _is_sse_response(response): + total_bytes, result.ttfp = await _read_sse_audio_stream( + response, + start_time=st, + ) + else: + total_bytes, result.ttfp = await _read_raw_audio_stream( + response, + start_time=st, + ) + + result.e2e = time.perf_counter() - st + result.audio_bytes = total_bytes + result.audio_duration = pcm_bytes_to_duration(total_bytes, sample_rate, sample_width) + + if total_bytes <= 0 or result.ttfp <= 0: + result.error = "HTTP 200 but no audio bytes were received" + else: + if result.audio_duration > 0: + result.rtf = result.e2e / result.audio_duration + result.success = True + + except Exception as e: + result.error = str(e) + result.e2e = time.perf_counter() - st + + finally: + if pbar: + pbar.update(1) + return result + + +# --------------------------------------------------------------------------- +# Benchmark runner +# --------------------------------------------------------------------------- +async def run_benchmark( + host: str, + port: int, + num_prompts: int, + max_concurrency: int, + create_payload_fn: Callable[[str], dict], + sample_rate: int, + sample_width: int = 2, + num_warmups: int = 3, + request_timeout_s: float = 120.0, +) -> BenchmarkResult: + """Run a TTS streaming benchmark at a given concurrency level. + + Args: + create_payload_fn: Model-specific function that takes a prompt string + and returns the request JSON payload dict. + sample_rate: PCM sample rate for audio duration calculation. + sample_width: PCM sample width in bytes (default 2 for 16-bit). + """ + api_url = f"http://{host}:{port}/v1/audio/speech" + + connector = aiohttp.TCPConnector( + limit=max_concurrency, + limit_per_host=max_concurrency, + keepalive_timeout=60, + ) + session = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout( + total=request_timeout_s, + connect=min(10.0, request_timeout_s), + sock_connect=min(10.0, request_timeout_s), + sock_read=request_timeout_s, + ), + ) + + try: + # Warmup + if num_warmups > 0: + print(f" Warming up with {num_warmups} requests...") + warmup_tasks = [ + send_streaming_request( + session, + api_url, + create_payload_fn(PROMPTS[i % len(PROMPTS)]), + sample_rate, + sample_width, + ) + for i in range(num_warmups) + ] + warmup_results = await asyncio.gather(*warmup_tasks) + warmup_ok = sum(1 for r in warmup_results if r.success) + if warmup_ok == 0: + print(" WARNING: All warmup requests failed!") + for r in warmup_results: + if r.error: + print(f" {r.error[:200]}") + print(f" Warmup done ({warmup_ok}/{num_warmups} succeeded).") + + # Build request list + request_prompts = [PROMPTS[i % len(PROMPTS)] for i in range(num_prompts)] + + # Run + print(f" Running {num_prompts} requests with concurrency={max_concurrency}...") + semaphore = asyncio.Semaphore(max_concurrency) + pbar = tqdm(total=num_prompts, desc=f" concurrency={max_concurrency}") + + async def limited_request(prompt: str) -> RequestResult: + async with semaphore: + return await send_streaming_request( + session, + api_url, + create_payload_fn(prompt), + sample_rate, + sample_width, + pbar, + ) + + start_time = time.perf_counter() + tasks = [asyncio.create_task(limited_request(p)) for p in request_prompts] + results: list[RequestResult] = await asyncio.gather(*tasks) + wall_time = time.perf_counter() - start_time + pbar.close() + + finally: + await session.close() + + # Compute stats + bench = compute_stats(results, wall_time) + bench.concurrency = max_concurrency + bench.num_prompts = num_prompts + + print_benchmark_results(bench) + + # Print sample errors + failed = [r for r in results if not r.success] + if failed: + for r in failed[:3]: + print(f" [ERROR] {r.error[:200]}") + + return bench + + +async def run_benchmark_sweep( + host: str, + port: int, + num_prompts: int, + concurrency_levels: list[int], + create_payload_fn: Callable[[str], dict], + sample_rate: int, + sample_width: int = 2, + num_warmups: int = 3, + request_timeout_s: float = 120.0, + config_name: str = "benchmark", + result_dir: str = "results", +) -> list[dict]: + """Run benchmarks across multiple concurrency levels and save results.""" + all_results = [] + + for concurrency in concurrency_levels: + result = await run_benchmark( + host=host, + port=port, + num_prompts=num_prompts, + max_concurrency=concurrency, + create_payload_fn=create_payload_fn, + sample_rate=sample_rate, + sample_width=sample_width, + num_warmups=num_warmups, + request_timeout_s=request_timeout_s, + ) + result.config_name = config_name + all_results.append(asdict(result)) + + save_results(all_results, result_dir, config_name) + return all_results From 8f0f4f766ef70477dda65b490be5799fb3015f54 Mon Sep 17 00:00:00 2001 From: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:10:57 -0400 Subject: [PATCH 3/5] style: fix ruff formatting Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> --- benchmarks/fish-speech/bench_voice_cache.py | 63 ++++++++++++------- .../models/fish_speech/fish_speech_slow_ar.py | 9 ++- 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/benchmarks/fish-speech/bench_voice_cache.py b/benchmarks/fish-speech/bench_voice_cache.py index 8572010a797..543b68206ec 100644 --- a/benchmarks/fish-speech/bench_voice_cache.py +++ b/benchmarks/fish-speech/bench_voice_cache.py @@ -130,10 +130,14 @@ async def run_round( for i in range(num_warmups): payload = create_payload_fn(PROMPTS[i % len(PROMPTS)]) r = await send_streaming_request( - session, api_url, payload, SAMPLE_RATE, SAMPLE_WIDTH, + session, + api_url, + payload, + SAMPLE_RATE, + SAMPLE_WIDTH, ) status = "OK" if r.success else f"FAIL: {r.error[:80]}" - print(f" warmup {i+1}: ttfp={r.ttfp*1000:.0f}ms {status}") + print(f" warmup {i + 1}: ttfp={r.ttfp * 1000:.0f}ms {status}") # Benchmark. print(f" [{label}] Running {num_prompts} requests (concurrency=1)...") @@ -143,13 +147,17 @@ async def run_round( prompt = PROMPTS[i % len(PROMPTS)] payload = create_payload_fn(prompt) r = await send_streaming_request( - session, api_url, payload, SAMPLE_RATE, SAMPLE_WIDTH, + session, + api_url, + payload, + SAMPLE_RATE, + SAMPLE_WIDTH, ) results.append(r) tag = "HIT" if i > 0 and label == "uploaded_voice" else "" print( - f" req {i+1:3d}: ttfp={r.ttfp*1000:7.1f}ms " - f"e2e={r.e2e*1000:7.1f}ms " + f" req {i + 1:3d}: ttfp={r.ttfp * 1000:7.1f}ms " + f"e2e={r.e2e * 1000:7.1f}ms " f"{'OK' if r.success else 'FAIL'} {tag}" ) wall_time = time.perf_counter() - start @@ -181,12 +189,12 @@ async def main(): sys.exit(1) ref_audio_b64 = encode_audio_to_base64(args.ref_audio) - print(f"Reference audio: {args.ref_audio} ({len(ref_audio_b64)//1024}KB base64)") + print(f"Reference audio: {args.ref_audio} ({len(ref_audio_b64) // 1024}KB base64)") # ---- Round A: Inline ref_audio (no cache) ---- - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("Round A: INLINE ref_audio (every request sends full audio)") - print(f"{'='*60}") + print(f"{'=' * 60}") def make_inline_payload(prompt: str) -> dict: return { @@ -200,26 +208,32 @@ def make_inline_payload(prompt: str) -> dict: } bench_inline = await run_round( - args.host, args.port, args.num_prompts, - make_inline_payload, "inline_ref_audio", + args.host, + args.port, + args.num_prompts, + make_inline_payload, + "inline_ref_audio", num_warmups=args.num_warmups, ) print_benchmark_results(bench_inline) # ---- Upload voice ---- - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("Uploading voice for cache test...") - print(f"{'='*60}") + print(f"{'=' * 60}") await delete_voice(args.host, args.port, args.voice_name) await upload_voice( - args.host, args.port, - args.ref_audio, args.ref_text, args.voice_name, + args.host, + args.port, + args.ref_audio, + args.ref_text, + args.voice_name, ) # ---- Round B: Uploaded voice (cache hits after 1st request) ---- - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("Round B: UPLOADED VOICE (cache hits after 1st request)") - print(f"{'='*60}") + print(f"{'=' * 60}") def make_uploaded_payload(prompt: str) -> dict: return { @@ -232,18 +246,21 @@ def make_uploaded_payload(prompt: str) -> dict: } bench_cached = await run_round( - args.host, args.port, args.num_prompts, - make_uploaded_payload, "uploaded_voice", + args.host, + args.port, + args.num_prompts, + make_uploaded_payload, + "uploaded_voice", num_warmups=args.num_warmups, ) print_benchmark_results(bench_cached) # ---- Comparison ---- - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("COMPARISON: Inline ref_audio vs Uploaded voice (cached)") - print(f"{'='*60}") + print(f"{'=' * 60}") print(f"{'Metric':<30} {'Inline':>12} {'Cached':>12} {'Speedup':>10}") - print(f"{'-'*64}") + print(f"{'-' * 64}") def fmt_speedup(inline_val: float, cached_val: float) -> str: if cached_val > 0 and inline_val > 0: @@ -262,8 +279,8 @@ def fmt_speedup(inline_val: float, cached_val: float) -> str: for label, a, b in rows: print(f"{label:<30} {a:>12.1f} {b:>12.1f} {fmt_speedup(a, b):>10}") - print(f"\nNote: Round B request #1 is a cache MISS (cold start).") - print(f" Requests #2+ are cache HITs (skip DAC encoding).") + print("\nNote: Round B request #1 is a cache MISS (cold start).") + print(" Requests #2+ are cache HITs (skip DAC encoding).") # Cleanup. await delete_voice(args.host, args.port, args.voice_name) diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 8f36e0b1334..3813597caad 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -538,7 +538,9 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any] ) else: _voice_cache_key = self._voice_cache.make_cache_key( - voice_name, xvec_only=False, created_at=_created_at, + voice_name, + xvec_only=False, + created_at=_created_at, ) _cached = self._voice_cache.get(_voice_cache_key) if _cached is not None: @@ -549,7 +551,10 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any] _voice_cache_key = None # hit → don't store again logger.debug("Voice cache HIT for Fish Speech voice '%s'", voice_name) return self._apply_codebook_embeddings( - tokenizer, text, ref_text, ref_codes_fq, + tokenizer, + text, + ref_text, + ref_codes_fq, ) if not isinstance(ref_audio_sr, int): From 060ee7f393d2e48286afeed7913e8adece22cc22 Mon Sep 17 00:00:00 2001 From: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:17:34 -0400 Subject: [PATCH 4/5] refactor: move Fish Speech voice cache test to model_executor/models/ Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> --- .../models/test_fish_speech_voice_cache.py | 218 ++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 tests/model_executor/models/test_fish_speech_voice_cache.py diff --git a/tests/model_executor/models/test_fish_speech_voice_cache.py b/tests/model_executor/models/test_fish_speech_voice_cache.py new file mode 100644 index 00000000000..8fe7a4a4d11 --- /dev/null +++ b/tests/model_executor/models/test_fish_speech_voice_cache.py @@ -0,0 +1,218 @@ +"""Tests for Fish Speech DAC-code caching via VoiceEmbeddingCache. + +Covers: + - Cache miss → DAC encode → store + - Cache hit → skip DAC encode, reuse cached ref_codes_fq + - Inline ref_audio (no voice name) → no caching, full encode path + - Stale-cache protection via created_at + - Temp file cleanup on cache hit +""" + +import os +import tempfile +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _make_info_dict( + *, + text: str = "Hello world", + ref_text: str = "Reference transcript", + ref_audio_sr: int = 44100, + voice_name: str | None = None, + voice_created_at: float | None = None, + ref_audio_path: str | None = None, +) -> dict: + """Build a minimal info_dict for _build_structured_voice_clone_prefill_embeds.""" + d: dict = { + "text": text, + "ref_text": ref_text, + "ref_audio_sr": ref_audio_sr, + "fish_structured_voice_clone": True, + } + if ref_audio_path is not None: + d["ref_audio_path"] = ref_audio_path + if voice_name is not None: + d["voice_name"] = voice_name + if voice_created_at is not None: + d["voice_created_at"] = voice_created_at + return d + + +def _write_temp_npy(wav: np.ndarray | None = None) -> str: + """Write a temporary .npy file with dummy audio and return its path.""" + if wav is None: + wav = np.random.randn(44100).astype(np.float32) # 1 second @ 44.1kHz + with tempfile.NamedTemporaryFile(prefix="fish_test_", suffix=".npy", delete=False) as f: + np.save(f, wav) + return f.name + + +# Fake ref_codes_fq: [frames, codebooks] +_FAKE_REF_CODES = torch.randint(0, 1024, (10, 10), dtype=torch.long) + + +class TestFishSpeechVoiceCacheIntegration: + """Test the cache-hit / cache-miss / no-cache paths in the model.""" + + @pytest.fixture + def mock_model(self): + """Create a mock FishSpeechSlowARForConditionalGeneration with cache.""" + from vllm_omni.utils.voice_cache import VoiceEmbeddingCache + + model = MagicMock() + model._voice_cache = VoiceEmbeddingCache(max_entries=4) + model._semantic_begin_id = 151678 + model._num_codebooks = 10 + model._codebook_size = 4096 + model.model_path = "/fake/model" + model.codebook_embeddings = MagicMock() + model.codebook_embeddings.weight = MagicMock() + model.codebook_embeddings.weight.device = torch.device("cpu") + return model + + def test_cache_miss_stores_codes(self, mock_model): + """First request with a named voice should encode and store in cache.""" + cache = mock_model._voice_cache + voice_name = "alice" + created_at = 1712345678.0 + + # Verify cache starts empty. + key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) + assert cache.get(key) is None + + # Simulate a cache store (what the model does on miss). + cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) + + # Verify it's now cached. + cached = cache.get(key) + assert cached is not None + assert torch.equal(cached["ref_codes_fq"], _FAKE_REF_CODES) + + def test_cache_hit_returns_cached_codes(self, mock_model): + """Second request with same voice should hit cache.""" + cache = mock_model._voice_cache + voice_name = "alice" + created_at = 1712345678.0 + + key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at) + cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()}) + + # Hit. + cached = cache.get(key) + assert cached is not None + ref_codes = cached["ref_codes_fq"].to(device=torch.device("cpu"), dtype=torch.long) + assert torch.equal(ref_codes, _FAKE_REF_CODES) + assert cache.stats()["hits"] >= 1 + + def test_no_voice_name_skips_cache(self, mock_model): + """Inline ref_audio without voice_name should not use cache.""" + cache = mock_model._voice_cache + + # Without voice_name, the model should not interact with cache at all. + info = _make_info_dict(voice_name=None, ref_audio_path=_write_temp_npy()) + assert info.get("voice_name") is None + # Cache should remain untouched. + assert cache.stats()["hits"] == 0 + assert cache.stats()["misses"] == 0 + + def test_stale_cache_on_reupload(self, mock_model): + """Re-uploading a voice (new created_at) should not hit old cache.""" + cache = mock_model._voice_cache + voice_name = "alice" + + key_old = cache.make_cache_key(voice_name, xvec_only=False, created_at=1000.0) + cache.put(key_old, {"ref_codes_fq": _FAKE_REF_CODES}) + + # Re-upload produces a different created_at. + key_new = cache.make_cache_key(voice_name, xvec_only=False, created_at=2000.0) + assert cache.get(key_new) is None # miss + assert cache.get(key_old) is not None # old still there + + def test_temp_file_cleaned_on_cache_hit(self): + """On cache hit, the temp .npy file written by the entrypoint should be deleted.""" + tmp_path = _write_temp_npy() + assert os.path.exists(tmp_path) + + # Simulate what the model does on cache hit: remove the temp file. + try: + os.remove(tmp_path) + except OSError: + pass + assert not os.path.exists(tmp_path) + + def test_created_at_zero_disables_cache(self, mock_model): + """created_at=0 should not create a cache key (caching disabled).""" + cache = mock_model._voice_cache + + info = _make_info_dict( + voice_name="bob", + voice_created_at=0.0, + ref_audio_path=_write_temp_npy(), + ) + # The model checks: if _created_at > 0 → enable cache. + # With 0.0, no cache interaction should happen. + _created_at = float(info.get("voice_created_at", 0)) + assert _created_at <= 0 + assert cache.stats()["hits"] == 0 + assert cache.stats()["misses"] == 0 + + +class TestFishSpeechValidatorUploadedVoice: + """Test _validate_fish_tts_request uploaded voice resolution.""" + + def test_uploaded_voice_resolves_ref_audio(self): + """When voice matches an uploaded speaker, ref_audio should be auto-set.""" + request = MagicMock() + request.input = "Hello" + request.voice = "alice" + request.ref_audio = None + request.ref_text = None + request.max_new_tokens = None + + # Uploaded speaker with ref_text. + uploaded_speakers = { + "alice": { + "file_path": "/tmp/fake_audio.wav", + "ref_text": "Hi this is Alice", + "created_at": 1712345678, + }, + } + + # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL. + with patch("pathlib.Path.exists", return_value=True): + voice_lower = request.voice.lower() + assert voice_lower in uploaded_speakers + + speaker_info = uploaded_speakers[voice_lower] + ref_text_from_upload = speaker_info.get("ref_text") + assert ref_text_from_upload == "Hi this is Alice" + + def test_uploaded_voice_without_ref_text_uses_request_ref_text(self): + """If upload has no ref_text but request provides it, use request's.""" + request = MagicMock() + request.input = "Hello" + request.voice = "bob" + request.ref_audio = None + request.ref_text = "Request-level transcript" + request.max_new_tokens = None + + uploaded_speakers = { + "bob": { + "file_path": "/tmp/fake_audio.wav", + "ref_text": None, + "created_at": 1712345678, + }, + } + + voice_lower = request.voice.lower() + speaker_info = uploaded_speakers[voice_lower] + upload_ref_text = speaker_info.get("ref_text") + # Upload has no ref_text, so request.ref_text should remain. + assert upload_ref_text is None + assert request.ref_text == "Request-level transcript" From 53d8e2f5700db8f52ef3b33668ee9b23c78a5107 Mon Sep 17 00:00:00 2001 From: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:18:05 -0400 Subject: [PATCH 5/5] fix(bench): use correct 'audio_sample' field name for voice upload Signed-off-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com> --- benchmarks/fish-speech/bench_voice_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/fish-speech/bench_voice_cache.py b/benchmarks/fish-speech/bench_voice_cache.py index 543b68206ec..8d465d6489f 100644 --- a/benchmarks/fish-speech/bench_voice_cache.py +++ b/benchmarks/fish-speech/bench_voice_cache.py @@ -84,7 +84,7 @@ async def upload_voice( if ref_text: data.add_field("ref_text", ref_text) data.add_field( - "file", + "audio_sample", open(audio_path, "rb"), filename=os.path.basename(audio_path), content_type="audio/wav",