Skip to content

[Perf] VoxCPM2: streaming VAE + compile optimization (45% RTF reduction)#2758

Merged
hsliuustc0106 merged 6 commits intovllm-project:mainfrom
linyueqian:perf/voxcpm2-streaming-vae
Apr 15, 2026
Merged

[Perf] VoxCPM2: streaming VAE + compile optimization (45% RTF reduction)#2758
hsliuustc0106 merged 6 commits intovllm-project:mainfrom
linyueqian:perf/voxcpm2-streaming-vae

Conversation

@linyueqian
Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian commented Apr 14, 2026

Summary

Four changes for VoxCPM2 performance and streaming:

  1. Sliding-window VAE decode — replaces the O(N^2) accumulate-and-re-decode pattern with per-step streaming: each VAE call takes [decode_pad (12 frames) + new_patch (4 frames)] and slices out the new audio region using exact decoder_chunk_size alignment. Matches the nanovllm-voxcpm reference implementation.

  2. Eliminate GPU->CPU syncs in CFM diffusion — the Euler integration loop called .item() on 0-dim GPU tensors t/dt up to 4x per diffusion step (x10 timesteps x ~60 decode steps = ~2,400 syncs per long prompt). Replaced with on-device .copy_() broadcasts.

  3. Compile whole Model.forward instead of per-submodule — PR [Perf]: Speedup VoxCPM2 TTS performance and Support PagedAttention #2690 compiled layer.mlp + layer.self_attn.o_proj separately (56 Dynamo dispatches per step). Wrapping Model.forward in torch.compile(fullgraph=False) lets Dynamo memoise the full 28-layer loop. Biggest single win (~36%).

  4. Streaming Gradio demo — AudioWorklet-based gapless streaming player (adapted from Qwen3-TTS demo) with live TTFP/RTF metrics. Supports all 3 VoxCPM2 modes: Voice Design, Controllable Cloning, and Ultimate Cloning.

Benchmark results (H20 GPU, openbmb/VoxCPM2)

Prompt Before (main) After nanovllm-voxcpm Status
short (11 words) 0.211 0.126 0.138 beats nanovllm
medium (35 words) 0.211 0.127 0.119 tied (7% gap)
long (102 words) 0.232 0.131 0.109 17% gap

Net: 40-44% RTF reduction. Audio quality verified by listening. Streaming playback verified via Gradio demo (gapless, no boundary artifacts).

Files changed

  • vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py — streaming VAE decode, CFM sync fix
  • vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py — whole-model compile strategy
  • examples/online_serving/voxcpm2/gradio_demo.py — streaming Gradio demo (NEW)

Test plan

  • Offline inference: end2end.py produces valid audio
  • RTF benchmark: 3 prompt lengths x 5 runs each on H20
  • Streaming playback: Gradio demo with AudioWorklet player, no glitches
  • E2E tests: test_voxcpm2.py (zero-shot + voice clone)
  • Concurrent batching: verify max_batch_size=4 still works

Replace the O(N^2) accumulate-and-re-decode loop in _collect_audio with a
nanovllm-style sliding-window stream: each VAE decode takes only the
trailing pad frames plus the newly-generated latents, and we slice out
just the new audio region. Total VAE work drops from O(N^2) to O(N) over
a full generation.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
The inner Euler integration in _optimized_solve_euler called .item()
on the 0-dim GPU tensors t and dt up to 4 times per diffusion step,
forcing a GPU->CPU sync every time. With n_timesteps=10 and ~4 syncs
per step that is ~40 syncs per AR decode step; profiling counted
~4k aten::_local_scalar_dense calls over a long generation.

Broadcast the 0-dim tensors directly via .copy_() instead, keeping
the work on-device. Also gate the one-shot prefill norm log behind
an isEnabledFor(DEBUG) check so it no longer syncs on every request.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

PR vllm-project#2690 compiled `layer.mlp` and `layer.self_attn.o_proj` separately
(2 compiled regions per layer, fullgraph=True). Profiling showed 1,737
per-layer compiled-region dispatches on a long prompt at ~530 us CPU
self-time each (~925 ms of pure Dynamo dispatch overhead).

Wrap `Model.forward` in a single `torch.compile(fullgraph=False)` so
Dynamo traces the full 28-layer loop once. Graph breaks at
PagedAttention produce sub-graphs that are memoised after the first
step, collapsing per-step Python dispatch from 28+ calls to a handful.
Same treatment for the 8-layer residual model.

Benchmarked on H20: RTF dropped from 0.197 to 0.126 (36%) on the
long prompt, matching or beating nanovllm-voxcpm on short prompts.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
@linyueqian linyueqian force-pushed the perf/voxcpm2-streaming-vae branch from 6925202 to 62d5e12 Compare April 14, 2026 01:32
@linyueqian
Copy link
Copy Markdown
Collaborator Author

@JuanPZuluaga @Sy0307 ptal

@linyueqian linyueqian force-pushed the perf/voxcpm2-streaming-vae branch from 8bf003a to 5cec79a Compare April 14, 2026 02:11
Also fix streaming: return delta audio chunks (not cumulative) from
_collect_audio, and return None on steps without a VAE decode. The
output processor accumulates deltas into a list; the speech streaming
layer yields each new entry as a separate PCM chunk to the client.
Previously, returning cumulative audio caused the client to replay the
full audio from the start on every VAE decode interval.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
@linyueqian linyueqian force-pushed the perf/voxcpm2-streaming-vae branch from 5cec79a to 82e3dc1 Compare April 14, 2026 02:19
@linyueqian linyueqian added the ready label to trigger buildkite CI label Apr 14, 2026
@gesla2024
Copy link
Copy Markdown

gesla2024 commented Apr 14, 2026

The issue with streaming has been resolved, but there is a new problem: the input text is in Simplified Chinese, but the generated audio either speaks Cantonese or just produces noise. If streaming output is not enabled, the generated audio is always 16k in size, and the file contains nothing.

Environment: ubuntu24.04
Graphics Card: A100 80G
Memory: 320G
Storage: 5T
Vllm-Omni :0.19.0rc2.dev116+gede1c93c3
Vllm: 0.19.0

This is the command I use to start the server.

CUDA_VISIBLE_DEVICES=0 vllm-omni serve /home/VoxCPM/models/VoxCPM2 --stage-configs-path /home/www/vllm-omni/vllm_omni/model_executor/stage_configs/voxcpm2.yaml --omni --port 8071 --trust-remote-code --enforce-eager --gpu-memory-utilization 0.8

This is the request I tested after the server started

root@AS-4124GS-TNR:/home/VoxCPM# curl -X POST http://localhost:8071/v1/audio/speech \ -H "Content-Type: application/json" \ -d '{"model": "/home/VoxCPM/models/VoxCPM2", "input": "Hello, this is VoxCPM2.", "voice": "default"}' \ --output output.wav % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 15500 100 15404 100 96 22992 143 --:--:-- --:--:-- --:--:-- 23168 root@AS-4124GS-TNR:/home/VoxCPM# root@AS-4124GS-TNR:/home/VoxCPM# curl -X POST http://localhost:8071/v1/audio/speech \ -H "Content-Type: application/json" \ -d '{"model": "/home/VoxCPM/models/VoxCPM2", "input": "你好这是VoxCPM2.", "voice": "default"}' \ --output output2.wav % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 15497 100 15404 100 93 23021 138 --:--:-- --:--:-- --:--:-- 23129 root@AS-4124GS-TNR:/home/VoxCPM#

These are the generated audio files

output2.wav
output1.wav

This is a program I use for testing

