[Perf] Fish Speech S2 Pro: CUDA graph acceleration for Fast AR codebook decode#2579
Closed
linyueqian wants to merge 1 commit intovllm-project:mainfrom
Closed
[Perf] Fish Speech S2 Pro: CUDA graph acceleration for Fast AR codebook decode#2579linyueqian wants to merge 1 commit intovllm-project:mainfrom
linyueqian wants to merge 1 commit intovllm-project:mainfrom
Conversation
Switch Fish Speech's Fast AR from variable-length re-prefill to fixed-shape full-buffer forward with CUDA graph capture and replay. Follows the same pattern as Qwen3 TTS's CodePredictor. Key changes: - Always forward the full [padded_bsz, max_seq, hidden] buffer (zero-padded future positions) instead of slicing to growing seq_len - torch.compile with epilogue_fusion=False, dynamic=False - Capture CUDA graphs per power-of-2 batch-size bucket - Replay graph each codebook step, index the relevant position - Sampling (top_k/top_p/multinomial) stays outside the graph - Defer compile + graph capture to first forward() to avoid OOM during model loading (before KV cache allocation) Benchmark on H20-3e (143GB): - Per-step Fast AR time: 73ms -> 50ms (31% reduction) - E2E latency: 1800ms -> 1253ms (30% reduction) - RTF: 0.48 -> 0.35 (27% improvement)
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Collaborator
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
Accelerate Fish Speech S2 Pro's per-step decode latency by enabling CUDA graph capture and replay for the Fast AR (residual codebook predictor). This follows the same pattern already used by Qwen3 TTS's CodePredictor.
Key Changes
fish_speech_fast_ar.py— Switch from variable-length re-prefill to fixed-shape full-buffer forward:[padded_bsz, max_seq, hidden]embedding buffer (zero-padded future positions) instead of slicing to growingseq_lentorch.compilewithepilogue_fusion=False,dynamic=Falsetop_k/top_p/multinomial) stays outside the graphforward()to avoid OOM during model loading (before KV cache allocation)self.talker = self.fast_arsoOmniGPUModelRunnerwrapstalker_mtpinCUDAGraphWrapperfish_speech_slow_ar.py— Maketalker_mtpCUDA-graph-safe:if semantic_mask.any():branch with branchlesstorch.where(eliminates host-device sync during graph capture)self.talker = self.fast_arto trigger outerCUDAGraphWrapperwrappingself.talker_mtp_graph_safe = TrueflagTest Plan
Tested on NVIDIA H20-3e (143GB) with
fishaudio/s2-pro, single GPU,enforce_eagerfor Stage 0, using the benchmark config from #2515.CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \ "fishaudio/s2-pro" --omni --host 127.0.0.1 --port 8091 \ --stage-configs-path benchmarks/fish-speech/config/vllm_omni/fish_speech_s2_pro.yaml \ --trust-remote-code --enforce-eagerClient-side profiling with streaming PCM requests at concurrency=1.
Test Result
Server logs confirm CUDA graph capture: