[Perf][Fish Speech] Enable CUDA Graph capture for Fast AR code predictor#2520
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
4546791 to
33900d1
Compare
|
[Bug] Startup crash on Fish Speech S2 Pro - I tried running this PR on H20 (8x H20-3e, vllm 0.19.0, torch 2.10.0+cu128) and hit an immediate crash in the worker subprocess: Root cause: Suggested fix: move the wrapping logic from Two other issues from code review:
The Could not verify the 52.7% speedup claim due to the startup crash. Happy to re-test once the |
761abf1 to
d061c52
Compare
|
Verified on H20 (single GPU, single request, "The quick brown fox jumps over the lazy dog."):
Audio quality sounds similar. Speedup is real and significant. One bug: Also |
Enable CUDAGraphWrapper for Fish Speech S2 Pro's Fast AR via opt-in talker_mtp_graph_safe attribute. - Wrap talker_mtp in CUDAGraphWrapper in GPUARModelRunner.load_model (not __init__, since has_talker_mtp is set during load_model) - Add _capture_talker_mtp_graphs() for explicit warmup+capture after capture_model() completes; capture largest bsz first to pre-allocate Fast AR internal buffers at max size (avoids buffer reallocation invalidating previously captured graphs) - Replace semantic_mask.any() with torch.where (graph-safe) - Disable torch.compile inside Fast AR when outer graph is active - Fallback to eager on capture failure with compile state reset Only affects models with talker_mtp_graph_safe = True. gpu_model_runner.py is untouched. Benchmark (H20, Fish Speech S2 Pro, vllm 0.19.0): Baseline: 2048ms -> Optimized: 955ms (-53.4%) Signed-off-by: Sy03 <1370724210@qq.com>
d061c52 to
af8568e
Compare
|
cc @ZeldaHuang |
|
Tested this on H20-3e (143GB) with the same profiling setup used for the sglang-omni comparison in #2515. Nice improvement over baseline! One finding: combining your model runner changes with a
The key difference: the current PR disables Happy to share the |
|
Suggestion for an additional ~15% per-step speedup: Instead of disabling
self._compiled_model_fwd = torch.compile(
self.model.forward,
dynamic=False,
options={"epilogue_fusion": False},
)
# In _ensure_buffers: allocate for max_cudagraph_capture_size
max_bsz = max(
self._vllm_config.scheduler_config.max_num_seqs,
self._vllm_config.compilation_config.max_cudagraph_capture_size,
1,
)
self._embed_buf = torch.zeros(max_bsz, max_seq, self._fast_dim, ...)
# In forward: pad batch, zero buffer, forward full shape
padded_bsz = self._padded_bsz(bsz)
embed_buf[:padded_bsz].zero_()
# ... fill positions 0 and 1 ...
# Each step: fixed-shape forward, then index the right position
for step in range(1, num_cb):
hidden_out = model_fwd(embed_buf[:padded_bsz, :max_seq, :], pos_ids)
logits = self.fast_output(self.fast_norm(hidden_out[:bsz, step, :]))
# ... sampling ...This matches the Qwen3 TTS CodePredictor pattern exactly. The fixed shape lets The full diff is on branch |
|
After merging @linyueqian 's suggestion, I start a new benchmark: |
…Graph - Extend CUDAGraphWrapper wrap condition with talker_mtp_graph_safe opt-in - Enable torch.compile(dynamic=True, epilogue_fusion=False) inside graph - Use compiled forward for all batch sizes in graph mode - Replace semantic_mask.any() with torch.where for graph compatibility - Add clamp(max=codebook_size-1) for codebook index safety - Clean fallback state reset (_compiled_model_fwd=None) Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Sy03 <1370724210@qq.com>
…t fallback Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Sy03 <1370724210@qq.com>
…tor (vllm-project#2520) Signed-off-by: Sy03 <1370724210@qq.com>
…tor (vllm-project#2520) Signed-off-by: Sy03 <1370724210@qq.com>
…tor (vllm-project#2520) Signed-off-by: Sy03 <1370724210@qq.com>
Purpose
Enable CUDAGraphWrapper for Fish Speech S2 Pro's Fast AR (residual codebook predictor), reducing inference latency by 52.7% on H20.
Fish Speech uses a Dual AR architecture: Slow AR (Qwen3-4B, piecewise CUDA Graph) + Fast AR (4 layers, 9-step AR loop). The Fast AR ran entirely in eager mode — profiling showed it accounts for 63% of steady-state decode time (~13ms/step x 73 steps).
Key changes (scoped to
gpu_ar_model_runner+ fish_speech files,gpu_model_runner.pyuntouched):talker_mtp_graph_safeattribute for TTS models to enable CUDAGraphWrapper_capture_talker_mtp_graphs()for explicit warmup+capture aftercapture_model()(_dummy_runhas no decode requests, so talker_mtp misses the normal capture window)semantic_mask.any()(host-device sync) withtorch.wheretorch.compileinside Fast AR when outer CUDA Graph is active (compile guards don't re-execute during graph replay)Only affects models with
talker_mtp_graph_safe = True. Qwen3-Omni/Qwen3-TTS unaffected.Test Plan
Test Result
Benchmark (H20, single request, 5 runs, both stages completed):
Audio quality (UTMOS):
Related PR: #2515
cc @linyueqian @zwhzzz0821