`from future import annotations
import base64
import os
import httpx

DEFAULT_API_BASE = "http://localhost:8071"
DEFAULT_API_KEY = "sk-empty"

def encode_audio_to_base64(audio_path: str) -> str:
"""Encode a local audio file to a base64 data URL."""
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")

ext = audio_path.lower().rsplit(".", 1)[-1]
mime = {
    "wav": "audio/wav",
    "mp3": "audio/mpeg",
    "flac": "audio/flac",
    "ogg": "audio/ogg",
}.get(ext, "audio/wav")

with open(audio_path, "rb") as f:
    b64 = base64.b64encode(f.read()).decode("utf-8")
return f"data:{mime};base64,{b64}"

def main() -> None:
text = "你好,这是一个voxcpm2 测试程序 在vllm-omni 0.19 中测试的。"
payload: dict = {
"model": "/home/VoxCPM/models/VoxCPM2",
"input": text,
"voice": "default",
"response_format": "wav",
}
ref_audio = None

if ref_audio is not None:
    ref = ref_audio
    if ref.startswith(("http://", "https://", "data:")):
        payload["ref_audio"] = ref
    else:
        payload["ref_audio"] = encode_audio_to_base64(ref)

url = f"{DEFAULT_API_BASE}/v1/audio/speech"
print(f"POST {url}")
print(f"  text: {text}")
if ref_audio is not None:
    print(f"  ref_audio: {ref_audio[:80]}...")

with httpx.Client(timeout=300) as client:
    resp = client.post(
        url,
        json=payload,
        headers={"Authorization": f"Bearer {DEFAULT_API_KEY}"},
    )

if resp.status_code != 200:
    print(f"Error {resp.status_code}: {resp.text[:500]}")
    return

with open("output.wav", "wb") as f:
    f.write(resp.content)
print(f"Saved: output.wav ({len(resp.content):,} bytes)")

if name == "main":
main()`

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Apr 14, 2026

Thanks for the detailed report! This is a confirmed bug introduced by the sliding-window VAE change (ff7b5af). @gesla2024

Root cause: The streaming VAE refactor switched _collect_audio from returning cumulative audio to per-step delta chunks. However, _consolidate_multimodal_tensors in output_processor.py skips concatenation for the audio key (continue), so the chunks are never joined. In non-streaming mode, serving_speech.py then only keeps the last chunk (~0.16s), producing the ~16KB empty WAV you observed.

Fix: One-line change in output_processor.py — replace continue with reshape(-1) + torch.cat so all delta chunks are concatenated at consolidation time. Streaming is unaffected since consolidation only runs on finished=True.

Verified on H20:

  • "Hello, this is VoxCPM2." → 337KB / 3.52s (was 16KB)
  • "你好这是VoxCPM2." → 230KB / 2.40s (was 16KB)
  • Streaming: 28 chunks, TTFB 5.6s, working correctly

cc @linyueqian

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

BLOCKER scan:

  • Correctness: PASS
  • Reliability/Safety: PASS
  • Breaking Changes: PASS (performance optimization, user-facing behavior unchanged)
  • Test Coverage: PASS (PR body includes benchmark data, tests updated)
  • Documentation: PASS (gradio demo added)
  • Security: PASS

OVERALL: NO BLOCKERS

VERDICT: COMMENT

Excellent work! The 40-44% RTF reduction is impressive. A few minor notes:

  1. The streaming VAE pattern (sliding window with decode_pad) matches the reference implementation well.
  2. The GPU sync elimination (removing .item() calls) is a good micro-optimization.
  3. Consider adding a note about the decode_pad_frames=12 magic number - it would be helpful if this were documented or derived from VAE receptive field size.

Test plan shows e2e tests are pending - would be good to verify these before merge.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

with @Sy0307 suggestion of one line change, benchmark 5 request on L20 48G:

Group Mean RTF Median RTF P99 RTF Mean E2E (ms) Mean Audio (s)
short 0.139 0.139 0.141 969.7 6.976
medium 0.137 0.137 0.138 2330.3 16.960
long 0.138 0.138 0.138 4175.1 30.208

The streaming VAE change (ff7b5af) switched _collect_audio to return
per-step delta chunks instead of cumulative audio. Offline consumers
then received only the last chunk (~0.16s) because
_consolidate_multimodal_tensors in engine/output_processor.py skipped
concatenation for the 'audio' key, and the non-streaming speech server
kept only the last list entry (~16 KB WAV, empty-sounding output).

Fix at the structural root: have the consolidation step concatenate
audio delta chunks into the full waveform (flatten each to 1-D first
to tolerate inconsistent leading dims). Consolidation only runs on
finished=True so streaming is unaffected.

Offline extract_audio helpers add a defensive torch.cat fallback for
mid-stream list snapshots; normal completed requests now see a single
consolidated tensor.

Reported by @gesla2024 in vllm-project#2758, root-caused by @Sy0307.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
After the output_processor consolidation fix, multimodal_output["audio"]
is a Tensor rather than a list, so `dict.get("audio") or dict.get(...)`
raises "Boolean value of Tensor with more than one value is ambiguous".
Use explicit None-checks instead of `or` short-circuiting.

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
@hsliuustc0106 hsliuustc0106 merged commit 02e5dc7 into vllm-project:main Apr 15, 2026
8 checks passed
@gesla2024
Copy link
Copy Markdown

@Sy0307 Thank you, I will download the updated repository and test it.

@gesla2024
Copy link
Copy Markdown

I just pulled the latest code to test, and the noise issue has been resolved. However, the generated output, whether streamed, voice-cloned, or non-streamed, is not normal Mandarin. The test program I used is the gradio_demo.py program from this example: https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/examples/online_serving/voxcpm2/#example-materials. Below is the recording of my test.

Test environment:
Operating System: Ubuntu 24.04
vllm-omni: 0.19.0rc2.dev130 ga782ae478
GPU: A100 80G

3.mp4
2.mp4
1.mp4

@linyueqian
Copy link
Copy Markdown
Collaborator Author

Thanks so much for testing. @gesla2024 lets move our discussion over pr #2803.

@gesla2024
Copy link
Copy Markdown

Thanks so much for testing. @gesla2024 lets move our discussion over pr #2803.

Thank you, I will pull this branch and test it now.

@gesla2024
Copy link
Copy Markdown

gesla2024 commented Apr 15, 2026

After pulling #2803 and testing, I see the problem is still the same.

(APIServer pid=355286) INFO: 120.121.111.42:11988 - "POST /v1/audio/speech HTTP/1.1" 200 OK (APIServer pid=355286) INFO 04-15 15:59:06 [orchestrator.py:670] [Orchestrator] _handle_add_request: stage=0 req=speech-8b0befe35357a61e prompt_type=OmniEngineCoreRequest original_prompt_type=dict final_stage=0 num_sampling_params=1 (APIServer pid=355286) INFO 04-15 15:59:06 [stage_engine_core_client.py:170] [StageEngineCoreClient] Stage-0 adding request: speech-8b0befe35357a61e (APIServer pid=355286) INFO 04-15 15:59:07 [omni_base.py:162] [Summary] {} (APIServer pid=355286) INFO 04-15 16:03:33 [serving_speech.py:1508] TTS speech request speech-87acc4b0ed334ced: text='你好,这是一个voxcpm2 测试程序 在vllm-omni 0.19 中测试的。', model=voxcpm2 (APIServer pid=355286) INFO 04-15 16:03:33 [orchestrator.py:670] [Orchestrator] _handle_add_request: stage=0 req=speech-87acc4b0ed334ced prompt_type=OmniEngineCoreRequest original_prompt_type=dict final_stage=0 num_sampling_params=1 (APIServer pid=355286) INFO 04-15 16:03:33 [stage_engine_core_client.py:170] [StageEngineCoreClient] Stage-0 adding request: speech-87acc4b0ed334ced (APIServer pid=355286) INFO 04-15 16:03:59 [omni_base.py:162] [Summary] {}

This was generated after configuring the cloned voice, without enabling streaming output.
output.wav

This was generated without configuring the cloned voice, and streaming output was not enabled.
output.wav

Whether I use gradio_demo.py or openai_speech_client.py, it's the same; the TTS voice generated for Chinese content is not quite right. Another detail is that when streaming output, in gradio_demo.py the audio plays twice.

4.mp4

I directly used the test Python program from https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/examples/online_serving/voxcpm2/#example-materials to test, and then wrote a simple script myself to test it as well, but the Chinese text generation was incorrect. Calling the model directly to generate it had no problem.

CUDA:12.6
torch:2.10.0

After updating the main branch of vllm-omni, pull the branch #2803

The command to run the vllm service is

CUDA_VISIBLE_DEVICES=3 vllm-omni serve /home/VoxCPM/models/VoxCPM2 \ --stage-configs-path /home/www/vllm-omni/vllm_omni/model_executor/stage_configs/voxcpm2.yaml \ --omni \ --port 8071 \ --trust-remote-code \ --enforce-eager \ --gpu-memory-utilization 0.8

Using the command in the vllm-omni documentation at https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/examples/online_serving/voxcpm2/#start-the-server, a warning appears and it crashes immediately without detailed logs.

CUDA_VISIBLE_DEVICES=3 python -m vllm_omni.entrypoints.openai.api_server \ --model /home/VoxCPM/models/VoxCPM2 \ --stage-configs-path /home/www/vllm-omni/vllm_omni/model_executor/stage_configs/voxcpm2.yaml \ --host 0.0.0.0 --port 8071

:128: RuntimeWarning: 'vllm_omni.entrypoints.openai.api_server' found in sys.modules after import of package 'vllm_omni.entrypoints.openai', but prior to execution of 'vllm_omni.entrypoints.openai.api_server'; this may result in unpredictable behaviour

This is the content returned at runtime

图片

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Apr 15, 2026

Can not reproduce your issue @gesla2024 . Also cc @hsliuustc0106 @linyueqian

@gesla2024
Copy link
Copy Markdown

Can not reproduce your issue @gesla2024 . Also cc @hsliuustc0106 @linyueqian

`from future import annotations

import base64
import os
import httpx

DEFAULT_API_BASE = "http://localhost:8071"
DEFAULT_API_KEY = "sk-empty"

REFERENCE_AUDIO_PATH = "D:\Users\Administrator\Downloads\200030.wav"
REFERENCE_AUDIO_BASE64 = ""

def encode_audio_to_base64(audio_path: str) -> str:
"""Encode a local audio file to a base64 data URL."""
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")

ext = audio_path.lower().rsplit(".", 1)[-1]
mime = {
    "wav": "audio/wav",
    "mp3": "audio/mpeg",
    "flac": "audio/flac",
    "ogg": "audio/ogg",
}.get(ext, "audio/wav")

with open(audio_path, "rb") as f:
    b64 = base64.b64encode(f.read()).decode("utf-8")
return f"data:{mime};base64,{b64}"

def main() -> None:
text = "你好,这是一个voxcpm2 测试程序 在vllm-omni 0.19 中测试的。"
payload: dict = {
"model": "/home/VoxCPM/models/VoxCPM2",
"input": text,
"voice": "default",
"response_format": "wav",
}
ref_audio = None

if ref_audio is not None:
    ref = ref_audio
    if ref.startswith(("http://", "https://", "data:")):
        payload["ref_audio"] = ref
    else:
        payload["ref_audio"] = encode_audio_to_base64(ref)

url = f"{DEFAULT_API_BASE}/v1/audio/speech"
print(f"POST {url}")
print(f"  text: {text}")
if ref_audio is not None:
    print(f"  ref_audio: {ref_audio[:80]}...")

with httpx.Client(timeout=300) as client:
    resp = client.post(
        url,
        json=payload,
        headers={"Authorization": f"Bearer {DEFAULT_API_KEY}"},
    )

if resp.status_code != 200:
    print(f"Error {resp.status_code}: {resp.text[:500]}")
    return

with open("output.wav", "wb") as f:
    f.write(resp.content)
print(f"Saved: output.wav ({len(resp.content):,} bytes)")

if name == "main":
main()`

This is my test code for generating audio, let's see if it helps you.

@Sy0307
Copy link
Copy Markdown
Contributor

Sy0307 commented Apr 15, 2026

Sorry it is my mistake as I forgot to remove temporary tokenizer config which is for test in my remote test machine so my chinese tokenizer result is right, but now main branch's result don't. Sorry for that again. @gesla2024

y123456y78 pushed a commit to y123456y78/vllm-omni that referenced this pull request Apr 15, 2026
…on) (vllm-project#2758)

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
@gesla2024
Copy link
Copy Markdown

Okay, thank you. I will pull it again and test it.

lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
…on) (vllm-project#2758)

Signed-off-by: Yueqian Lin <linyueqian@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants