Skip to content

[Perf] [TTS] Improve Fish Speech S2 Pro voice cloning TTFP#2145

Merged
linyueqian merged 5 commits into
vllm-project:mainfrom
Sy0307:dev/fish_clone_opt
Mar 25, 2026
Merged

[Perf] [TTS] Improve Fish Speech S2 Pro voice cloning TTFP#2145
linyueqian merged 5 commits into
vllm-project:mainfrom
Sy0307:dev/fish_clone_opt

Conversation

@Sy0307
Copy link
Copy Markdown
Contributor

@Sy0307 Sy0307 commented Mar 24, 2026

Purpose

This PR improves Fish Speech S2 Pro voice cloning TTFP.

Previously, the Fish Speech voice cloning path built the full clone prompt in the API server:

  • resolve ref_audio
  • run encode_reference_audio(...)
  • convert the reference audio into semantic tokens
  • construct the full system + user prompt before sending the request to the engine

That makes the API server do the heaviest part of voice cloning synchronously in the request path, which directly increases TTFP.

This PR moves that work closer to the model side and makes the reference-audio encoding path more device-native.

What changed

  • Fish Speech clone requests no longer run encode_reference_audio(...) in the API server request path.
  • The API server now passes structured clone information (ref_text, ref_audio_path, ref_audio_sr, text, etc.) instead of constructing the full clone prompt up front.
  • Fish Speech Slow AR now builds structured voice-clone prefill embeddings inside the worker.
  • The DAC encoder path is updated to support device-aware loading and encoding, with per (model_path, device, dtype) codec caching.
  • Reference-audio preprocessing is now tensor/device-native:
    • tensor conversion happens early
    • resampling uses cached device-side kernels
    • encoding can run directly on the target device
  • Fish Speech Slow AR preloads the DAC encoder during weight loading to reduce first-request overhead.

Why this helps

The main gain is removing the heavy synchronous voice-cloning preprocessing from the API server hot path.

Instead of:

  • API server CPU encode
  • API server prompt construction
  • request submission

the path becomes:

  • API server resolves reference audio and forwards structured info
  • worker handles reference-audio encoding and final clone prompt embedding construction

This reduces request-path blocking and significantly improves voice cloning first-token latency.

Results

Measured with the same request pattern using a 5-second reference clip:

Warm voice cloning TTFP improved from about 1.47s to about 0.125s. on RTX5090.

cc @linyueqian

Sy0307 added 2 commits March 24, 2026 10:26
Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 requested a review from hsliuustc0106 as a code owner March 24, 2026 17:39
@linyueqian linyueqian self-requested a review March 24, 2026 17:41
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f5f1435015

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

raise ValueError("Fish Speech structured voice clone requires integer ref_audio_sr")

ref_audio_wav = np.load(ref_audio_path)
os.remove(ref_audio_path)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid per-rank deletion of shared ref-audio file

_build_structured_voice_clone_prefill_embeds unconditionally removes ref_audio_path right after loading it, but gpu_model_runner._preprocess calls model.preprocess(...) on each worker rank, so tensor/pipeline parallel workers can race on the same path; once one rank deletes the file, another rank can fail at np.load with FileNotFoundError, making voice-clone requests flaky or failing outright in multi-GPU runs.

Useful? React with 👍 / 👎.

Comment on lines +987 to +989
with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f:
np.save(f, np.asarray(wav_samples, dtype=np.float32))
ref_audio_path = f.name
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Don't serialize clone audio as a local temp path

This path writes reference audio to a node-local temp file and passes only the filename through additional_information; the serialization layer transports that as a scalar string, not file contents, so deployments where the slow-AR worker is not on the same filesystem (e.g., disaggregated/non-mp executors) cannot open the file in np.load(ref_audio_path) and voice cloning fails.

Useful? React with 👍 / 👎.

logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)

codec_device = self.codebook_embeddings.weight.device
codec_dtype = torch.bfloat16 if codec_device.type == "cuda" else torch.float32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be used for other platform as well? cc @gcanlin

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not directly use bf16? Falling back to float32 seems to be unnecessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. This part is migrated from dac_encoder. Thanks for suggestions.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DAC codecs are sensitive to precision. bfloat16 may produce different VQ codes which would change the voice cloning quality. need some testing.

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still testing

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested locally with s2-pro on A100. Voice cloning e2e works — 3.72s audio generated at 44.1kHz.

Note: Fish Speech requires architectures field in config.json (added locally for testing, might need a fix upstream or in the model repo). Also needed to lower gpu_memory_utilization in stage config for our GPU setup. Neither issue is introduced by this PR.

LGTM.

Signed-off-by: Sy03 <1370724210@qq.com>
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested online serving locally with s2-pro on A100, voice cloning with streaming.

Run TTFA Total
Cold 4,978ms 8,891ms
Warm #1 367ms 4,028ms
Warm #2 374ms 3,149ms

Warm TTFA ~370ms on A100. Audio output sounds correct (4.13s @ 44.1kHz for non-streaming).

Note: Fish Speech currently needs architectures added to config.json manually to start on vLLM 0.18.0. Not related to this PR.

LGTM.

@linyueqian linyueqian added the ready label to trigger buildkite CI label Mar 25, 2026
@linyueqian
Copy link
Copy Markdown
Collaborator

Could you verify that bfloat16 doesn't degrade voice cloning quality? Or keep float32 for the codec to be safe. the speed gain from bfloat16 on the encoder is minimal since it only runs once per request.

@linyueqian
Copy link
Copy Markdown
Collaborator

The bfloat16 change on the DAC encoder might affect voice cloning quality. I compared the token outputs:

from vllm_omni.model_executor.models.fish_speech.dac_encoder import (
    encode_reference_audio, _load_dac_codec, _codec_cache, DAC_SAMPLE_RATE, _get_resample_kernel,
)
import soundfile as sf, torch, numpy as np

wav, sr = sf.read("zero_shot_prompt.wav")
wav = np.asarray(wav, dtype=np.float32)

# bfloat16 (PR default on CUDA)
_codec_cache.clear()
ids_bf16 = encode_reference_audio("s2-pro", wav, sr, device="cuda")

# float32 on CUDA
_codec_cache.clear()
codec = _load_dac_codec("s2-pro", device="cuda", dtype=torch.float32)
wav_t = torch.as_tensor(wav).to(device="cuda", dtype=torch.float32).flatten()
resampler = _get_resample_kernel(sr, DAC_SAMPLE_RATE, "cuda", 0, "float32")
wav_t = resampler(wav_t.unsqueeze(0)).squeeze(0).unsqueeze(0).unsqueeze(0)
codes, _ = codec.encode(wav_t, torch.tensor([wav_t.shape[-1]], device="cuda"))
ids_f32 = [151678 + c for c in codes[0, 0, :].cpu().tolist()]

diffs = sum(1 for a, b in zip(ids_bf16, ids_f32) if a != b)
print(f"bf16: {len(ids_bf16)} tokens, f32: {len(ids_f32)} tokens, diffs: {diffs}/{len(ids_f32)} ({diffs/len(ids_f32)*100:.1f}%)")
# Output: bf16: 75 tokens, f32: 75 tokens, diffs: 22/75 (29.3%)

29.3% of semantic tokens differ between bfloat16 and float32. Could you verify on your side that bfloat16 doesn't degrade voice cloning quality?

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Mar 25, 2026

29.3% of semantic tokens differ between bfloat16 and float32. Could you verify on your side that bfloat16 doesn't degrade voice cloning quality?

Here I compared multiple audio samples from DAC using fp32 and bf16. When using ASR, the recognition similarity rate of fp32 is significantly higher than bf16 (97% vs 75% on average). I also observed that audio quality using fp32 is more stable, while bf16 randomly exhibits sudden changes in audio volume and timbre. Therefore, DAC will use fp32 by default. cc @linyueqian @gcanlin @hsliuustc0106

@linyueqian linyueqian merged commit fc32da7 into vllm-project:main Mar 25, 2026
8 checks passed
zhangj1an pushed a commit to zhangj1an/vllm-omni that referenced this pull request Mar 26, 2026
zhangj1an pushed a commit to zhangj1an/vllm-omni that referenced this pull request Mar 26, 2026
@ukemamaster
Copy link
Copy Markdown

ukemamaster commented Apr 8, 2026

@Sy0307

Is this already merged in release 0.18.0?

After this PR, will the server reuse cached ref_audio (or speaker information) from previous request?

How should i send ref_audio in request to the server? base64 encoded or string path? I tried to pass string but got error:

Error during stream: Server error (400): {"error":{"message":"ref_audio must be a URL (http/https), base64 data URL (data:...), or file URI (file://...)","type":"BadRequestError","param":null,"code":400}}

I used this code for inference request.

Do you have any example code to send request to the server which will reuse the ref_audio from previous request?

lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants