Skip to content

[Perf] VoxCPM2: Speedup by manual CUDA Graph capture for scaffold/residual forward#2803

Merged
hsliuustc0106 merged 2 commits into
vllm-project:mainfrom
Sy0307:perf/voxcpm2-streaming-vae
Apr 15, 2026
Merged

[Perf] VoxCPM2: Speedup by manual CUDA Graph capture for scaffold/residual forward#2803
hsliuustc0106 merged 2 commits into
vllm-project:mainfrom
Sy0307:perf/voxcpm2-streaming-vae

Conversation

@Sy0307
Copy link
Copy Markdown
Contributor

@Sy0307 Sy0307 commented Apr 14, 2026

WAITING FOR #2758 MERGE.

Purpose

Follow-up to #2690 and #2758. Adds manual CUDA Graph capture/replay for the 28-layer scaffold (MiniCPM4PagedForVoxCPM2) and 8-layer residual model forwards during decode steps, eliminating per-step kernel launch overhead.

Key changes

CUDA Graph capture/replay (voxcpm2_talker.py):

  • _CapturedGraph dataclass holding static input/output buffers and the CUDA graph
  • _capture_graph(): 3 warmup runs + graph capture under patched ForwardContext (scheduler_metadata nullified to avoid shape mismatches across batch sizes)
  • _replay_graph(): copy inputs into static buffers, replay, clone output
  • forward() dispatch: pure-decode batches → graph replay; mixed prefill+decode → eager fallback
  • Private memory pool via torch.cuda.graph_pool_handle() to isolate from reduce-overhead cudagraph_trees
  • Lazy capture on first decode step after torch.compile warmup completes
  • When CUDA Graph is enabled, scaffold/residual skip torch.compile (graph capture already eliminates kernel launch overhead)

precompute_fused_qkv() (minicpm4_paged.py):

  • Materializes fused QKV weights before graph capture to prevent lazy torch.cat allocation inside the captured graph

Test Plan

  • Single-request E2E generation: verified complete audio output (WAV >100KB, not truncated)
  • Sequential RTF benchmark across 6 prompts (short/medium/long × en/zh)
  • Concurrent throughput benchmark (c=1, 2, 4)
  • Verified eager fallback for mixed prefill+decode batches
  • ASR quality verification on generated audio samples

Test Result

Hardware: NVIDIA H20 (98GB), CUDA 13.0

RTF Comparison (3 runs, OpenAI speech API)

Prompt chars Baseline (PR #2758) CUDA Graph Improvement
avg / min avg / min
short_en 49 0.139 / 0.134 0.112 / 0.106 -19% avg
medium_en 169 0.134 / 0.132 0.105 / 0.103 -22% avg
long_en 412 0.133 / 0.133 0.103 / 0.102 -23% avg
short_zh 15 0.139 / 0.134 0.110 / 0.106 -21% avg
medium_zh 55 0.133 / 0.132 0.104 / 0.103 -22% avg
long_zh 145 0.132 / 0.132 0.103 / 0.103 -22% avg
Average 0.135 / 0.133 0.106 / 0.104 -21% avg

Concurrent Throughput (medium_en, 2 runs)

Concurrency Baseline CUDA Graph Improvement
c=1 0.81 req/s 0.99 req/s +22%
c=2 0.88 req/s 1.23 req/s +40%
c=4 1.25 req/s 1.42 req/s +14%

cc @linyueqian @hsliuustc0106

@linyueqian
Copy link
Copy Markdown
Collaborator

impressive! I will take a look later today.

@linyueqian linyueqian added the ready label to trigger buildkite CI label Apr 14, 2026
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Quick blocker scan - no issues found. A few notes:

  1. The Gradio demo adds 600 LOC which seems outside the scope of a perf PR. Consider splitting into a separate PR.

  2. The PR description mentions "WAITING FOR [Perf] VoxCPM2: streaming VAE + compile optimization (45% RTF reduction) #2758" but mergeStateStatus shows BLOCKED - clarify dependencies.

  3. Pre-fusing QKV weights is good, but verify this doesn't increase memory footprint for models that don't use CUDA Graph.

  4. Benchmark results are solid, but would be helpful to see the delta between this PR and baseline ([Perf]: Speedup VoxCPM2 TTS performance and Support PagedAttention #2690) to isolate the CUDA Graph contribution.

  5. The sliding window VAE decode change is significant - verify this doesn't change audio quality.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

hsliuustc0106 commented Apr 15, 2026

amazing results from L20 48G:

Group Mean RTF Median RTF P99 RTF Mean E2E (ms) Mean Audio (s)
short 0.082 0.082 0.083 543.2 6.592
medium 0.079 0.079 0.079 1393.6 17.728
long 0.078 0.078 0.078 2309.1 29.696

If we compare #2803 against the earlier steady-state result we saw before, the improvement is roughly:

Group Before RTF After RTF RTF Reduction Speedup
short 0.144 0.082 43.1% lower 1.76x
medium 0.141 0.079 44.0% lower 1.78x
long 0.141 0.078 44.7% lower 1.81x

Capture and replay CUDA Graphs for the 28-layer scaffold and 8-layer
residual model forwards during decode steps, eliminating per-step
kernel launch overhead. Reduces average RTF from 0.135 to 0.106 (-21%)
and improves concurrent throughput by 14-40%.

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 force-pushed the perf/voxcpm2-streaming-vae branch from 4131a72 to 8cfdf66 Compare April 15, 2026 03:37
@Sy0307 Sy0307 marked this pull request as ready for review April 15, 2026 03:48
@Sy0307 Sy0307 requested a review from hsliuustc0106 as a code owner April 15, 2026 03:48
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@hsliuustc0106 hsliuustc0106 merged commit 38d5f2d into vllm-project:main Apr 15, 2026
6 of 8 checks passed
y123456y78 pushed a commit to y123456y78/vllm-omni that referenced this pull request Apr 15, 2026
@hsliuustc0106 hsliuustc0106 mentioned this pull request Apr 17, 2026
1 task
Sy0307 added a commit to Sy0307/vllm-omni that referenced this pull request Apr 18, 2026
Root cause of the c>=2 vectorized_gather OOB / illegal memory access
reported on VoxCPM2 after PR vllm-project#2803: preprocess() used _pending_requests
(per-step prefix, cleared each forward) as if it were the full active
batch. When a new prefill was scheduled after cached decode requests in
the same batch, the decode requests were wrongly classified as stale and
removed from _active_states; the next forward then recreated empty
states, silently skipped residual_model for them, and desynchronized
attn metadata -- producing NaN on the prefill slice and eventually
crashing inside the CFM/LocDiT kernels.

State cleanup is now driven solely by on_requests_finished ->
_flush_deferred_cleanup at the end of forward().  A one-shot leak
warning in _get_or_create_state guards against a future regression.

Signed-off-by: Sy03 <1370724210@qq.com>
Sy0307 added a commit to Sy0307/vllm-omni that referenced this pull request Apr 18, 2026
Root cause of the c>=2 vectorized_gather OOB / illegal memory access
reported on VoxCPM2 after PR vllm-project#2803: preprocess() used _pending_requests
(per-step prefix, cleared each forward) as if it were the full active
batch. When a new prefill was scheduled after cached decode requests in
the same batch, the decode requests were wrongly classified as stale and
removed from _active_states; the next forward then recreated empty
states, silently skipped residual_model for them, and desynchronized
attn metadata -- producing NaN on the prefill slice and eventually
crashing inside the CFM/LocDiT kernels.

State cleanup is now driven solely by on_requests_finished ->
_flush_deferred_cleanup at the end of forward().  A one-shot leak
warning in _get_or_create_state guards against a future regression.

Signed-off-by: Sy03 <1370724210@qq.com>
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants