From ff7b5af6d50a910801977f06d715468d78204c2b Mon Sep 17 00:00:00 2001 From: Yueqian Lin Date: Mon, 13 Apr 2026 20:37:11 -0400 Subject: [PATCH 1/6] perf(voxcpm2): sliding-window VAE decode (O(N) streaming) Replace the O(N^2) accumulate-and-re-decode loop in _collect_audio with a nanovllm-style sliding-window stream: each VAE decode takes only the trailing pad frames plus the newly-generated latents, and we slice out just the new audio region. Total VAE work drops from O(N^2) to O(N) over a full generation. Signed-off-by: Yueqian Lin --- .../models/voxcpm2/voxcpm2_talker.py | 77 +++++++++++++++---- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 0898ca59ae4..4afc54548ed 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -19,7 +19,6 @@ import librosa import torch import torch.nn as nn -from einops import rearrange from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import ( @@ -86,7 +85,13 @@ class _RequestState: curr_prefix_feat_cond: torch.Tensor | None = None last_audio_patch_gpu: torch.Tensor | None = None precomputed_stop_logits: torch.Tensor | None = None - accumulated_patches: list[torch.Tensor] = dataclasses.field(default_factory=list) + # Patches produced since the last VAE decode, shape (patch_size, feat_dim) each. + pending_latents: list[torch.Tensor] = dataclasses.field(default_factory=list) + # Rolling tail of previously-decoded latents used as VAE receptive-field context. + # Shape (n_pad_frames, feat_dim) on GPU. None before first decode. + decode_pad: torch.Tensor | None = None + # Audio chunks already emitted (CPU float32), concatenated for cumulative output. + audio_chunks: list[torch.Tensor] = dataclasses.field(default_factory=list) decode_step_count: int = 0 request_start_time: float = 0.0 prefill_completed: bool = False @@ -321,6 +326,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._cfg_value = 2.0 self._cfg_cutoff_ratio = 1.0 self._vae_decode_interval = 5 + # Number of trailing latent frames to keep as VAE receptive-field context + # for sliding-window streaming decode. 12 matches the nanovllm reference + # implementation and covers the longest VAE decoder receptive field. + self._n_decode_pad_frames = 12 self._enable_torch_compile = True self._compile_vae = True self._max_decode_steps = 2000 @@ -720,25 +729,57 @@ def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor # -------------------- audio collection -------------------- def _collect_audio(self, state: _RequestState) -> torch.Tensor | None: + """Sliding-window VAE decode. + + Each decode step appends the newly-generated latent patch to + ``pending_latents``. Every ``_vae_decode_interval`` steps (or on the + final step) we feed ``[decode_pad, *pending_latents]`` through the VAE + once, slice out only the *new* audio region, append it to + ``audio_chunks`` and refresh ``decode_pad`` with the last + ``_n_decode_pad_frames`` latent frames. Total VAE work is O(N) in + the number of decode steps instead of O(N**2) for the full re-decode. + """ patch = state.last_audio_patch_gpu if patch is not None: state.last_audio_patch_gpu = None - state.accumulated_patches.append(patch.reshape(1, -1).float()) + # patch shape is (1, patch_size, feat_dim); keep on GPU as float32. + state.pending_latents.append(patch.reshape(-1, self._feat_dim).to(torch.float32)) - if not state.accumulated_patches: - return None + if not state.pending_latents: + return state.last_decoded_audio + + pending_count = len(state.pending_latents) + if not (pending_count >= self._vae_decode_interval or state.is_stopping + or state.last_decoded_audio is None): + return state.last_decoded_audio + + self._perf.start("vae_decode") + new_latents = torch.cat(state.pending_latents, dim=0) # (T_new, D) + state.pending_latents = [] + + if state.decode_pad is not None: + vae_latents = torch.cat([state.decode_pad, new_latents], dim=0) + pad_frames = state.decode_pad.shape[0] + else: + vae_latents = new_latents + pad_frames = 0 + + feat = vae_latents.reshape(1, -1, self._feat_dim).transpose(1, 2).contiguous() + with torch.no_grad(): + audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1) + + # Slice out the newly-generated audio (everything after the pad region). + chunk_size = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_latents.shape[0])) + new_audio = audio[pad_frames * chunk_size:].detach().cpu().float() + state.audio_chunks.append(new_audio) + + # Roll the pad buffer to the last N frames of the current input. + state.decode_pad = vae_latents[-self._n_decode_pad_frames:].detach() - n = len(state.accumulated_patches) - if n <= 1 or n % self._vae_decode_interval == 0 or state.is_stopping: - self._perf.start("vae_decode") - all_p = torch.cat(state.accumulated_patches, dim=0) - state.accumulated_patches = [all_p] - feat = rearrange(all_p.reshape(1, -1, self._feat_dim), "b t d -> b d t") - with torch.no_grad(): - audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1).cpu().float() - self._perf.stop("vae_decode") - state.last_decoded_audio = audio - return audio + # Cumulative audio preserves the existing "last element = complete audio" + # semantic relied on by tests, examples and the speech serving layer. + state.last_decoded_audio = torch.cat(state.audio_chunks, dim=0) + self._perf.stop("vae_decode") return state.last_decoded_audio # -------------------- compute_logits -------------------- @@ -830,7 +871,9 @@ def preprocess( state = self._get_or_create_state(req_id) state.prefill_text = "" - state.accumulated_patches = [] + state.pending_latents = [] + state.decode_pad = None + state.audio_chunks = [] state.prefill_completed = False state.decode_step_count = 0 state.precomputed_stop_logits = None From 41b878e6ad12654ba5a90a75f6e19af80f523384 Mon Sep 17 00:00:00 2001 From: Yueqian Lin Date: Mon, 13 Apr 2026 20:42:50 -0400 Subject: [PATCH 2/6] perf(voxcpm2): avoid GPU->CPU syncs in CFM diffusion loop The inner Euler integration in _optimized_solve_euler called .item() on the 0-dim GPU tensors t and dt up to 4 times per diffusion step, forcing a GPU->CPU sync every time. With n_timesteps=10 and ~4 syncs per step that is ~40 syncs per AR decode step; profiling counted ~4k aten::_local_scalar_dense calls over a long generation. Broadcast the 0-dim tensors directly via .copy_() instead, keeping the work on-device. Also gate the one-shot prefill norm log behind an isEnabledFor(DEBUG) check so it no longer syncs on every request. Signed-off-by: Yueqian Lin --- .../models/voxcpm2/voxcpm2_talker.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 4afc54548ed..750e5d5457c 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -11,6 +11,7 @@ from __future__ import annotations import dataclasses +import logging import os import time from collections.abc import Iterable @@ -234,11 +235,11 @@ def _optimized_solve_euler( buffers.x_in[b : 2 * b].copy_(x) buffers.mu_in[:b].copy_(mu) buffers.mu_in[b : 2 * b].zero_() - buffers.t_in[:b].fill_(t.item()) - buffers.t_in[b : 2 * b].fill_(t.item()) + # Broadcast the 0-dim GPU scalar directly instead of + # ``.fill_(t.item())`` — ``.item()`` forces a GPU->CPU sync. + buffers.t_in[: 2 * b].copy_(t) if mean_mode: - buffers.dt_in[:b].fill_(dt.item()) - buffers.dt_in[b : 2 * b].fill_(dt.item()) + buffers.dt_in[: 2 * b].copy_(dt) else: buffers.dt_in.zero_() buffers.cond_in[:b].copy_(cond[:b]) @@ -268,9 +269,10 @@ def _optimized_solve_euler( else: buffers.x_in[:b].copy_(x) buffers.mu_in[:b].copy_(mu) - buffers.t_in[:b].fill_(t.item()) + # Broadcast the 0-dim GPU scalar; ``.fill_(t.item())`` would sync. + buffers.t_in[:b].copy_(t) if mean_mode: - buffers.dt_in[:b].fill_(dt.item()) + buffers.dt_in[:b].copy_(dt) else: buffers.dt_in[:b].zero_() buffers.cond_in[:b].copy_(cond[:b]) @@ -695,7 +697,9 @@ def _finish_prefill(self, state: _RequestState, meta: dict, res_out: torch.Tenso state.request_start_time = time.perf_counter() state.prefill_completed = True - logger.info("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item()) + if logger.isEnabledFor(logging.DEBUG): + # Only compute the norm (which forces a GPU->CPU sync) if we will log it. + logger.debug("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item()) self._perf.reset() def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor, dev: Any): From 62d5e1234a26d3cd26835cea3c47974c9af2ec1d Mon Sep 17 00:00:00 2001 From: Yueqian Lin Date: Mon, 13 Apr 2026 21:28:51 -0400 Subject: [PATCH 3/6] perf(voxcpm2): compile whole Model.forward instead of per-submodule PR #2690 compiled `layer.mlp` and `layer.self_attn.o_proj` separately (2 compiled regions per layer, fullgraph=True). Profiling showed 1,737 per-layer compiled-region dispatches on a long prompt at ~530 us CPU self-time each (~925 ms of pure Dynamo dispatch overhead). Wrap `Model.forward` in a single `torch.compile(fullgraph=False)` so Dynamo traces the full 28-layer loop once. Graph breaks at PagedAttention produce sub-graphs that are memoised after the first step, collapsing per-step Python dispatch from 28+ calls to a handful. Same treatment for the 8-layer residual model. Benchmarked on H20: RTF dropped from 0.197 to 0.126 (36%) on the long prompt, matching or beating nanovllm-voxcpm on short prompts. Signed-off-by: Yueqian Lin --- .../models/voxcpm2/minicpm4_paged.py | 71 ++++++++----------- .../models/voxcpm2/voxcpm2_talker.py | 7 +- 2 files changed, 33 insertions(+), 45 deletions(-) diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py index 7ea5bc229dc..40bacfff6c7 100644 --- a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py +++ b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py @@ -308,31 +308,28 @@ def forward( return hidden_states def compile_selective(self) -> list[str]: - """Compile MLP + o_proj; keep RMSNorm/RoPE eager for precision.""" - compiled: list[str] = [] - for i, layer in enumerate(self.layers): - if i in self._compiled_layers: - continue - try: - layer.mlp = torch.compile( - layer.mlp, - mode="default", - fullgraph=True, - ) - layer.self_attn.o_proj = torch.compile( - layer.self_attn.o_proj, - mode="default", - fullgraph=True, - ) - layer.self_attn._fused_qkv_weight = None - self._compiled_layers.add(i) - if i == 0: - compiled.append(f"layers.*.mlp (×{len(self.layers)})") - compiled.append(f"layers.*.self_attn.o_proj (×{len(self.layers)})") - except Exception as e: - logger.warning("compile_selective: layer %d failed: %s", i, e) - break - return compiled + """Compile the full model forward as one graph. + + Earlier versions compiled ``layer.mlp`` + ``layer.self_attn.o_proj`` + (PR #2690) and then the whole ``layer`` (perf/voxcpm2-streaming-vae). + Both still paid one Dynamo dispatch per layer per decode step. + V3 profiling showed 1,332 per-layer dispatches (~28 layers × ~47 + decode steps) costing ~726 ms of CPU self-time for a long prompt. + + Compiling ``forward`` at the model level lets Dynamo unroll the + 28-layer Python loop inside the graph. Graph breaks at + PagedAttention produce sub-graphs but Dynamo memoises the whole + trace once, so the per-step dispatch drops from 28 to just a few. + """ + if self._compiled_layers: + return [] + # Null the fused-qkv caches so the compile sees the real weight layout. + for layer in self.layers: + layer.self_attn._fused_qkv_weight = None + self.forward = torch.compile(self.forward, mode="default", fullgraph=False) + # Mark every layer as compiled so idempotent callers don't double-wrap. + self._compiled_layers.update(range(len(self.layers))) + return ["forward (whole model)"] def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights from native checkpoint (base_lm. prefix pre-stripped).""" @@ -415,22 +412,14 @@ def forward( return hidden_states def compile_selective(self) -> list[str]: - """Compile MLP + o_proj (same as base_lm).""" - compiled: list[str] = [] - for i, layer in enumerate(self.layers): - if i in self._compiled_layers: - continue - try: - layer.mlp = torch.compile(layer.mlp, mode="default", fullgraph=True) - layer.self_attn.o_proj = torch.compile(layer.self_attn.o_proj, mode="default", fullgraph=True) - layer.self_attn._fused_qkv_weight = None - self._compiled_layers.add(i) - if i == 0: - compiled.append(f"layers.*.mlp (×{len(self.layers)})") - compiled.append(f"layers.*.self_attn.o_proj (×{len(self.layers)})") - except Exception as e: - logger.warning("compile_selective: residual layer %d failed: %s", i, e) - return compiled + """Compile the full residual model forward as one graph (same strategy as base_lm).""" + if self._compiled_layers: + return [] + for layer in self.layers: + layer.self_attn._fused_qkv_weight = None + self.forward = torch.compile(self.forward, mode="default", fullgraph=False) + self._compiled_layers.update(range(len(self.layers))) + return ["forward (whole residual)"] def load_weights_from_native(self, native_residual_lm: nn.Module) -> int: """Load weights from native residual_lm. Returns param count.""" diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 750e5d5457c..22c3b9c88ff 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -753,8 +753,7 @@ def _collect_audio(self, state: _RequestState) -> torch.Tensor | None: return state.last_decoded_audio pending_count = len(state.pending_latents) - if not (pending_count >= self._vae_decode_interval or state.is_stopping - or state.last_decoded_audio is None): + if not (pending_count >= self._vae_decode_interval or state.is_stopping or state.last_decoded_audio is None): return state.last_decoded_audio self._perf.start("vae_decode") @@ -774,11 +773,11 @@ def _collect_audio(self, state: _RequestState) -> torch.Tensor | None: # Slice out the newly-generated audio (everything after the pad region). chunk_size = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_latents.shape[0])) - new_audio = audio[pad_frames * chunk_size:].detach().cpu().float() + new_audio = audio[pad_frames * chunk_size :].detach().cpu().float() state.audio_chunks.append(new_audio) # Roll the pad buffer to the last N frames of the current input. - state.decode_pad = vae_latents[-self._n_decode_pad_frames:].detach() + state.decode_pad = vae_latents[-self._n_decode_pad_frames :].detach() # Cumulative audio preserves the existing "last element = complete audio" # semantic relied on by tests, examples and the speech serving layer. From 82e3dc1458cd6b045a10f6f985cc52c205f58a5e Mon Sep 17 00:00:00 2001 From: Yueqian Lin Date: Mon, 13 Apr 2026 21:47:15 -0400 Subject: [PATCH 4/6] feat(voxcpm2): add streaming Gradio demo with voice cloning Also fix streaming: return delta audio chunks (not cumulative) from _collect_audio, and return None on steps without a VAE decode. The output processor accumulates deltas into a list; the speech streaming layer yields each new entry as a separate PCM chunk to the client. Previously, returning cumulative audio caused the client to replay the full audio from the start on every VAE decode interval. Signed-off-by: Yueqian Lin --- .../online_serving/voxcpm2/gradio_demo.py | 602 ++++++++++++++++++ .../models/voxcpm2/voxcpm2_talker.py | 69 +- 2 files changed, 633 insertions(+), 38 deletions(-) create mode 100644 examples/online_serving/voxcpm2/gradio_demo.py diff --git a/examples/online_serving/voxcpm2/gradio_demo.py b/examples/online_serving/voxcpm2/gradio_demo.py new file mode 100644 index 00000000000..a33a2d9245f --- /dev/null +++ b/examples/online_serving/voxcpm2/gradio_demo.py @@ -0,0 +1,602 @@ +"""Gradio demo for VoxCPM2 TTS with gapless streaming audio playback. + +Uses a custom AudioWorklet-based player for gap-free streaming +(adapted from the Qwen3-TTS demo). Audio is streamed from the vLLM +server through a same-origin proxy and played via the Web Audio API's +AudioWorklet, which maintains a FIFO buffer queue and plays samples at +the audio clock rate. + +Usage: + # Start the vLLM server first: + python -m vllm_omni.entrypoints.openai.api_server \ + --model openbmb/VoxCPM2 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \ + --host 0.0.0.0 --port 8000 + + # Then launch the demo: + python gradio_demo.py --api-base http://localhost:8000 +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import json +import logging + +import gradio as gr +import httpx +import numpy as np +import soundfile as sf +from fastapi import FastAPI, Request +from fastapi.responses import Response, StreamingResponse + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 48000 + +# ── AudioWorklet processor (loaded in browser via Blob URL) ────────── +WORKLET_JS = r""" +class TTSPlaybackProcessor extends AudioWorkletProcessor { + constructor() { + super(); + this.queue = []; + this.buf = null; + this.pos = 0; + this.playing = false; + this.played = 0; + this.port.onmessage = (e) => { + if (e.data && e.data.type === 'clear') { + this.queue = []; this.buf = null; this.pos = 0; this.played = 0; + if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped'}); } + return; + } + this.queue.push(e.data); + }; + } + process(inputs, outputs) { + const out = outputs[0][0]; + for (let i = 0; i < out.length; i++) { + if (!this.buf || this.pos >= this.buf.length) { + if (this.queue.length > 0) { + this.buf = this.queue.shift(); this.pos = 0; + } else { + for (let j = i; j < out.length; j++) out[j] = 0; + if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped', played:this.played}); } + return true; + } + } + out[i] = this.buf[this.pos++] / 32768; + this.played++; + } + if (!this.playing) { this.playing = true; this.port.postMessage({type:'started'}); } + return true; + } +} +registerProcessor('tts-playback-processor', TTSPlaybackProcessor); +""" + +PLAYER_HTML = """ +
+
+
+ Ready + +
+ + + +
+""" + + +def _build_player_js() -> str: + return f""" + +""" + + +def _encode_audio(audio_data: tuple) -> str: + sr, audio_np = audio_data + if audio_np.dtype in (np.float32, np.float64): + audio_np = np.clip(audio_np, -1.0, 1.0) + audio_np = (audio_np * 32767).astype(np.int16) + elif audio_np.dtype != np.int16: + audio_np = audio_np.astype(np.int16) + buf = io.BytesIO() + sf.write(buf, audio_np, sr, format="WAV") + return f"data:audio/wav;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +def create_app(api_base: str): + app = FastAPI() + _pending: dict[str, dict] = {} + + @app.post("/proxy/v1/audio/speech") + async def proxy_speech(request: Request): + body = await request.json() + req_id = body.get("_req_id") + if req_id and req_id in _pending: + body = _pending.pop(req_id) + logger.info("Proxy: %s", {k: (f"<{len(str(v))} chars>" if k == "ref_audio" else v) for k, v in body.items()}) + try: + client = httpx.AsyncClient(timeout=300) + resp = await client.send( + client.build_request( + "POST", + f"{api_base}/v1/audio/speech", + json=body, + headers={"Authorization": "Bearer EMPTY", "Content-Type": "application/json"}, + ), + stream=True, + ) + except Exception as exc: + logger.exception("Proxy connection error") + await client.aclose() + return Response(content=str(exc), status_code=502) + if resp.status_code != 200: + content = await resp.aread() + await resp.aclose() + await client.aclose() + return Response(content=content, status_code=resp.status_code) + + async def relay(): + try: + async for chunk in resp.aiter_bytes(): + yield chunk + finally: + await resp.aclose() + await client.aclose() + + return StreamingResponse(relay(), media_type="application/octet-stream") + + css = """ + #generate-btn button { width: 100%; } + #streaming-player { border: 1px solid var(--border-color-primary) !important; border-radius: var(--block-radius) !important; padding: var(--block-padding) !important; } + """ + theme = gr.themes.Default( + primary_hue=gr.themes.Color( + c50="#f0f5ff", + c100="#dce6f9", + c200="#b8cef3", + c300="#8eb2eb", + c400="#6496e0", + c500="#4A90D9", + c600="#3a7bc8", + c700="#2d66b0", + c800="#1f4f8f", + c900="#163a6e", + c950="#0e2650", + ), + ) + + with gr.Blocks(title="VoxCPM2 TTS Demo") as demo: + gr.HTML(f""" +
+ vLLM-Omni +
+

VoxCPM2 Streaming Demo

+ + Served by vLLM-Omni + · {api_base} + · 48 kHz + +
+
+ """) + + gr.Markdown( + "**Three modes:** " + "**Voice Design** (control instruction only) · " + "**Controllable Cloning** (ref audio + optional style control) · " + "**Ultimate Cloning** (ref audio + transcript for audio continuation)" + ) + + with gr.Row(): + with gr.Column(scale=3): + text_input = gr.Textbox( + label="Target Text", + placeholder="Enter text to synthesize...", + lines=4, + ) + control_instruction = gr.Textbox( + label="Control Instruction (optional)", + placeholder="e.g. A warm young woman / Excited and fast-paced", + lines=2, + info="Describe voice style, emotion, pace. Works for both Voice Design and Controllable Cloning.", + ) + + with gr.Accordion("Voice Cloning", open=False): + ref_audio = gr.Audio( + label="Reference Audio (upload for cloning)", + type="numpy", + sources=["upload", "microphone"], + ) + ref_audio_url = gr.Textbox( + label="or Reference Audio URL", + placeholder="https://example.com/reference.wav", + ) + ultimate_clone = gr.Checkbox( + label="Ultimate Cloning Mode", + value=False, + info="Provide transcript of ref audio for audio continuation (disables control instruction)", + ) + prompt_text = gr.Textbox( + label="Reference Audio Transcript", + placeholder="Transcript of your reference audio (for ultimate cloning)", + lines=2, + visible=False, + ) + + with gr.Row(): + stream_checkbox = gr.Checkbox( + label="Stream (gapless)", + value=True, + info="AudioWorklet streaming", + ) + with gr.Row(): + generate_btn = gr.Button( + "Generate Speech", + variant="primary", + size="lg", + elem_id="generate-btn", + scale=3, + ) + reset_btn = gr.Button("Reset", variant="secondary", size="lg", scale=1) + + with gr.Column(scale=2): + player_html = gr.HTML( + value=PLAYER_HTML, + visible=True, + label="streaming player", + elem_id="streaming-player", + ) + audio_output = gr.Audio( + label="generated audio", + interactive=False, + autoplay=True, + visible=False, + ) + gr.Examples( + examples=[ + ["Hello, this is a VoxCPM2 demo running on vLLM-Omni.", ""], + [ + "I have a dream that my four little children will one day live in a nation " + "where they will not be judged by the color of their skin but by the content " + "of their character.", + "", + ], + [ + "I never asked you to stay. It's not like I care or anything. " + "But why does it still hurt so much now that you're gone?", + "A young girl with a soft, sweet voice. Speaks slowly with a melancholic tone.", + ], + ], + inputs=[text_input, control_instruction], + label="examples", + ) + gr.HTML(""" +
+ + vLLM-Omni + +
+ """) + + hidden_payload = gr.Textbox(visible=False, elem_id="tts-payload") + + def on_ultimate_toggle(checked): + return ( + gr.update(visible=checked), # prompt_text + gr.update(interactive=not checked), # control_instruction + ) + + ultimate_clone.change( + fn=on_ultimate_toggle, + inputs=[ultimate_clone], + outputs=[prompt_text, control_instruction], + ) + + def on_stream_change(stream: bool): + if stream: + return gr.update(visible=True), gr.update(visible=False) + return gr.update(visible=False), gr.update(visible=True) + + stream_checkbox.change( + fn=on_stream_change, + inputs=[stream_checkbox], + outputs=[player_html, audio_output], + ) + + def on_reset(): + return "", "", None, "", False, "", PLAYER_HTML + + reset_btn.click( + fn=on_reset, + outputs=[ + text_input, + control_instruction, + audio_output, + hidden_payload, + ultimate_clone, + prompt_text, + player_html, + ], + js="() => { if (window.ttsStop) window.ttsStop(); }", + ) + + def on_generate(stream_enabled, text, ctrl_instr, ref_a, ref_url, ult_clone, p_text): + import time as _time + + if not text or not text.strip(): + raise gr.Error("Please enter text to synthesize.") + + # VoxCPM2 uses "(instruction)text" format for control + ctrl = ctrl_instr.strip() if ctrl_instr and not ult_clone else "" + final_text = f"({ctrl}){text.strip()}" if ctrl else text.strip() + + payload: dict = { + "input": final_text, + "voice": "default", + "response_format": "pcm" if stream_enabled else "wav", + "stream": stream_enabled, + } + + # Reference audio for cloning + ref_url_s = ref_url.strip() if ref_url else "" + if ref_url_s: + payload["ref_audio"] = ref_url_s + elif ref_a is not None: + payload["ref_audio"] = _encode_audio(ref_a) + + # Ultimate cloning: prompt_audio + prompt_text for continuation + if ult_clone and p_text and p_text.strip(): + if ref_url_s: + payload["prompt_audio"] = ref_url_s + elif ref_a is not None: + payload["prompt_audio"] = payload.get("ref_audio", "") + payload["prompt_text"] = p_text.strip() + + if stream_enabled: + if ref_a is not None and not ref_url_s: + req_id = f"req-{int(_time.time() * 1000)}" + _pending[req_id] = payload + browser_payload = {"_req_id": req_id, "_nonce": int(_time.time() * 1000)} + return json.dumps(browser_payload), gr.update() + payload["_nonce"] = int(_time.time() * 1000) + return json.dumps(payload), gr.update() + else: + try: + with httpx.Client(timeout=300.0) as client: + resp = client.post( + f"{api_base}/v1/audio/speech", + json=payload, + headers={"Content-Type": "application/json", "Authorization": "Bearer EMPTY"}, + ) + except httpx.ConnectError: + raise gr.Error(f"Cannot connect to server at {api_base}.") + if resp.status_code != 200: + raise gr.Error(f"Server error ({resp.status_code}): {resp.text[:200]}") + audio_np, sr = sf.read(io.BytesIO(resp.content)) + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + return "", (sr, audio_np.astype(np.float32)) + + generate_btn.click( + fn=on_generate, + inputs=[ + stream_checkbox, + text_input, + control_instruction, + ref_audio, + ref_audio_url, + ultimate_clone, + prompt_text, + ], + outputs=[hidden_payload, audio_output], + ).then( + fn=lambda p: p, + inputs=[hidden_payload], + outputs=[hidden_payload], + js="(p) => { if (p && p.trim()) { const d = JSON.parse(p); delete d._nonce; window.ttsGenerate(d); } return p; }", + ) + + demo.queue() + + return gr.mount_gradio_app(app, demo, path="/", css=css, theme=theme, head=_build_player_js()) + + +def main(): + parser = argparse.ArgumentParser(description="VoxCPM2 streaming Gradio demo") + parser.add_argument("--api-base", default="http://localhost:8000", help="vLLM API server URL") + parser.add_argument("--host", default="0.0.0.0", help="Gradio server host") + parser.add_argument("--port", type=int, default=7860, help="Gradio server port") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + print(f"Connecting to vLLM server at: {args.api_base}") + + import uvicorn + + uvicorn.run(create_app(args.api_base), host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 22c3b9c88ff..94f06589046 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -86,8 +86,6 @@ class _RequestState: curr_prefix_feat_cond: torch.Tensor | None = None last_audio_patch_gpu: torch.Tensor | None = None precomputed_stop_logits: torch.Tensor | None = None - # Patches produced since the last VAE decode, shape (patch_size, feat_dim) each. - pending_latents: list[torch.Tensor] = dataclasses.field(default_factory=list) # Rolling tail of previously-decoded latents used as VAE receptive-field context. # Shape (n_pad_frames, feat_dim) on GPU. None before first decode. decode_pad: torch.Tensor | None = None @@ -327,7 +325,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._inference_timesteps = 10 self._cfg_value = 2.0 self._cfg_cutoff_ratio = 1.0 - self._vae_decode_interval = 5 # Number of trailing latent frames to keep as VAE receptive-field context # for sliding-window streaming decode. 12 matches the nanovllm reference # implementation and covers the longest VAE decoder receptive field. @@ -733,57 +730,54 @@ def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor # -------------------- audio collection -------------------- def _collect_audio(self, state: _RequestState) -> torch.Tensor | None: - """Sliding-window VAE decode. - - Each decode step appends the newly-generated latent patch to - ``pending_latents``. Every ``_vae_decode_interval`` steps (or on the - final step) we feed ``[decode_pad, *pending_latents]`` through the VAE - once, slice out only the *new* audio region, append it to - ``audio_chunks`` and refresh ``decode_pad`` with the last - ``_n_decode_pad_frames`` latent frames. Total VAE work is O(N) in - the number of decode steps instead of O(N**2) for the full re-decode. + """Per-step sliding-window VAE decode (nanovllm pattern). + + Each decode step feeds ``[decode_pad, new_patch]`` through the VAE + and slices out only the audio region corresponding to the new patch. + The pad buffer (last ``_n_decode_pad_frames`` latent frames) provides + the receptive-field context needed by the VAE's transposed convolutions, + eliminating boundary artifacts between chunks. + + Returns the delta audio chunk (not cumulative) so the output processor + can stream each chunk to the client independently. """ patch = state.last_audio_patch_gpu - if patch is not None: - state.last_audio_patch_gpu = None - # patch shape is (1, patch_size, feat_dim); keep on GPU as float32. - state.pending_latents.append(patch.reshape(-1, self._feat_dim).to(torch.float32)) - - if not state.pending_latents: - return state.last_decoded_audio + if patch is None: + return None + state.last_audio_patch_gpu = None - pending_count = len(state.pending_latents) - if not (pending_count >= self._vae_decode_interval or state.is_stopping or state.last_decoded_audio is None): - return state.last_decoded_audio + # patch shape: (patch_size, feat_dim) or (1, patch_size, feat_dim) + new_latent = patch.reshape(-1, self._feat_dim).to(torch.float32) + n_new = new_latent.shape[0] # = patch_size (typically 4) self._perf.start("vae_decode") - new_latents = torch.cat(state.pending_latents, dim=0) # (T_new, D) - state.pending_latents = [] + # Build VAE input: [pad_frames | new_latent] if state.decode_pad is not None: - vae_latents = torch.cat([state.decode_pad, new_latents], dim=0) + vae_input = torch.cat([state.decode_pad, new_latent], dim=0) pad_frames = state.decode_pad.shape[0] else: - vae_latents = new_latents + vae_input = new_latent pad_frames = 0 - feat = vae_latents.reshape(1, -1, self._feat_dim).transpose(1, 2).contiguous() + # VAE decode: (1, feat_dim, T_frames) -> (1, 1, T_samples) + feat = vae_input.unsqueeze(0).transpose(1, 2).contiguous() with torch.no_grad(): audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1) - # Slice out the newly-generated audio (everything after the pad region). - chunk_size = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_latents.shape[0])) - new_audio = audio[pad_frames * chunk_size :].detach().cpu().float() - state.audio_chunks.append(new_audio) + # Slice out only the new audio (after the pad region). + # Each latent frame maps to decoder_chunk_size audio samples. + dcs = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_input.shape[0])) + new_audio = audio[pad_frames * dcs : (pad_frames + n_new) * dcs].detach().cpu().float() - # Roll the pad buffer to the last N frames of the current input. - state.decode_pad = vae_latents[-self._n_decode_pad_frames :].detach() + # Roll the pad buffer: keep last N latent frames as context for next step. + all_latents = vae_input # [pad + new] + state.decode_pad = all_latents[-self._n_decode_pad_frames :].detach() - # Cumulative audio preserves the existing "last element = complete audio" - # semantic relied on by tests, examples and the speech serving layer. - state.last_decoded_audio = torch.cat(state.audio_chunks, dim=0) + state.audio_chunks.append(new_audio) + state.last_decoded_audio = new_audio self._perf.stop("vae_decode") - return state.last_decoded_audio + return new_audio # -------------------- compute_logits -------------------- @@ -874,7 +868,6 @@ def preprocess( state = self._get_or_create_state(req_id) state.prefill_text = "" - state.pending_latents = [] state.decode_pad = None state.audio_chunks = [] state.prefill_completed = False From fa66e989afe7e37661bb47b5a2c714e5984ccc68 Mon Sep 17 00:00:00 2001 From: Yueqian Lin Date: Mon, 13 Apr 2026 22:52:28 -0400 Subject: [PATCH 5/6] fix(voxcpm2): concatenate delta audio chunks at output consolidation The streaming VAE change (ff7b5af6) switched _collect_audio to return per-step delta chunks instead of cumulative audio. Offline consumers then received only the last chunk (~0.16s) because _consolidate_multimodal_tensors in engine/output_processor.py skipped concatenation for the 'audio' key, and the non-streaming speech server kept only the last list entry (~16 KB WAV, empty-sounding output). Fix at the structural root: have the consolidation step concatenate audio delta chunks into the full waveform (flatten each to 1-D first to tolerate inconsistent leading dims). Consolidation only runs on finished=True so streaming is unaffected. Offline extract_audio helpers add a defensive torch.cat fallback for mid-stream list snapshots; normal completed requests now see a single consolidated tensor. Reported by @gesla2024 in #2758, root-caused by @Sy0307. Signed-off-by: Yueqian Lin --- examples/offline_inference/voxcpm2/end2end.py | 6 ++++-- tests/e2e/offline_inference/test_voxcpm2.py | 6 +++--- vllm_omni/engine/output_processor.py | 7 ++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py index ce404bf962d..8f43c1bbd75 100644 --- a/examples/offline_inference/voxcpm2/end2end.py +++ b/examples/offline_inference/voxcpm2/end2end.py @@ -79,11 +79,13 @@ def extract_audio(multimodal_output: dict) -> torch.Tensor: raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}") if isinstance(audio, list): - # Take the last valid tensor (most complete audio) + # Defensive: usually the output processor consolidates into a single + # tensor at request completion, but concatenate here too in case the + # caller consumes intermediate (pre-consolidation) outputs. valid = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio if a is not None] if not valid: raise ValueError("Audio list is empty or all elements are None.") - return valid[-1] + return torch.cat(valid, dim=0) if len(valid) > 1 else valid[0] return torch.as_tensor(audio).float().cpu().reshape(-1) diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py index 4e4f635d5c4..97e29b7167a 100644 --- a/tests/e2e/offline_inference/test_voxcpm2.py +++ b/tests/e2e/offline_inference/test_voxcpm2.py @@ -33,14 +33,14 @@ def _extract_audio(multimodal_output: dict) -> torch.Tensor: """Extract the final complete audio tensor from multimodal output.""" assert isinstance(multimodal_output, dict), f"Expected dict, got {type(multimodal_output)}" - # Output processor accumulates per-step full audio under "audio". + # Output processor accumulates per-step audio chunks under "audio". audio = multimodal_output.get("audio") or multimodal_output.get("model_outputs") assert audio is not None, f"No audio key, got {list(multimodal_output.keys())}" if isinstance(audio, list): - valid = [x for x in audio if isinstance(x, torch.Tensor) and x.numel() > 100] + valid = [torch.as_tensor(x).float().cpu().reshape(-1) for x in audio if x is not None] assert valid, "No valid audio tensors in output list" - audio = valid[-1] + audio = torch.cat(valid, dim=0) if len(valid) > 1 else valid[0] assert isinstance(audio, torch.Tensor), f"Expected Tensor, got {type(audio)}" return audio diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 43d02e85b84..badd799fc94 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -118,9 +118,10 @@ def _consolidate_multimodal_tensors(self) -> None: if isinstance(v, list) and v and isinstance(v[0], torch.Tensor): try: if k == "audio": - # When the audio tensor shape is inconsistent, torch.cat will fail. - # We need to use torch.cat in -1 dimension. - continue + # Concatenate delta audio chunks (1-D) into the full waveform. + # Each entry is a per-step slice; flatten to -1 so chunks with + # inconsistent leading dims can still be joined on the sample axis. + self.mm_accumulated[k] = torch.cat([t.reshape(-1) for t in v], dim=0) elif k == "sr": # Sample rate is a constant scalar, keep last value. self.mm_accumulated[k] = v[-1] From d9fc07a0402917201fed254cbc3a93b81d24a8e2 Mon Sep 17 00:00:00 2001 From: Yueqian Lin Date: Tue, 14 Apr 2026 20:04:33 -0500 Subject: [PATCH 6/6] fix(voxcpm2): avoid tensor truthiness in audio extraction helpers After the output_processor consolidation fix, multimodal_output["audio"] is a Tensor rather than a list, so `dict.get("audio") or dict.get(...)` raises "Boolean value of Tensor with more than one value is ambiguous". Use explicit None-checks instead of `or` short-circuiting. Signed-off-by: Yueqian Lin --- examples/offline_inference/voxcpm2/end2end.py | 4 +++- tests/e2e/offline_inference/test_voxcpm2.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py index 8f43c1bbd75..687e596018c 100644 --- a/examples/offline_inference/voxcpm2/end2end.py +++ b/examples/offline_inference/voxcpm2/end2end.py @@ -74,7 +74,9 @@ def extract_audio(multimodal_output: dict) -> torch.Tensor: The output processor concatenates per-step delta tensors under ``model_outputs``. Falls back to ``audio`` for backwards compat. """ - audio = multimodal_output.get("model_outputs") or multimodal_output.get("audio") + audio = multimodal_output.get("model_outputs") + if audio is None: + audio = multimodal_output.get("audio") if audio is None: raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}") diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py index 97e29b7167a..6ec4630a45e 100644 --- a/tests/e2e/offline_inference/test_voxcpm2.py +++ b/tests/e2e/offline_inference/test_voxcpm2.py @@ -34,7 +34,9 @@ def _extract_audio(multimodal_output: dict) -> torch.Tensor: assert isinstance(multimodal_output, dict), f"Expected dict, got {type(multimodal_output)}" # Output processor accumulates per-step audio chunks under "audio". - audio = multimodal_output.get("audio") or multimodal_output.get("model_outputs") + audio = multimodal_output.get("audio") + if audio is None: + audio = multimodal_output.get("model_outputs") assert audio is not None, f"No audio key, got {list(multimodal_output.keys())}" if isinstance(audio, list):