Skip to content

[Bugfix] bound transfer save queue and tighten orchestrator idle backoff#3690

Open
JuanPZuluaga wants to merge 1 commit into
vllm-project:mainfrom
JuanPZuluaga:qwen3tts-c128-backpressure
Open

[Bugfix] bound transfer save queue and tighten orchestrator idle backoff#3690
JuanPZuluaga wants to merge 1 commit into
vllm-project:mainfrom
JuanPZuluaga:qwen3tts-c128-backpressure

Conversation

@JuanPZuluaga
Copy link
Copy Markdown
Contributor

@JuanPZuluaga JuanPZuluaga commented May 18, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Under high concurrency, the OmniTransferAdapterBase save path is unbounded: OmniChunkTransferAdapter.save_async appends to _pending_save_reqs faster than save_loop can drain it, so the deque (and the per-task GPU references it holds) bloats as concurrency increases. Pair fix: the orchestrator's idle
poll path sleeps 1 ms between empty passes, which inflates TTFA at every batch boundary even when work is queued elsewhere.

This PR caps in-flight saves with a Semaphore(512) (acquired in save_async, released in save_loop's finally) and drops the idle sleep to 0.1 ms. Surgical change — only the transfer adapter and orchestrator idle path are touched.

Test Plan

Test Result

I used bench_tts_continuity.py from @linyueqian for the underrun rate.

Comparison with qwen3_tts_high_concurrency.yaml

I used: https://github.com/vllm-project/vllm-omni/blob/main/vllm_omni/deploy/qwen3_tts_high_concurrency.yaml

but modified the config yaml where both stages are using the same GPU=0 (as I don't have access to 2).

c main UR_rate PR UR_rate main TTFT_p99 (ms) PR TTFT_p99 (ms) main E2EL_p99 (ms) PR E2EL_p99 (ms) main RTF_p99 PR RTF_p99
1 0.000 0.000 51 52 995 977 0.166 0.165
4 0.000 0.000 109 111 1123 1129 0.203 0.196
8 0.000 0.000 141 114 (-19%) 1522 1458 (-4%) 0.243 0.250
16 0.000 0.000 166 190 (noise) 2230 1953 (-12%) 0.319 0.329
32 0.000 0.000 358 441 (noise) 3267 3412 0.553 0.538
48 0.000 0.000 538 488 (-9%) 4824 4410 0.799 0.933
64 0.109 0.039 (-64%) 452 1081 5514 6050 1.193 1.093
96 0.979 0.995 (noise) 1901 674 (-65%) 8371 8369 1.860 1.881
128 0.762 0.781 (noise) 8919 8914 16525 17105 3.391 3.321 (-2%)

also:

  • c=64 (SLO boundary, headline win): underrun_rate -64% (0.109 → 0.039), UR_p99 -59% (461 → 188 ms). This is exactly where the unbounded save queue used to backlog.
  • c=96: TTFT_p99 -65% (1901 → 674 ms).
  • c=8: TTFT_p99 -19% (141 → 114 ms), E2EL_p99 -4%.
  • c=16: E2EL_p99 -12% (2230 → 1953 ms).
  • c=48: TTFT_p99 -9% (538 → 488 ms).
  • c=128: RTF_p99 -2% (3.391 → 3.321).
  • c=1 / c=4: unchanged (already optimal).

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

please also provide the status before this commit so that we can see the improvement

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

JuanPZuluaga commented May 18, 2026

@hsliuustc0106 thanks for the comments, I updated the PR description; the benefits come at high concurrency.

Tagging @linyueqian, as this should be aligned with #3535

(btw, thanks @Sy0307 for the amazing improvement in performance with: #3662)

@linyueqian
Copy link
Copy Markdown
Collaborator

Re-ran an A/B on a 2-GPU H20 split with the harness from #3618 (Qwen3-TTS-Base voice_clone, qwen3_tts_high_concurrency.yaml, 128 prompts at c=64 and 256 at c=128). Only Juan's 4-file commit differed between sides.

metric c=64 Δ c=128 Δ
p99 TTFT −4.5% +7.9%
p99 E2EL −19.9% +4.4%
mean RTF −17.8% +4.7%
p99 RTF −48.0% +9.2%
p99 audio_underrun −18.7% −22.6%
total token tput +17.7% −6.6%

Two things to flag:

  1. c=64 is a clear win across the board (not in your original table since you only had 1-GPU access). p99 RTF nearly halved, throughput up ~18%. Worth adding to the PR description.
  2. At c=128 only the p99 underrun improvement reproduces (−22.6%, matches your underrun-rate claim within noise). The TTFT/E2EL/RTF wins from your table don't show up on the 2-GPU split, and aggregate throughput actually drops ~7%.

My guess on the c=128 throughput dip: the global Semaphore(512) occasionally blocks the Stage 0 producer thread when the consumer is briefly slow, and on the 2-GPU split that costs Stage 0 useful work per step. Doesn't show up on single-GPU because Stage 0 and Stage 1 already contend for the same SM time. The right shape of fix is what RFC #3535 WS-4 envisions (per-stream credits so Stage 0 skips backpressured streams in the LM step rather than block the whole producer), but a smaller intermediate step is to make the cap a DeployConfig knob and let high-concurrency 2-GPU deploys set it higher than 512.

Net take: queue cap is the right idea and the underrun improvement is real and reproduces. Would be good to surface the c=128 throughput tradeoff in the PR description, and if you want the merge to be unconditional improvement, expose the cap as a config knob so the high-c regime can opt out.

Happy to share the raw JSONs if useful.

@linyueqian
Copy link
Copy Markdown
Collaborator

On making 512 adaptive: I don't think it's worth doing on a single global counter since one slow stream would still drag the cap for everyone, so I'd suggest just exposing it as a DeployConfig knob here and revisiting it as per-stream credits in a separate PR.

@linyueqian
Copy link
Copy Markdown
Collaborator

Re-ran on a different box (H200-hsliu, single-GPU slim deploy, GPU 3) to see if the c=128 picture is hardware-sensitive. Same bench-base vs bench-pr3690 setup, only Juan's 4-file commit differs.

metric H20 c=64 H200 c=64 H20 c=128 H200 c=128
p99 TTFT −4.5% −34.0% +7.9% −4.0%
p99 E2EL −19.9% −11.0% +4.4% +3.4%
p99 RTF −48.0% −21.7% +9.2% +1.3%
p99 audio_underrun −18.7% −10.6% −22.6% +7.7%
token throughput +17.7% +9.2% −6.6% −2.1%

c=64 is a robust win on both boxes. c=128 is the one that doesn't generalize: on H200 the single positive signal from H20 (p99 underrun −22.6%) flips to +7.7%, with throughput, E2EL, and mean underrun also slightly worse. That's consistent with the slim H200 deploy already being in a vocoder-saturated regime where there's nothing useful for the global queue cap to do.

Reinforces what I think the right move is: keep the PR small but expose the cap as a DeployConfig knob so high-c users on saturated profiles can opt out of the slight regression.

Raw JSONs available if you want them.

@Sy0307
Copy link
Copy Markdown
Collaborator

Sy0307 commented May 20, 2026

I don't object to bounding the save queue, but this backpressure point is too coarse: it turns downstream consumption pressure into a global block on the scheduler / producer path.

save_async() now synchronously blocks on:

self._save_semaphore.acquire()

This is not a wait inside a background worker; it is a wait on the producer path. If the real bottleneck is the downstream stage, vocoder, or connector, blocking the upstream producer does not make the downstream side faster. It only moves the waiting from the _pending_save_reqs queue into the scheduler thread. That may reduce queue growth, but it can also break pipeline overlap and prevent Stage 0 from doing useful work that could have been prepared ahead of time.

We previously tested a similar save/transfer-side throttling mechanism on Fish S2 / FastSR. It did not show a clear win, and in some configurations it slightly regressed, around ~5%. My read is that this is the same failure mode: when the downstream side is already close to saturation, global backpressure does not reduce actual compute; it only shrinks the producer-consumer parallelism window. It can also introduce head-of-line blocking: one slow stream or slow consumer can make the global semaphore throttle all following streams, instead of throttling only itself.

So I don't think this cap should be hard-coded to 512. This value is counted by task, not by payload bytes or by stream, so its meaning changes across models, hardware, and stage outputs. At minimum, it should be exposed as a DeployConfig knob so each deployment can raise it, lower it, or disable it. Longer term, the right shape is probably per-stream credits rather than a global producer block.

Separately, acquire() is currently an unbounded wait: no timeout, and no stop_event check. Under normal operation, as long as save_loop keeps draining, this is not a deadlock. But if the save loop stops, connector put() hangs, or teardown/stop happens while the producer is already blocked, this wait has no explicit exit semantics. Please make it a timeout / stop-aware acquire, and add a concurrency test covering a full semaphore, release after send exceptions, and the guarantee that the producer does not block forever when the consumer no longer drains.

@linyueqian linyueqian added ready label to trigger buildkite CI tts-test label to trigger buildkite tts models test in nightly CI labels May 26, 2026
@linyueqian
Copy link
Copy Markdown
Collaborator

@JuanPZuluaga any thoughts on @Sy0307's 5-20 comment? Two things line up with my c=128 H20/H200 numbers and I think both should land before merge:

  • _save_semaphore.acquire() blocks the producer with no timeout and no stop_event check. If save_loop exits early or teardown fires mid-block, the scheduler thread has no way out.
  • 512 is right for at most one hardware/yaml combo. Same code on a 2-GPU H20 split vs a single-GPU H200 slim deploy hits very different consumer pressure, so the cap should be configurable via DeployConfig rather than hard-coded.

What's your read?

@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

thanks for the comments! @Sy0307, I'm working on making it configurable. @linyueqian

@JuanPZuluaga JuanPZuluaga force-pushed the qwen3tts-c128-backpressure branch from 47f259a to 0245ea7 Compare May 26, 2026 14:51
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

I computed the numbers again @linyueqian

c main UR_rate PR UR_rate main TTFT_p99 (ms) PR TTFT_p99 (ms) main E2EL_p99 (ms) PR E2EL_p99 (ms) main RTF_p99 PR RTF_p99
1 0.000 0.000 54 55 1063 1064 0.167 0.167
4 0.000 0.000 114 118 1421 1474 0.196 0.202
8 0.000 0.000 132 115 (-13%) 1482 1511 0.240 0.255
16 0.000 0.000 164 196 (noise) 2067 2100 0.319 0.322
32 0.000 0.000 581 372 (-36%) 3908 3547 (-9%) 0.541 0.612
48 0.000 0.000 567 698 (noise) 4859 4742 0.915 0.826 (-10%)
64 0.078 0.125 512 509 5841 5573 (-5%) 1.127 1.165
96 0.974 0.984 1739 2322 (noise) 8469 8932 1.848 1.879
128 0.781 0.777 8689 8620 16630 16621 3.421 3.170 (-7%)

@JuanPZuluaga JuanPZuluaga force-pushed the qwen3tts-c128-backpressure branch 2 times, most recently from c4e7529 to dd38710 Compare May 28, 2026 12:35
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

@linyueqian i updated the PR body with the latest evaluation.

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm @Sy0307 ptal as well

@linyueqian linyueqian added this to the v0.22.0 milestone May 28, 2026
@linyueqian linyueqian removed the tts-test label to trigger buildkite tts models test in nightly CI label May 28, 2026
@linyueqian
Copy link
Copy Markdown
Collaborator

@Sy0307 here's an 8× H20 141GB sweep against your c=64 req=128 finding. Short answer: the regression doesn't reproduce as a hard hit on either GPU config, and the PR's claimed c=64 wins also don't appear. Net looks approximately neutral on this hardware.

Method: vllm bench serve --omni --backend openai-audio-speech --dataset-name seed-tts (en subset), Qwen3-TTS-12Hz-1.7B-Base voice_clone, c=1..128, PR head dd38710d vs origin/main 8935439. Both GPU configs tested.

Dual-card (qwen3_tts_high_concurrency.yaml as shipped, Stage 0 → GPU 0, Stage 1 → GPU 1)

c E2EL_p99 Δ TTFP_p99 Δ RTF_p99 Δ
8 +8.8% 🔴 -2.1% 🟢 +1.9% 🟡
16 +1.5% 🟡 +12.1% 🔴 +2.9% 🟡
32 +0.7% -1.7% 🟢 -14.5% 🟢
48 +3.5% 🟡 +2.2% 🟡 +0.6%
64 +1.1% 🟡 +2.3% 🟡 +4.0% 🟡
96 -0.9% -4.2% 🟢 +0.4%
128 +3.7% 🟡 +5.5% 🔴 -2.5% 🟢

Single-card (both stages co-located on one GPU, matching the PR description)

c E2EL_p99 Δ TTFP_p99 Δ RTF_p99 Δ
8 -5.4% 🟢 -4.9% 🟢 -3.2% 🟢
16 -4.2% 🟢 -0.9% +7.6% 🔴
32 -0.4% -7.2% 🟢 -4.3% 🟢
48 +1.1% 🟡 -3.9% 🟢 -4.5% 🟢
64 +5.8% 🔴 -0.8% -1.7% 🟢
96 +1.9% 🟡 +1.6% 🟡 +1.5% 🟡
128 +0.8% -1.2% 🟢 -4.3% 🟢

Δ legend: 🟢 ≤−1%, neutral within ±1%, 🟡 +1..+5%, 🔴 ≥+5%. Lower is better for all three.

c=64 dual-card lands at +1.1% E2EL_p99 (within run-to-run noise). c=64 single-card sits at +5.8% E2EL_p99 with RTF and TTFP both flat, so if that's a real signal it's in tail total latency, not streaming keep-up. c=64 is also the most statistically stable cell in each table (n=256), so it's the best comparison point.

Two caveats worth flagging:

  1. UR_rate not measured. vllm bench serve doesn't emit underrun rate, so the PR description's -64% UR_rate at c=64 headline is not verified here. The continuity bench that does compute it can't talk to Qwen3-TTS-Base without a per-request ref_audio (instant HTTP 400 from the voice_clone validator), so verifying that claim needs the script extended to attach a fixed reference audio.
  2. Low-c cells used 64–256 prompts vs the PR description's 512, so treat the big greens/reds at c=4 and c=8 as noisy.

Net: no dual-card blocker that I can see, but the c=64 single-card +5.8% E2EL is worth a second look. Happy to re-run any cell at num_prompts=512 if it helps tighten percentiles.

@linyueqian linyueqian removed this from the v0.22.0 milestone May 28, 2026
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need furthur investigation

@JuanPZuluaga JuanPZuluaga force-pushed the qwen3tts-c128-backpressure branch from dd38710 to 2aa7f23 Compare May 29, 2026 04:49
Signed-off-by: JuanPZuluaga <juanz9312@gmal.com>
@JuanPZuluaga JuanPZuluaga force-pushed the qwen3tts-c128-backpressure branch from 2aa7f23 to 1c1f13c Compare June 1, 2026 06:14
@JuanPZuluaga
Copy link
Copy Markdown
Contributor Author

@linyueqian thanks for the detailed test. I'll keep working on this and see how we can improve things. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants