Skip to content

[Perf][Fish Speech] Enable CUDA Graph capture for Fast AR code predictor#2520

Merged
linyueqian merged 6 commits into
vllm-project:mainfrom
Sy0307:fish-speech-fast-ar-cudagraph
Apr 13, 2026
Merged

[Perf][Fish Speech] Enable CUDA Graph capture for Fast AR code predictor#2520
linyueqian merged 6 commits into
vllm-project:mainfrom
Sy0307:fish-speech-fast-ar-cudagraph

Conversation

@Sy0307
Copy link
Copy Markdown
Contributor

@Sy0307 Sy0307 commented Apr 6, 2026

Purpose

Enable CUDAGraphWrapper for Fish Speech S2 Pro's Fast AR (residual codebook predictor), reducing inference latency by 52.7% on H20.

Fish Speech uses a Dual AR architecture: Slow AR (Qwen3-4B, piecewise CUDA Graph) + Fast AR (4 layers, 9-step AR loop). The Fast AR ran entirely in eager mode — profiling showed it accounts for 63% of steady-state decode time (~13ms/step x 73 steps).

Key changes (scoped to gpu_ar_model_runner + fish_speech files, gpu_model_runner.py untouched):

  • Opt-in talker_mtp_graph_safe attribute for TTS models to enable CUDAGraphWrapper
  • _capture_talker_mtp_graphs() for explicit warmup+capture after capture_model() (_dummy_run has no decode requests, so talker_mtp misses the normal capture window)
  • Replace semantic_mask.any() (host-device sync) with torch.where
  • Disable torch.compile inside Fast AR when outer CUDA Graph is active (compile guards don't re-execute during graph replay)
  • try-except fallback to eager on capture failure

Only affects models with talker_mtp_graph_safe = True. Qwen3-Omni/Qwen3-TTS unaffected.

Test Plan

  • Fish Speech S2 Pro offline inference — audio generated successfully
  • A/B benchmark: 5+5 runs, all with audio output verified
  • UTMOS audio quality: no regression (4.516 -> 4.493)
  • Qwen3-Omni offline inference — verify no regression
  • Qwen3-TTS offline inference — verify no regression

Test Result

Benchmark (H20, single request, 5 runs, both stages completed):

Baseline + CUDAGraphWrapper Improvement
Mean 3375.6 ms 1598.3 ms -52.7%
Stdev 60.6 ms 2.3 ms

Audio quality (UTMOS):

Baseline Optimized
Mean 4.516 4.493
Verdict PASS PASS

Related PR: #2515

cc @linyueqian @zwhzzz0821

@Sy0307 Sy0307 requested a review from hsliuustc0106 as a code owner April 6, 2026 13:25
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@Sy0307 Sy0307 force-pushed the fish-speech-fast-ar-cudagraph branch from 4546791 to 33900d1 Compare April 6, 2026 13:27
@linyueqian
Copy link
Copy Markdown
Collaborator

[Bug] Startup crash on Fish Speech S2 Pro - AttributeError: has_talker_mtp

I tried running this PR on H20 (8x H20-3e, vllm 0.19.0, torch 2.10.0+cu128) and hit an immediate crash in the worker subprocess:

(Worker pid=2404860) File "vllm_omni/worker/gpu_ar_model_runner.py", line 86, in __init__
(Worker pid=2404860)     self.has_talker_mtp
(Worker pid=2404860) AttributeError: 'GPUARModelRunner' object has no attribute 'has_talker_mtp'

Root cause: has_talker_mtp is set in OmniGPUModelRunner.load_model() (gpu_model_runner.py:96), which runs later in the lifecycle. But the new CUDAGraphWrapper wrapping code in GPUARModelRunner.__init__ accesses it right after super().__init__() - before load_model() is ever called.

Suggested fix: move the wrapping logic from __init__ to load_model() (after super().load_model()).


Two other issues from code review:

  1. Fallback path loses both optimizations: when graph capture fails in _capture_talker_mtp_graphs, the except block restores eager talker_mtp but _disable_compile_for_graph is still True on fast_ar. The fallback ends up running without CUDA Graph AND without torch.compile. Should reset fast_ar._disable_compile_for_graph = False and fast_ar._compile_attempted = False on failure.

  2. Manual inference_mode context management: _capture_talker_mtp_graphs manually calls inference_mode.__enter__() / __exit__() instead of a with statement. Consider using with torch.inference_mode(): for safety.

The semantic_mask.any() to torch.where refactor looks correct.

Could not verify the 52.7% speedup claim due to the startup crash. Happy to re-test once the __init__ vs load_model ordering is fixed.

@Sy0307 Sy0307 force-pushed the fish-speech-fast-ar-cudagraph branch 2 times, most recently from 761abf1 to d061c52 Compare April 6, 2026 19:57
@linyueqian
Copy link
Copy Markdown
Collaborator

Verified on H20 (single GPU, single request, "The quick brown fox jumps over the lazy dog."):

Baseline (main) + CUDAGraphWrapper Improvement
Mean (runs 2-5) 2088.6 ms 1109.5 ms -46.9%
Stdev ~30 ms ~6 ms

Audio quality sounds similar. Speedup is real and significant.

One bug: load_model() was inserted after the existing @torch.inference_mode() decorator at L141, which steals it from execute_model (L229). Fix: add @torch.inference_mode() back before execute_model and remove it from load_model.

Also getattr(self.model, "talker_mtp") in load_model and the exception handler can just be self.model.talker_mtp since has_talker_mtp is already checked.

Comment thread vllm_omni/worker/gpu_ar_model_runner.py
Enable CUDAGraphWrapper for Fish Speech S2 Pro's Fast AR via opt-in
talker_mtp_graph_safe attribute.

- Wrap talker_mtp in CUDAGraphWrapper in GPUARModelRunner.load_model
  (not __init__, since has_talker_mtp is set during load_model)
- Add _capture_talker_mtp_graphs() for explicit warmup+capture after
  capture_model() completes; capture largest bsz first to pre-allocate
  Fast AR internal buffers at max size (avoids buffer reallocation
  invalidating previously captured graphs)
- Replace semantic_mask.any() with torch.where (graph-safe)
- Disable torch.compile inside Fast AR when outer graph is active
- Fallback to eager on capture failure with compile state reset

Only affects models with talker_mtp_graph_safe = True.
gpu_model_runner.py is untouched.

Benchmark (H20, Fish Speech S2 Pro, vllm 0.19.0):
  Baseline: 2048ms -> Optimized: 955ms (-53.4%)

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 force-pushed the fish-speech-fast-ar-cudagraph branch from d061c52 to af8568e Compare April 7, 2026 06:55
@gcanlin
Copy link
Copy Markdown
Collaborator

gcanlin commented Apr 8, 2026

cc @ZeldaHuang

@linyueqian
Copy link
Copy Markdown
Collaborator

Tested this on H20-3e (143GB) with the same profiling setup used for the sglang-omni comparison in #2515. Nice improvement over baseline!

One finding: combining your model runner changes with a torch.compile(dynamic=False, epilogue_fusion=False) + full-buffer forward in FishSpeechFastAR gives an additional ~15% per-step speedup on top of the outer CUDAGraphWrapper:

Metric Baseline PR #2520 #2520 + torch.compile + full-buffer
Per-step Fast AR 73 ms 56 ms 48 ms
E2E 1800 ms 1270 ms 1180 ms
RTF 0.48 0.38 0.33

The key difference: the current PR disables torch.compile inside Fast AR when the outer graph is active (_disable_compile_for_graph = True), which means the 4-layer transformer runs eager inside the graph. If we instead keep torch.compile(dynamic=False, epilogue_fusion=False) active and change the forward to always process the full [padded_bsz, max_seq, hidden] buffer (matching the Qwen3 TTS CodePredictor pattern), the compiled kernels fuse better and we get ~48ms/step vs 56ms.

Happy to share the fish_speech_fast_ar.py diff if you want to integrate it — the model runner changes in this PR are the right approach regardless.

@linyueqian
Copy link
Copy Markdown
Collaborator

Suggestion for an additional ~15% per-step speedup:

Instead of disabling torch.compile when the outer graph is active, keep it enabled with fixed-shape forward. The key changes in fish_speech_fast_ar.py:

  1. Don't disable compile -- remove the _disable_compile_for_graph path. Instead use:
self._compiled_model_fwd = torch.compile(
    self.model.forward,
    dynamic=False,
    options={"epilogue_fusion": False},
)
  1. Full-buffer forward -- always forward [padded_bsz, max_seq, hidden] instead of [:seq_len]:
# In _ensure_buffers: allocate for max_cudagraph_capture_size
max_bsz = max(
    self._vllm_config.scheduler_config.max_num_seqs,
    self._vllm_config.compilation_config.max_cudagraph_capture_size,
    1,
)
self._embed_buf = torch.zeros(max_bsz, max_seq, self._fast_dim, ...)

# In forward: pad batch, zero buffer, forward full shape
padded_bsz = self._padded_bsz(bsz)
embed_buf[:padded_bsz].zero_()
# ... fill positions 0 and 1 ...

# Each step: fixed-shape forward, then index the right position
for step in range(1, num_cb):
    hidden_out = model_fwd(embed_buf[:padded_bsz, :max_seq, :], pos_ids)
    logits = self.fast_output(self.fast_norm(hidden_out[:bsz, step, :]))
    # ... sampling ...

This matches the Qwen3 TTS CodePredictor pattern exactly. The fixed shape lets torch.compile fuse kernels optimally inside the CUDA graph, giving 48ms/step vs 56ms/step on H20.

The full diff is on branch linyueqian/vllm-omni:exp/fish-speech-fast-ar-cudagraph if you want to cherry-pick the fish_speech_fast_ar.py changes.

Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
Comment thread vllm_omni/worker/gpu_ar_model_runner.py
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Apr 8, 2026

After merging @linyueqian 's suggestion, I start a new benchmark:

┌─────────────┬────────────┬─────────────────┬─────────────────┬─────────────┐
│ Concurrency │   Metric   │    Baseline     │    Optimized    │ Improvement │
├─────────────┼────────────┼─────────────────┼─────────────────┼─────────────┤
│     c=4     │ E2E        │ 3864 ms         │ 2378 ms         │ -38%        │
│             │ Latency    │                 │                 │             │
├─────────────┼────────────┼─────────────────┼─────────────────┼─────────────┤
│     c=4     │ Throughput │ 11.9            │ 17.6            │ +48%        │
│             │            │ audio_sec/s     │ audio_sec/s     │             │
├─────────────┼────────────┼─────────────────┼─────────────────┼─────────────┤
│    c=10     │ E2E        │ 7166 ms         │ 4964 ms         │ -31%        │
│             │ Latency    │                 │                 │             │
├─────────────┼────────────┼─────────────────┼─────────────────┼─────────────┤
│    c=10     │ Throughput │ 5.86            │ 9.41            │ +61%        │
│             │            │ audio_sec/s     │ audio_sec/s     │             │
└─────────────┴────────────┴─────────────────┴─────────────────┴─────────────┘

Sy0307 added 2 commits April 9, 2026 00:05
…Graph

- Extend CUDAGraphWrapper wrap condition with talker_mtp_graph_safe opt-in
- Enable torch.compile(dynamic=True, epilogue_fusion=False) inside graph
- Use compiled forward for all batch sizes in graph mode
- Replace semantic_mask.any() with torch.where for graph compatibility
- Add clamp(max=codebook_size-1) for codebook index safety
- Clean fallback state reset (_compiled_model_fwd=None)

Signed-off-by: Sy03 <1370724210@qq.com>
Comment thread vllm_omni/worker/gpu_ar_model_runner.py Outdated
Sy0307 added 3 commits April 9, 2026 19:57
Signed-off-by: Sy03 <1370724210@qq.com>
…t fallback

Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Sy03 <1370724210@qq.com>
@linyueqian linyueqian added the ready label to trigger buildkite CI label Apr 13, 2026
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@linyueqian linyueqian enabled auto-merge (squash) April 13, 2026 04:20
@linyueqian linyueqian merged commit cb4d13a into vllm-project:main Apr 13, 2026
8 checks passed
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants