Skip to content

[Perf]: Speedup VoxCPM2 TTS performance and Support PagedAttention#2690

Merged
hsliuustc0106 merged 12 commits intovllm-project:mainfrom
Sy0307:perf/voxcpm2-cuda-graph-optimize
Apr 13, 2026
Merged

[Perf]: Speedup VoxCPM2 TTS performance and Support PagedAttention#2690
hsliuustc0106 merged 12 commits intovllm-project:mainfrom
Sy0307:perf/voxcpm2-cuda-graph-optimize

Conversation

@Sy0307
Copy link
Copy Markdown
Contributor

@Sy0307 Sy0307 commented Apr 10, 2026

Summary

Performance optimization for VoxCPM2 TTS native AR pipeline, built on #2658.

Achieves 4.2x speedup (RTF 0.946 → 0.224) on a single H20 GPU with no audio quality loss.

Key changes:

  • PagedAttention for base_lm (28 layers) and residual_lm (8 layers), replacing native StaticKVCache
  • Selective torch.compile on MLP + o_proj per layer; RMSNorm/RoPE kept eager for precision
  • CFM pre-allocated buffers, fused QKV projection, batch-level residual_model forward
  • Online serving support in serving_speech.py

Architecture (per AR step):

MiniCPM4PagedForVoxCPM2 (scaffold, 28 layers, PagedAttention + fp32 RoPE)
→ FSQ → MiniCPM4PagedResidualLM (8 layers, PagedAttention, no RoPE)
→ LocDiT (CFM 10-step) → feat_encoder → AudioVAE → 48kHz waveform

Performance (single H20, enforce_eager=true):

Metric #2658 baseline This PR
RTF (best) 0.946 0.224

Test Plan

  • E2E regression: 3 EN + 5 ZH sentences, Whisper large-v3 ASR 8/8 correct
  • RTF stable at 0.22-0.27 after warmup
  • 4 concurrent requests: all produce valid audio

hsliuustc0106

This comment was marked as low quality.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few blocking issues before this can move forward:

DCO check failing — please fix your commit sign-offs.

No online serving support. The PR only has offline inference and no online serving path. Per contribution guidelines, new models must support both modes.

Title mismatch. Prefix is [Perf] but this is a full new model integration (1901 LOC, 12 files). Should be [Model].

External runtime dependency. The model requires pip install voxcpm or a source checkout. What is the long-term plan? Any breaking change in the voxcpm package breaks vllm-omni silently.

Scaffold weight cleanup correctness. _free_scaffold_weights() replaces params with empty tensors after first prefill. If a second request triggers a new prefill, the scaffold forward returns zeros — the engine still tries to do KV allocation. The _prefill_completed flag is reset in preprocess() for new requests, but after scaffold weights are freed, the MiniCPM forward will crash or produce garbage for new prefills.

_sliding_vae_decode is dead code. It just calls _full_vae_decode(). Either implement or remove it along with the _ENABLE_SLIDING_VAE flag.

12 env vars for tuning is excessive. Consider consolidating into a config object or at least documenting them centrally.

@Sy0307 Sy0307 force-pushed the perf/voxcpm2-cuda-graph-optimize branch from 175db41 to 23fdd1e Compare April 11, 2026 02:14
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

This is a full model addition (12 files, ~1900 LOC) rather than just a perf optimization. The [Perf] prefix is misleading — it adds VoxCPM2 as a new model with examples, tests, stage configs, and model registration.

Key issues:

  1. Wrong prefix — should be [Model] not [Perf]. The perf optimizations (CFM pre-allocation, torch.compile, scaffold skip) are part of the model implementation, not a standalone perf PR.
  2. No online serving config — only offline example provided. Is there an OpenAI-compatible serving endpoint?
  3. DCO check — please verify all commits are signed off.
  4. 12 environment variables in voxcpm2_import_utils.py — consider consolidating or documenting why each is needed. That many env vars is a operational burden for deployment.

@Sy0307 Sy0307 force-pushed the perf/voxcpm2-cuda-graph-optimize branch from 23fdd1e to 4ddc808 Compare April 12, 2026 17:42
@Sy0307 Sy0307 changed the title [Perf] [VoxCPM2]: Speedup inference via CFM pre-allocation, torch.compile, and scaffold skip [Perf] VoxCPM2: 4.2x inference speedup via PagedAttention + selective compile Apr 12, 2026
@Sy0307 Sy0307 force-pushed the perf/voxcpm2-cuda-graph-optimize branch from 4ddc808 to 3424e17 Compare April 12, 2026 17:46
@Sy0307 Sy0307 changed the title [Perf] VoxCPM2: 4.2x inference speedup via PagedAttention + selective compile [Perf]: Speedup VoxCPM2 TTS performance and Support PagedAttention Apr 12, 2026
@Sy0307 Sy0307 force-pushed the perf/voxcpm2-cuda-graph-optimize branch from 5e70e0e to 8989df3 Compare April 13, 2026 03:46
@Sy0307 Sy0307 marked this pull request as ready for review April 13, 2026 03:51
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Apr 13, 2026

@hsliuustc0106 ' review sugguestions have all solved.

PTAK @linyueqian .

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on H20 (single GPU, enforce_eager=true, vllm 0.19.0)

Single request + streaming: looks good. RTF ~0.21 on warm runs, audio quality is correct, streaming produces 33 incremental chunks with monotonically growing audio. torch.compile targets all apply cleanly.

Concurrent batching has two bugs.


Bug 1: Stop logic fails for batched requests

With 2 concurrent requests, both produce ~58s audio for sentences that should be ~4s. The stop head never triggers, so requests run to _max_decode_steps.

Repro:

engine.generate([
    {"prompt": "The quick brown fox jumps over the lazy dog."},
    {"prompt": "VoxCPM2 uses PagedAttention for language model inference."},
])
# Each request outputs ~58s audio instead of ~4s
# Sequential runs produce correct ~4s audio

I think the issue is in compute_logits() -- _results_queue maps (req_id, stop_logits) pairs by list index to batch position, but the queue ordering may not match the scheduler's token ordering when multiple requests are in flight. The stop signal never reaches the right request.


Bug 2: 4-batch prefill crashes with shape mismatch

RuntimeError: The size of tensor a (11) must match the size of tensor b (18)
  at voxcpm2_talker.py:499
  enc_outputs = tts.fsq_layer(enc_out) * feat_mask.unsqueeze(-1) + ...

feat_mask is computed in preprocess() via _build_prefill_inputs() (which may expand the sequence with prompt cache tokens), but req_hidden is sliced from scaffold_hidden using the vllm-visible token count. These lengths diverge when 4 requests are batched.


Test matrix:

Feature Result
Single request (warm) RTF 0.213
Streaming 33 incremental audio chunks, monotonically growing
2 concurrent Runs but stop logic broken (58s audio for 4s sentences)
4 concurrent Crash in _prepare_residual_prefill

Single-request perf is solid. Suggest fixing concurrent batching before merge, or setting max_batch_size: 1 in the stage config and deferring batching to a follow-up PR.

@Sy0307 Sy0307 force-pushed the perf/voxcpm2-cuda-graph-optimize branch from 74f519a to 239d772 Compare April 13, 2026 07:48
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Apr 13, 2026

Tested on H20 (single GPU, enforce_eager=true, vllm 0.19.0)

Single request + streaming: looks good. RTF ~0.21 on warm runs, audio quality is correct, streaming produces 33 incremental chunks with monotonically growing audio. torch.compile targets all apply cleanly.

Concurrent batching has two bugs.

Both bugs have been fixed. Thanks.

Sy0307 added 9 commits April 13, 2026 16:21
Performance optimization for VoxCPM2 TTS native AR pipeline.
Achieves ~3x speedup (RTF 0.95→0.29) on H20 with no quality loss.

- CFM pre-allocated buffers (eliminate ~60 allocs/step in Euler solver)
- torch.compile on LocDiT estimator, feat_encoder, AudioVAE decode
- Per-request state management for max_batch_size > 1
- Scaffold skip for decode steps (use native base_lm directly)
- gpu_memory_utilization 0.9→0.3 (VoxCPM2 uses less KV cache)

Signed-off-by: Sy03 <1370724210@qq.com>
Replace native StaticKVCache with vllm PagedAttention for both base_lm
(28 layers) and residual_lm (8 layers). Further 1.3x speedup on top of
commit A, achieving RTF 0.224 on H20 with UTMOS 4.029 (no quality loss).

Key changes:
- MiniCPM4PagedForVoxCPM2: PagedAttention base_lm with fp32 RoPE/RMSNorm
- MiniCPM4PagedResidualLM: 8-layer no-RoPE residual LM with PagedAttention
- Selective torch.compile: MLP + o_proj compiled, RMSNorm/RoPE kept eager
- Fused QKV projection (lazy-concat for checkpoint compatibility)
- Batch-level residual_model forward for concurrent request safety
- on_requests_finished hook for proper state cleanup
- Online serving support via serving_speech.py
- Dead code cleanup: 1238→818 lines in talker.py

Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Sy03 <1370724210@qq.com>
- Remove test_flash_attn_precision.py (tests removed flash_attn path)
- Remove benchmark.py (convenience script, not core)
- Revert omni.py sort change (not a VoxCPM2-specific issue)

Signed-off-by: Sy03 <1370724210@qq.com>
…rectly

- scale_emb should only apply when use_mup=True (native uses 1.0 otherwise)
- Use vllm input_ids directly instead of decode→re-tokenize round trip
- Remove debug stop logit logging

Signed-off-by: Sy03 <1370724210@qq.com>
- Extract _resolve_lm_cfg() to deduplicate dict→namespace conversion
- Remove dead _RMSNorm class from minicpm4_hf_compat.py
- Replace zeros embed fallback with RuntimeError on missing prefill state

Signed-off-by: Sy03 <1370724210@qq.com>
- Record span_len in _pending_requests, use it for scaffold_hidden slicing
  instead of embeds.shape[0] (which may exceed vllm's allocated token count
  in voice clone / continuation mode)
- Pad scaffold_hidden in _prepare_residual_prefill when TTS mask length
  exceeds scaffold length (extra positions are zero-padded)

Signed-off-by: Sy03 <1370724210@qq.com>
… check)

- gpu_ar_model_runner: write None to mm_payload when idx >= len(v)
  instead of silently skipping the key (affects all AR models)
- Defer on_requests_finished cleanup to after forward() completes,
  preventing premature state deletion for in-flight requests
- Use numbers.Integral in _is_raw_audio to handle numpy int sample rates

Signed-off-by: Sy03 <1370724210@qq.com>
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-tested on H20 (latest commit 2ba6714)

Both batching bugs from my previous review still reproduce. I traced the root cause:

All requests get req_id="default" because gpu_model_runner.py never passes request_id to model.preprocess(). This means every concurrent request shares the same _RequestState -- the second preprocess() call overwrites the first's prefill_masks, and stop signals get mixed across requests.

I pushed a one-line fix in #2746 (req_infos["request_id"] = req_id before the preprocess call). With that fix:

Test Before #2746 After #2746
Single request PASS (RTF ~0.21) PASS
2 concurrent FAIL (57s + 58s audio) PASS (2.72s + 5.28s)
4 concurrent FAIL (shape crash) FAIL (msgspec crash -- different bug)

The 4-concurrent failure with #2746 applied is a separate orchestrator-level bug: msgspec.ValidationError: cannot unpack non-iterable NoneType object when requests finish at different times and mm_payload contains None audio entries. This isn't VoxCPM2-specific.

Suggestion: Once #2746 is merged, rebase this PR on top of it. Consider setting max_batch_size: 2 in the stage config until the orchestrator msgspec bug is resolved separately.

@Sy0307 Sy0307 force-pushed the perf/voxcpm2-cuda-graph-optimize branch from 727a236 to beb6530 Compare April 13, 2026 16:56
When batching multiple concurrent requests, some requests may not have
audio output on every decode step. The None values in mm_payload cause
msgspec deserialization to fail because OmniEngineCoreOutput.pooling_output
is typed as dict[str, torch.Tensor] which cannot hold None values.

Only set mm_payload[k] when the element is non-None.

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 force-pushed the perf/voxcpm2-cuda-graph-optimize branch from beb6530 to b6fdab7 Compare April 13, 2026 16:58
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-tested with latest commits (727a236) + #2746 on H20

All tests pass with gpu_memory_utilization: 0.9:

Test Result
Single request PASS (RTF ~2.0)
2 concurrent PASS (2.72s + 5.28s)
4 concurrent PASS (2.72s + 3.36s + 3.36s + 4.00s)

The msgspec fix in 3b45efb works. The remaining blocker is that this PR depends on #2746 (request_id passthrough in gpu_model_runner.py). Without it, all concurrent requests share a single "default" state and batching is completely broken.

Action items before merge:

  1. Merge #2746 first (or include the one-line fix in this PR)
  2. Change gpu_memory_utilization from 0.3 to 0.9 in voxcpm2.yaml -- at 0.3, 4 concurrent requests OOM on H20 (140GB)

@linyueqian linyueqian added the ready label to trigger buildkite CI label Apr 13, 2026
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

- Pass request_id to model.preprocess() so models with per-request
  state (e.g. VoxCPM2) can distinguish concurrent requests. Without
  this, all requests fall back to "default" and share one state,
  breaking batched stop logic and causing prefill shape mismatches.
- Increase gpu_memory_utilization from 0.3 to 0.9: CFM side computation
  (LocDiT + AudioVAE) runs outside vllm's memory budget, so 0.3 OOMs
  with 4 concurrent requests on H20.

Tested on H20 (bs=4, enforce_eager):
  single: PASS | 2-concurrent: PASS | 4-concurrent: PASS

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
@linyueqian linyueqian force-pushed the perf/voxcpm2-cuda-graph-optimize branch from 97c91a8 to b441f00 Compare April 13, 2026 17:20
@linyueqian linyueqian enabled auto-merge (squash) April 13, 2026 17:21
@hsliuustc0106 hsliuustc0106 disabled auto-merge April 13, 2026 21:16
@hsliuustc0106 hsliuustc0106 merged commit 14f7910 into vllm-project:main Apr 13, 2026
7 of 8 checks passed
linyueqian added a commit to linyueqian/vllm-omni that referenced this pull request Apr 14, 2026
PR vllm-project#2690 compiled `layer.mlp` and `layer.self_attn.o_proj` separately
(2 compiled regions per layer, fullgraph=True). Profiling showed 1,737
per-layer compiled-region dispatches on a long prompt at ~530 us CPU
self-time each (~925 ms of pure Dynamo dispatch overhead).

Wrap `Model.forward` in a single `torch.compile(fullgraph=False)` so
Dynamo traces the full 28-layer loop once. Graph breaks at
PagedAttention produce sub-graphs that are memoised after the first
step, collapsing per-step Python dispatch from 28+ calls to a handful.
Same treatment for the 8-layer residual model.

Benchmarked on H20: RTF dropped from 0.197 to 0.126 (36%) on the
long prompt, matching or beating nanovllm-voxcpm on short prompts.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
linyueqian added a commit to linyueqian/vllm-omni that referenced this pull request Apr 14, 2026
PR vllm-project#2690 compiled `layer.mlp` and `layer.self_attn.o_proj` separately
(2 compiled regions per layer, fullgraph=True). Profiling showed 1,737
per-layer compiled-region dispatches on a long prompt at ~530 us CPU
self-time each (~925 ms of pure Dynamo dispatch overhead).

Wrap `Model.forward` in a single `torch.compile(fullgraph=False)` so
Dynamo traces the full 28-layer loop once. Graph breaks at
PagedAttention produce sub-graphs that are memoised after the first
step, collapsing per-step Python dispatch from 28+ calls to a handful.
Same treatment for the 8-layer residual model.

Benchmarked on H20: RTF dropped from 0.197 to 0.126 (36%) on the
long prompt, matching or beating nanovllm-voxcpm on short prompts.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Celeste-jq pushed a commit to IsleOfDawnlight/vllm-omni-voxcpm that referenced this pull request Apr 14, 2026
…llm-project#2690)

Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Co-authored-by: Yueqian Lin <linyueqian@outlook.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
alex-jw-brooks pushed a commit to alex-jw-brooks/vllm-omni that referenced this pull request Apr 14, 2026
…llm-project#2690)

Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Co-authored-by: Yueqian Lin <linyueqian@outlook.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants