Skip to content

[Perf][Bugfix] cache hot buffers in qwen3_tts talker; fall back on evicted state#3688

Merged
linyueqian merged 4 commits into
vllm-project:mainfrom
JuanPZuluaga:qwen3tts-talker
May 28, 2026
Merged

[Perf][Bugfix] cache hot buffers in qwen3_tts talker; fall back on evicted state#3688
linyueqian merged 4 commits into
vllm-project:mainfrom
JuanPZuluaga:qwen3tts-talker

Conversation

@JuanPZuluaga
Copy link
Copy Markdown
Contributor

@JuanPZuluaga JuanPZuluaga commented May 18, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

few optimizations for the talker satge of Qwen3TTS, which give some benefits in concurrency=64 and fewer:

  1. Per-call allocations. mel_spectrogram reallocated the mel filter bank and Hann window on every call; the voice-cloning path constructed a fresh AudioResampler per request.
  2. .item() host sync in ref_code_len emit. Forced a D2H stall per request span.
  3. Eviction crash. tts_pad_embed and hidden_states['last'] could be evicted by _update_states' finished_req_ids cleanup before the talker's final decode step, raising RuntimeError("Missing tts_pad_embed in additional_information; prefill must run first.") and killing in-flight streams.
  4. Q-1 serial embedding kernels in talker_mtp's residual codebook path: Q-1 separate embedding calls per step instead of a single gather over a stacked weight.

Test Plan

Test Result

c reqs ok TTFT p99 (ms) E2EL p99 (ms) Underrun p99 (ms) Underrun rate RTF p99
1 8/8 46 985 0 0.00 0.16
4 8/8 82 1143 0 0.00 0.21

I used bench_tts_continuity.py from @linyueqian for the underrun rate.

Comparison with qwen3_tts_high_concurrency.yaml

I used: https://github.com/vllm-project/vllm-omni/blob/main/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml

but modified the config yaml where both stages are using the same GPU=0 (as I don't have access to 2).

c main rate PR rate main TTFT p99 PR TTFT p99 main E2EL p99 PR E2EL p99 main RTF p99 PR RTF p99
1 0.00 0.00 51 52 1101 1045 0.17 0.17
4 0.00 0.00 115 103 1582 1152 0.22 0.21
8 0.00 0.00 137 138 2229 1576 0.27 0.27
16 0.00 0.00 222 264 2852 2511 0.37 0.36
32 0.00 0.00 431 496 4361 4052 0.64 0.65
48 0.00 0.00 969 843 5197 6062 0.85 0.86
64 0.11 0.13 1613 1360 6652 6720 1.11 1.11
96 0.03 0.09 4894 4677 10497 11269 1.89 1.78
128 0.09 0.04 9163 7029 14861 13241 2.69 2.35

ASR quality at c=16: 16/16 pass at CER 0.012 (non-stream), 16/16 at 0.011 (stream). No regression.

Headline gains (decode-heavy regime):

  • c=128 underrun: 0.09 → 0.04 (over halved)
  • c=128 TTFT p99: 9163 → 7029 ms (-23%)
  • c=128 RTF p99: 2.69 → 2.35 (-13%)
  • c=128 E2EL p99: 14.9s → 13.2s (-11%)
  • c=4/8/16 E2EL drops 11–29% (warm-up improves)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 11f6dc321f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +2137 to +2139
key = (orig_sr, target_sr)
if key not in self._resampler_cache:
self._resampler_cache[key] = AudioResampler(target_sr=target_sr)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Bound resampler cache growth

The new _resampler_cache never evicts entries, and its key includes the request-provided source sample rate (orig_sr), so a long-lived service can accumulate one AudioResampler per distinct input rate indefinitely. In contexts where users can upload arbitrary audio metadata, this becomes an unbounded memory growth path that did not exist before this change (previously the resampler was ephemeral per call). Consider capping this cache (LRU) or normalizing/whitelisting accepted sample rates before caching.

Useful? React with 👍 / 👎.

Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

Tagging @linyueqian, as this should be aligned wit #3535

@linyueqian
Copy link
Copy Markdown
Collaborator

linyueqian commented May 20, 2026

Verification (H20-3e + L20X, c=128 / 512 reqs each)

Ran A/B against pr3688^ (6d37e77f) on two single-GPU configs, both stages co-located on one device, identical deploy yaml across both branches.

H20-3e (vllm bench serve --backend openai-audio-speech on seed-tts en, task_type=Base, stage 0 max_num_seqs=128, gpu_memory_utilization=0.4):

metric base pr3688 delta
completed / failed 512 / 0 512 / 0
wall (s) 67.07 66.88 -0.3%
mean TTFT (ms) 3823 3634 -4.9%
median TTFT (ms) 3799 3269 -13.9%
P99 TTFT (ms) 11167 11586 +3.7%
P99 E2EL (ms) 25902 25447 -1.8%
P99 audio_rtf 5.364 5.153 -3.9%

L20X (same deploy, Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice, fixed prompt, voice=vivian, async client at concurrency 128):

metric base pr3688 delta
completed / failed 512 / 0 512 / 0
wall (s) 67.32 64.11 -4.8%
throughput (req/s) 7.61 7.99 +5.0%
lat mean (s) 16.18 15.32 -5.3%
lat P90 (s) 19.28 17.72 -8.1%
lat P99 (s) 21.57 19.94 -7.6%
warmup #0 (s) 57.24 6.58 -88% (cold-start amortization)

Server logs grepped for RuntimeError|Missing tts_pad_embed|Missing hidden_states: 0 across both branches on both hosts. Audio outputs from base and PR sound equivalent (subjective listen on the saved wavs).

[Low] The defensive fallback is wired into preprocess() but preprocess_decode_batch at L797 and L828 still hard-raises on the same Missing tts_pad_embed / Missing hidden_states['last'] condition. Did not reproduce eviction in any of these runs, but worth a follow-up if the batched concurrent path can hit it under different scheduling pressure.

LGTM otherwise. Nice cleanup of the per-step allocations and the for i in range(Q-1) embedding loop.

Comment on lines +260 to +269
@lru_cache(maxsize=8)
def _cached_mel_filter_bank(sampling_rate: int, n_fft: int, n_mels: int, fmin: int, fmax: int | None) -> torch.Tensor:
return mel_filter_bank(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)


@lru_cache(maxsize=8)
def _cached_hann_window(win_size: int) -> torch.Tensor:
return torch.hann_window(win_size)


Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some kind of global cache for these kind of objects? Like there are other systems that also extract mel_filter_bank or do resampling. @linyueqian

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense to me. this should be a common module.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, i'll work on another PR on making this for all the audio models.

@linyueqian
Copy link
Copy Markdown
Collaborator

linyueqian commented May 26, 2026

Fresh A/B on 2x L20X with the bundled qwen3_tts_high_concurrency.yaml against main cad25fa2 and your c2d5bb9a. 26 cells, headline numbers below.

default_voice (1.7B-CustomVoice, seed-tts EN bucket text<=50)

c main TTFP p99 (ms) PR TTFP p99 (ms) Δ
1 49 53 flat
4 178 118 -34%
16 384 358 -7%
48 1577 1448 -8%
64 2742 2039 -26%
128 5111 5094 flat

voice_clone (0.6B-Base, same bucket)

c main TTFP p99 (ms) PR TTFP p99 (ms) Δ
1 173 174 flat
16 2386 1163 -51%
64 3473 3431 flat
128 8651 8301 -4%

c=4/c=64 default and c=16 voice_clone reproduce the win on this hardware class (different from your single-GPU H100 baseline). No regression at high c. Approving.

One follow-up though: the wins flatten at c>=64 voice_clone and c>=96 default, which makes sense since your PR targets per-prefill CPU-side allocation (mel/hann/resampler) and the AR decode step dominates once concurrency saturates the GPU. Would you have appetite for a second pass at the decode-bound regime, e.g. fused MTP kernel, fuller decode CUDA graph capture, or trimming D2H syncs that remain in talker_mtp? @Sy0307 might want to weigh in too given the overlap with #3662.

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See bench numbers above. Approving.

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

@linyueqian i could work on this in a follow up PR.

@Sy0307
Copy link
Copy Markdown
Collaborator

Sy0307 commented May 27, 2026

Additional Qwen3-TTS validation

I ran an additional A/B validation.

Setup:

  • baseline: origin/main
  • PR: c2d5bb9a
  • hardware: 2-GPU CUDA setup
  • deploy config: bundled qwen3_tts_high_concurrency.yaml
  • dataset: SeedTTS EN text<=50 bucket
    • 256 rows, oversampled to 512 requests per cell
  • request count: 512 requests per concurrency cell
  • metrics: audio_ttfp, e2el, audio_rtf, throughput, completed/failed

default_voice: Qwen3-TTS-12Hz-1.7B-CustomVoice

c main p99 TTFP PR p99 TTFP Δ main p99 E2EL PR p99 E2EL Δ main p99 RTF PR p99 RTF Δ
1 72.42 65.43 -9.7% 1011.00 1041.22 +3.0% 0.239 0.219 -8.2%
4 120.60 120.63 flat 1248.69 1194.63 -4.3% 0.264 0.269 +1.9%
16 1624.69 303.35 -81.3% 4127.57 1981.29 -52.0% 1.359 0.405 -70.2%
48 1641.59 1585.77 -3.4% 3103.53 3061.71 -1.3% 1.238 1.210 -2.2%
64 2321.00 2444.97 +5.3% 3272.17 3234.67 -1.1% 1.688 1.771 +4.9%
128 6328.08 6544.10 +3.4% 6703.53 6914.36 +3.1% 3.757 3.852 +2.5%

For default_voice, the clearest win is at c=16. The PR also slightly improves E2EL at c=4/48/64. At c=128, this run shows a small regression across p99 TTFP/E2EL/RTF.

voice_clone: Qwen3-TTS-12Hz-0.6B-Base

c main p99 TTFP PR p99 TTFP Δ main p99 E2EL PR p99 E2EL Δ main p99 RTF PR p99 RTF Δ
1 322.32 278.98 -13.4% 918.61 876.11 -4.6% 0.273 0.263 -3.7%
16 3340.98 590.73 -82.3% 4276.78 1848.90 -56.8% 1.729 0.618 -64.2%
64 3808.56 3649.81 -4.2% 6250.63 5469.54 -12.5% 3.043 2.890 -5.0%
128 10798.29 9496.98 -12.1% 11283.71 10134.39 -10.2% 6.231 4.992 -19.9%

For voice_clone, the PR shows consistent improvement across all tested concurrency levels. The largest gain is at c=16, and the higher-concurrency cases (c=64/128) also improve.

Mean / median observations

This does not look like a p99-only improvement. In the voice_clone path, mean and median generally improve as well:

  • voice_clone c=16
    • mean E2EL: -8.6%
    • median E2EL: -2.7%
    • p99 E2EL: -56.8%
  • voice_clone c=64
    • mean E2EL: -4.9%
    • median E2EL: -3.9%
    • p99 E2EL: -12.5%
  • voice_clone c=128
    • mean E2EL: -7.3%
    • median E2EL: -5.6%
    • p99 E2EL: -10.2%

For default_voice, the improvement is more concentrated in the mid-concurrency regime, especially c=16.

Conclusion

LGTM from my side. The results are directionally consistent with the earlier validation: this PR helps the Qwen3-TTS hot path, especially for voice_clone and for the mid-concurrency regime.

One non-blocking caveat: the benefit is workload- and saturation-regime-dependent. In this run, default_voice c=128 regresses slightly in p99 TTFP/E2EL/RTF, so I would avoid describing this as a uniform high-concurrency win. It is still a net positive overall, especially for voice_clone. But we can work on it in follow-up PRs.

Thanks for nice work ;)

@linyueqian linyueqian enabled auto-merge (squash) May 27, 2026 18:08
@linyueqian linyueqian added ready label to trigger buildkite CI tts-test label to trigger buildkite tts models test in nightly CI labels May 28, 2026
@linyueqian linyueqian merged commit 0f3264a into vllm-project:main May 28, 2026
8 checks passed
@JuanPZuluaga JuanPZuluaga deleted the qwen3tts-talker branch May 28, 2026 06:43
zengchuang-hw pushed a commit to zengchuang-hw/vllm-omni that referenced this pull request Jun 1, 2026
…icted state (vllm-project#3688)

Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Co-authored-by: JuanPZuluaga <juanz9312@gmal.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
86MaxCao pushed a commit to 86MaxCao/vllm-omni that referenced this pull request Jun 4, 2026
…icted state (vllm-project#3688)

Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
Co-authored-by: JuanPZuluaga <juanz9312@gmal.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI tts-test label to trigger buildkite tts models test in nightly CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants