diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py
index ce404bf962d..687e596018c 100644
--- a/examples/offline_inference/voxcpm2/end2end.py
+++ b/examples/offline_inference/voxcpm2/end2end.py
@@ -74,16 +74,20 @@ def extract_audio(multimodal_output: dict) -> torch.Tensor:
The output processor concatenates per-step delta tensors under
``model_outputs``. Falls back to ``audio`` for backwards compat.
"""
- audio = multimodal_output.get("model_outputs") or multimodal_output.get("audio")
+ audio = multimodal_output.get("model_outputs")
+ if audio is None:
+ audio = multimodal_output.get("audio")
if audio is None:
raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}")
if isinstance(audio, list):
- # Take the last valid tensor (most complete audio)
+ # Defensive: usually the output processor consolidates into a single
+ # tensor at request completion, but concatenate here too in case the
+ # caller consumes intermediate (pre-consolidation) outputs.
valid = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio if a is not None]
if not valid:
raise ValueError("Audio list is empty or all elements are None.")
- return valid[-1]
+ return torch.cat(valid, dim=0) if len(valid) > 1 else valid[0]
return torch.as_tensor(audio).float().cpu().reshape(-1)
diff --git a/examples/online_serving/voxcpm2/gradio_demo.py b/examples/online_serving/voxcpm2/gradio_demo.py
new file mode 100644
index 00000000000..a33a2d9245f
--- /dev/null
+++ b/examples/online_serving/voxcpm2/gradio_demo.py
@@ -0,0 +1,602 @@
+"""Gradio demo for VoxCPM2 TTS with gapless streaming audio playback.
+
+Uses a custom AudioWorklet-based player for gap-free streaming
+(adapted from the Qwen3-TTS demo). Audio is streamed from the vLLM
+server through a same-origin proxy and played via the Web Audio API's
+AudioWorklet, which maintains a FIFO buffer queue and plays samples at
+the audio clock rate.
+
+Usage:
+ # Start the vLLM server first:
+ python -m vllm_omni.entrypoints.openai.api_server \
+ --model openbmb/VoxCPM2 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \
+ --host 0.0.0.0 --port 8000
+
+ # Then launch the demo:
+ python gradio_demo.py --api-base http://localhost:8000
+"""
+
+from __future__ import annotations
+
+import argparse
+import base64
+import io
+import json
+import logging
+
+import gradio as gr
+import httpx
+import numpy as np
+import soundfile as sf
+from fastapi import FastAPI, Request
+from fastapi.responses import Response, StreamingResponse
+
+logger = logging.getLogger(__name__)
+
+SAMPLE_RATE = 48000
+
+# ── AudioWorklet processor (loaded in browser via Blob URL) ──────────
+WORKLET_JS = r"""
+class TTSPlaybackProcessor extends AudioWorkletProcessor {
+ constructor() {
+ super();
+ this.queue = [];
+ this.buf = null;
+ this.pos = 0;
+ this.playing = false;
+ this.played = 0;
+ this.port.onmessage = (e) => {
+ if (e.data && e.data.type === 'clear') {
+ this.queue = []; this.buf = null; this.pos = 0; this.played = 0;
+ if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped'}); }
+ return;
+ }
+ this.queue.push(e.data);
+ };
+ }
+ process(inputs, outputs) {
+ const out = outputs[0][0];
+ for (let i = 0; i < out.length; i++) {
+ if (!this.buf || this.pos >= this.buf.length) {
+ if (this.queue.length > 0) {
+ this.buf = this.queue.shift(); this.pos = 0;
+ } else {
+ for (let j = i; j < out.length; j++) out[j] = 0;
+ if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped', played:this.played}); }
+ return true;
+ }
+ }
+ out[i] = this.buf[this.pos++] / 32768;
+ this.played++;
+ }
+ if (!this.playing) { this.playing = true; this.port.postMessage({type:'started'}); }
+ return true;
+ }
+}
+registerProcessor('tts-playback-processor', TTSPlaybackProcessor);
+"""
+
+PLAYER_HTML = """
+
+"""
+
+
+def _build_player_js() -> str:
+ return f"""
+
+"""
+
+
+def _encode_audio(audio_data: tuple) -> str:
+ sr, audio_np = audio_data
+ if audio_np.dtype in (np.float32, np.float64):
+ audio_np = np.clip(audio_np, -1.0, 1.0)
+ audio_np = (audio_np * 32767).astype(np.int16)
+ elif audio_np.dtype != np.int16:
+ audio_np = audio_np.astype(np.int16)
+ buf = io.BytesIO()
+ sf.write(buf, audio_np, sr, format="WAV")
+ return f"data:audio/wav;base64,{base64.b64encode(buf.getvalue()).decode()}"
+
+
+def create_app(api_base: str):
+ app = FastAPI()
+ _pending: dict[str, dict] = {}
+
+ @app.post("/proxy/v1/audio/speech")
+ async def proxy_speech(request: Request):
+ body = await request.json()
+ req_id = body.get("_req_id")
+ if req_id and req_id in _pending:
+ body = _pending.pop(req_id)
+ logger.info("Proxy: %s", {k: (f"<{len(str(v))} chars>" if k == "ref_audio" else v) for k, v in body.items()})
+ try:
+ client = httpx.AsyncClient(timeout=300)
+ resp = await client.send(
+ client.build_request(
+ "POST",
+ f"{api_base}/v1/audio/speech",
+ json=body,
+ headers={"Authorization": "Bearer EMPTY", "Content-Type": "application/json"},
+ ),
+ stream=True,
+ )
+ except Exception as exc:
+ logger.exception("Proxy connection error")
+ await client.aclose()
+ return Response(content=str(exc), status_code=502)
+ if resp.status_code != 200:
+ content = await resp.aread()
+ await resp.aclose()
+ await client.aclose()
+ return Response(content=content, status_code=resp.status_code)
+
+ async def relay():
+ try:
+ async for chunk in resp.aiter_bytes():
+ yield chunk
+ finally:
+ await resp.aclose()
+ await client.aclose()
+
+ return StreamingResponse(relay(), media_type="application/octet-stream")
+
+ css = """
+ #generate-btn button { width: 100%; }
+ #streaming-player { border: 1px solid var(--border-color-primary) !important; border-radius: var(--block-radius) !important; padding: var(--block-padding) !important; }
+ """
+ theme = gr.themes.Default(
+ primary_hue=gr.themes.Color(
+ c50="#f0f5ff",
+ c100="#dce6f9",
+ c200="#b8cef3",
+ c300="#8eb2eb",
+ c400="#6496e0",
+ c500="#4A90D9",
+ c600="#3a7bc8",
+ c700="#2d66b0",
+ c800="#1f4f8f",
+ c900="#163a6e",
+ c950="#0e2650",
+ ),
+ )
+
+ with gr.Blocks(title="VoxCPM2 TTS Demo") as demo:
+ gr.HTML(f"""
+
+

+
+
VoxCPM2 Streaming Demo
+
+ Served by vLLM-Omni
+ · {api_base}
+ · 48 kHz
+
+
+
+ """)
+
+ gr.Markdown(
+ "**Three modes:** "
+ "**Voice Design** (control instruction only) · "
+ "**Controllable Cloning** (ref audio + optional style control) · "
+ "**Ultimate Cloning** (ref audio + transcript for audio continuation)"
+ )
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ text_input = gr.Textbox(
+ label="Target Text",
+ placeholder="Enter text to synthesize...",
+ lines=4,
+ )
+ control_instruction = gr.Textbox(
+ label="Control Instruction (optional)",
+ placeholder="e.g. A warm young woman / Excited and fast-paced",
+ lines=2,
+ info="Describe voice style, emotion, pace. Works for both Voice Design and Controllable Cloning.",
+ )
+
+ with gr.Accordion("Voice Cloning", open=False):
+ ref_audio = gr.Audio(
+ label="Reference Audio (upload for cloning)",
+ type="numpy",
+ sources=["upload", "microphone"],
+ )
+ ref_audio_url = gr.Textbox(
+ label="or Reference Audio URL",
+ placeholder="https://example.com/reference.wav",
+ )
+ ultimate_clone = gr.Checkbox(
+ label="Ultimate Cloning Mode",
+ value=False,
+ info="Provide transcript of ref audio for audio continuation (disables control instruction)",
+ )
+ prompt_text = gr.Textbox(
+ label="Reference Audio Transcript",
+ placeholder="Transcript of your reference audio (for ultimate cloning)",
+ lines=2,
+ visible=False,
+ )
+
+ with gr.Row():
+ stream_checkbox = gr.Checkbox(
+ label="Stream (gapless)",
+ value=True,
+ info="AudioWorklet streaming",
+ )
+ with gr.Row():
+ generate_btn = gr.Button(
+ "Generate Speech",
+ variant="primary",
+ size="lg",
+ elem_id="generate-btn",
+ scale=3,
+ )
+ reset_btn = gr.Button("Reset", variant="secondary", size="lg", scale=1)
+
+ with gr.Column(scale=2):
+ player_html = gr.HTML(
+ value=PLAYER_HTML,
+ visible=True,
+ label="streaming player",
+ elem_id="streaming-player",
+ )
+ audio_output = gr.Audio(
+ label="generated audio",
+ interactive=False,
+ autoplay=True,
+ visible=False,
+ )
+ gr.Examples(
+ examples=[
+ ["Hello, this is a VoxCPM2 demo running on vLLM-Omni.", ""],
+ [
+ "I have a dream that my four little children will one day live in a nation "
+ "where they will not be judged by the color of their skin but by the content "
+ "of their character.",
+ "",
+ ],
+ [
+ "I never asked you to stay. It's not like I care or anything. "
+ "But why does it still hurt so much now that you're gone?",
+ "A young girl with a soft, sweet voice. Speaks slowly with a melancholic tone.",
+ ],
+ ],
+ inputs=[text_input, control_instruction],
+ label="examples",
+ )
+ gr.HTML("""
+
+ """)
+
+ hidden_payload = gr.Textbox(visible=False, elem_id="tts-payload")
+
+ def on_ultimate_toggle(checked):
+ return (
+ gr.update(visible=checked), # prompt_text
+ gr.update(interactive=not checked), # control_instruction
+ )
+
+ ultimate_clone.change(
+ fn=on_ultimate_toggle,
+ inputs=[ultimate_clone],
+ outputs=[prompt_text, control_instruction],
+ )
+
+ def on_stream_change(stream: bool):
+ if stream:
+ return gr.update(visible=True), gr.update(visible=False)
+ return gr.update(visible=False), gr.update(visible=True)
+
+ stream_checkbox.change(
+ fn=on_stream_change,
+ inputs=[stream_checkbox],
+ outputs=[player_html, audio_output],
+ )
+
+ def on_reset():
+ return "", "", None, "", False, "", PLAYER_HTML
+
+ reset_btn.click(
+ fn=on_reset,
+ outputs=[
+ text_input,
+ control_instruction,
+ audio_output,
+ hidden_payload,
+ ultimate_clone,
+ prompt_text,
+ player_html,
+ ],
+ js="() => { if (window.ttsStop) window.ttsStop(); }",
+ )
+
+ def on_generate(stream_enabled, text, ctrl_instr, ref_a, ref_url, ult_clone, p_text):
+ import time as _time
+
+ if not text or not text.strip():
+ raise gr.Error("Please enter text to synthesize.")
+
+ # VoxCPM2 uses "(instruction)text" format for control
+ ctrl = ctrl_instr.strip() if ctrl_instr and not ult_clone else ""
+ final_text = f"({ctrl}){text.strip()}" if ctrl else text.strip()
+
+ payload: dict = {
+ "input": final_text,
+ "voice": "default",
+ "response_format": "pcm" if stream_enabled else "wav",
+ "stream": stream_enabled,
+ }
+
+ # Reference audio for cloning
+ ref_url_s = ref_url.strip() if ref_url else ""
+ if ref_url_s:
+ payload["ref_audio"] = ref_url_s
+ elif ref_a is not None:
+ payload["ref_audio"] = _encode_audio(ref_a)
+
+ # Ultimate cloning: prompt_audio + prompt_text for continuation
+ if ult_clone and p_text and p_text.strip():
+ if ref_url_s:
+ payload["prompt_audio"] = ref_url_s
+ elif ref_a is not None:
+ payload["prompt_audio"] = payload.get("ref_audio", "")
+ payload["prompt_text"] = p_text.strip()
+
+ if stream_enabled:
+ if ref_a is not None and not ref_url_s:
+ req_id = f"req-{int(_time.time() * 1000)}"
+ _pending[req_id] = payload
+ browser_payload = {"_req_id": req_id, "_nonce": int(_time.time() * 1000)}
+ return json.dumps(browser_payload), gr.update()
+ payload["_nonce"] = int(_time.time() * 1000)
+ return json.dumps(payload), gr.update()
+ else:
+ try:
+ with httpx.Client(timeout=300.0) as client:
+ resp = client.post(
+ f"{api_base}/v1/audio/speech",
+ json=payload,
+ headers={"Content-Type": "application/json", "Authorization": "Bearer EMPTY"},
+ )
+ except httpx.ConnectError:
+ raise gr.Error(f"Cannot connect to server at {api_base}.")
+ if resp.status_code != 200:
+ raise gr.Error(f"Server error ({resp.status_code}): {resp.text[:200]}")
+ audio_np, sr = sf.read(io.BytesIO(resp.content))
+ if audio_np.ndim > 1:
+ audio_np = audio_np[:, 0]
+ return "", (sr, audio_np.astype(np.float32))
+
+ generate_btn.click(
+ fn=on_generate,
+ inputs=[
+ stream_checkbox,
+ text_input,
+ control_instruction,
+ ref_audio,
+ ref_audio_url,
+ ultimate_clone,
+ prompt_text,
+ ],
+ outputs=[hidden_payload, audio_output],
+ ).then(
+ fn=lambda p: p,
+ inputs=[hidden_payload],
+ outputs=[hidden_payload],
+ js="(p) => { if (p && p.trim()) { const d = JSON.parse(p); delete d._nonce; window.ttsGenerate(d); } return p; }",
+ )
+
+ demo.queue()
+
+ return gr.mount_gradio_app(app, demo, path="/", css=css, theme=theme, head=_build_player_js())
+
+
+def main():
+ parser = argparse.ArgumentParser(description="VoxCPM2 streaming Gradio demo")
+ parser.add_argument("--api-base", default="http://localhost:8000", help="vLLM API server URL")
+ parser.add_argument("--host", default="0.0.0.0", help="Gradio server host")
+ parser.add_argument("--port", type=int, default=7860, help="Gradio server port")
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.INFO)
+ print(f"Connecting to vLLM server at: {args.api_base}")
+
+ import uvicorn
+
+ uvicorn.run(create_app(args.api_base), host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py
index 4e4f635d5c4..6ec4630a45e 100644
--- a/tests/e2e/offline_inference/test_voxcpm2.py
+++ b/tests/e2e/offline_inference/test_voxcpm2.py
@@ -33,14 +33,16 @@ def _extract_audio(multimodal_output: dict) -> torch.Tensor:
"""Extract the final complete audio tensor from multimodal output."""
assert isinstance(multimodal_output, dict), f"Expected dict, got {type(multimodal_output)}"
- # Output processor accumulates per-step full audio under "audio".
- audio = multimodal_output.get("audio") or multimodal_output.get("model_outputs")
+ # Output processor accumulates per-step audio chunks under "audio".
+ audio = multimodal_output.get("audio")
+ if audio is None:
+ audio = multimodal_output.get("model_outputs")
assert audio is not None, f"No audio key, got {list(multimodal_output.keys())}"
if isinstance(audio, list):
- valid = [x for x in audio if isinstance(x, torch.Tensor) and x.numel() > 100]
+ valid = [torch.as_tensor(x).float().cpu().reshape(-1) for x in audio if x is not None]
assert valid, "No valid audio tensors in output list"
- audio = valid[-1]
+ audio = torch.cat(valid, dim=0) if len(valid) > 1 else valid[0]
assert isinstance(audio, torch.Tensor), f"Expected Tensor, got {type(audio)}"
return audio
diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py
index 43d02e85b84..badd799fc94 100644
--- a/vllm_omni/engine/output_processor.py
+++ b/vllm_omni/engine/output_processor.py
@@ -118,9 +118,10 @@ def _consolidate_multimodal_tensors(self) -> None:
if isinstance(v, list) and v and isinstance(v[0], torch.Tensor):
try:
if k == "audio":
- # When the audio tensor shape is inconsistent, torch.cat will fail.
- # We need to use torch.cat in -1 dimension.
- continue
+ # Concatenate delta audio chunks (1-D) into the full waveform.
+ # Each entry is a per-step slice; flatten to -1 so chunks with
+ # inconsistent leading dims can still be joined on the sample axis.
+ self.mm_accumulated[k] = torch.cat([t.reshape(-1) for t in v], dim=0)
elif k == "sr":
# Sample rate is a constant scalar, keep last value.
self.mm_accumulated[k] = v[-1]
diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
index 7ea5bc229dc..40bacfff6c7 100644
--- a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
+++ b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
@@ -308,31 +308,28 @@ def forward(
return hidden_states
def compile_selective(self) -> list[str]:
- """Compile MLP + o_proj; keep RMSNorm/RoPE eager for precision."""
- compiled: list[str] = []
- for i, layer in enumerate(self.layers):
- if i in self._compiled_layers:
- continue
- try:
- layer.mlp = torch.compile(
- layer.mlp,
- mode="default",
- fullgraph=True,
- )
- layer.self_attn.o_proj = torch.compile(
- layer.self_attn.o_proj,
- mode="default",
- fullgraph=True,
- )
- layer.self_attn._fused_qkv_weight = None
- self._compiled_layers.add(i)
- if i == 0:
- compiled.append(f"layers.*.mlp (×{len(self.layers)})")
- compiled.append(f"layers.*.self_attn.o_proj (×{len(self.layers)})")
- except Exception as e:
- logger.warning("compile_selective: layer %d failed: %s", i, e)
- break
- return compiled
+ """Compile the full model forward as one graph.
+
+ Earlier versions compiled ``layer.mlp`` + ``layer.self_attn.o_proj``
+ (PR #2690) and then the whole ``layer`` (perf/voxcpm2-streaming-vae).
+ Both still paid one Dynamo dispatch per layer per decode step.
+ V3 profiling showed 1,332 per-layer dispatches (~28 layers × ~47
+ decode steps) costing ~726 ms of CPU self-time for a long prompt.
+
+ Compiling ``forward`` at the model level lets Dynamo unroll the
+ 28-layer Python loop inside the graph. Graph breaks at
+ PagedAttention produce sub-graphs but Dynamo memoises the whole
+ trace once, so the per-step dispatch drops from 28 to just a few.
+ """
+ if self._compiled_layers:
+ return []
+ # Null the fused-qkv caches so the compile sees the real weight layout.
+ for layer in self.layers:
+ layer.self_attn._fused_qkv_weight = None
+ self.forward = torch.compile(self.forward, mode="default", fullgraph=False)
+ # Mark every layer as compiled so idempotent callers don't double-wrap.
+ self._compiled_layers.update(range(len(self.layers)))
+ return ["forward (whole model)"]
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights from native checkpoint (base_lm. prefix pre-stripped)."""
@@ -415,22 +412,14 @@ def forward(
return hidden_states
def compile_selective(self) -> list[str]:
- """Compile MLP + o_proj (same as base_lm)."""
- compiled: list[str] = []
- for i, layer in enumerate(self.layers):
- if i in self._compiled_layers:
- continue
- try:
- layer.mlp = torch.compile(layer.mlp, mode="default", fullgraph=True)
- layer.self_attn.o_proj = torch.compile(layer.self_attn.o_proj, mode="default", fullgraph=True)
- layer.self_attn._fused_qkv_weight = None
- self._compiled_layers.add(i)
- if i == 0:
- compiled.append(f"layers.*.mlp (×{len(self.layers)})")
- compiled.append(f"layers.*.self_attn.o_proj (×{len(self.layers)})")
- except Exception as e:
- logger.warning("compile_selective: residual layer %d failed: %s", i, e)
- return compiled
+ """Compile the full residual model forward as one graph (same strategy as base_lm)."""
+ if self._compiled_layers:
+ return []
+ for layer in self.layers:
+ layer.self_attn._fused_qkv_weight = None
+ self.forward = torch.compile(self.forward, mode="default", fullgraph=False)
+ self._compiled_layers.update(range(len(self.layers)))
+ return ["forward (whole residual)"]
def load_weights_from_native(self, native_residual_lm: nn.Module) -> int:
"""Load weights from native residual_lm. Returns param count."""
diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
index 0898ca59ae4..94f06589046 100644
--- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
+++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
@@ -11,6 +11,7 @@
from __future__ import annotations
import dataclasses
+import logging
import os
import time
from collections.abc import Iterable
@@ -19,7 +20,6 @@
import librosa
import torch
import torch.nn as nn
-from einops import rearrange
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.utils import (
@@ -86,7 +86,11 @@ class _RequestState:
curr_prefix_feat_cond: torch.Tensor | None = None
last_audio_patch_gpu: torch.Tensor | None = None
precomputed_stop_logits: torch.Tensor | None = None
- accumulated_patches: list[torch.Tensor] = dataclasses.field(default_factory=list)
+ # Rolling tail of previously-decoded latents used as VAE receptive-field context.
+ # Shape (n_pad_frames, feat_dim) on GPU. None before first decode.
+ decode_pad: torch.Tensor | None = None
+ # Audio chunks already emitted (CPU float32), concatenated for cumulative output.
+ audio_chunks: list[torch.Tensor] = dataclasses.field(default_factory=list)
decode_step_count: int = 0
request_start_time: float = 0.0
prefill_completed: bool = False
@@ -229,11 +233,11 @@ def _optimized_solve_euler(
buffers.x_in[b : 2 * b].copy_(x)
buffers.mu_in[:b].copy_(mu)
buffers.mu_in[b : 2 * b].zero_()
- buffers.t_in[:b].fill_(t.item())
- buffers.t_in[b : 2 * b].fill_(t.item())
+ # Broadcast the 0-dim GPU scalar directly instead of
+ # ``.fill_(t.item())`` — ``.item()`` forces a GPU->CPU sync.
+ buffers.t_in[: 2 * b].copy_(t)
if mean_mode:
- buffers.dt_in[:b].fill_(dt.item())
- buffers.dt_in[b : 2 * b].fill_(dt.item())
+ buffers.dt_in[: 2 * b].copy_(dt)
else:
buffers.dt_in.zero_()
buffers.cond_in[:b].copy_(cond[:b])
@@ -263,9 +267,10 @@ def _optimized_solve_euler(
else:
buffers.x_in[:b].copy_(x)
buffers.mu_in[:b].copy_(mu)
- buffers.t_in[:b].fill_(t.item())
+ # Broadcast the 0-dim GPU scalar; ``.fill_(t.item())`` would sync.
+ buffers.t_in[:b].copy_(t)
if mean_mode:
- buffers.dt_in[:b].fill_(dt.item())
+ buffers.dt_in[:b].copy_(dt)
else:
buffers.dt_in[:b].zero_()
buffers.cond_in[:b].copy_(cond[:b])
@@ -320,7 +325,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._inference_timesteps = 10
self._cfg_value = 2.0
self._cfg_cutoff_ratio = 1.0
- self._vae_decode_interval = 5
+ # Number of trailing latent frames to keep as VAE receptive-field context
+ # for sliding-window streaming decode. 12 matches the nanovllm reference
+ # implementation and covers the longest VAE decoder receptive field.
+ self._n_decode_pad_frames = 12
self._enable_torch_compile = True
self._compile_vae = True
self._max_decode_steps = 2000
@@ -686,7 +694,9 @@ def _finish_prefill(self, state: _RequestState, meta: dict, res_out: torch.Tenso
state.request_start_time = time.perf_counter()
state.prefill_completed = True
- logger.info("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item())
+ if logger.isEnabledFor(logging.DEBUG):
+ # Only compute the norm (which forces a GPU->CPU sync) if we will log it.
+ logger.debug("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item())
self._perf.reset()
def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor, dev: Any):
@@ -720,26 +730,54 @@ def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor
# -------------------- audio collection --------------------
def _collect_audio(self, state: _RequestState) -> torch.Tensor | None:
- patch = state.last_audio_patch_gpu
- if patch is not None:
- state.last_audio_patch_gpu = None
- state.accumulated_patches.append(patch.reshape(1, -1).float())
+ """Per-step sliding-window VAE decode (nanovllm pattern).
- if not state.accumulated_patches:
+ Each decode step feeds ``[decode_pad, new_patch]`` through the VAE
+ and slices out only the audio region corresponding to the new patch.
+ The pad buffer (last ``_n_decode_pad_frames`` latent frames) provides
+ the receptive-field context needed by the VAE's transposed convolutions,
+ eliminating boundary artifacts between chunks.
+
+ Returns the delta audio chunk (not cumulative) so the output processor
+ can stream each chunk to the client independently.
+ """
+ patch = state.last_audio_patch_gpu
+ if patch is None:
return None
+ state.last_audio_patch_gpu = None
+
+ # patch shape: (patch_size, feat_dim) or (1, patch_size, feat_dim)
+ new_latent = patch.reshape(-1, self._feat_dim).to(torch.float32)
+ n_new = new_latent.shape[0] # = patch_size (typically 4)
+
+ self._perf.start("vae_decode")
+
+ # Build VAE input: [pad_frames | new_latent]
+ if state.decode_pad is not None:
+ vae_input = torch.cat([state.decode_pad, new_latent], dim=0)
+ pad_frames = state.decode_pad.shape[0]
+ else:
+ vae_input = new_latent
+ pad_frames = 0
+
+ # VAE decode: (1, feat_dim, T_frames) -> (1, 1, T_samples)
+ feat = vae_input.unsqueeze(0).transpose(1, 2).contiguous()
+ with torch.no_grad():
+ audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1)
+
+ # Slice out only the new audio (after the pad region).
+ # Each latent frame maps to decoder_chunk_size audio samples.
+ dcs = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_input.shape[0]))
+ new_audio = audio[pad_frames * dcs : (pad_frames + n_new) * dcs].detach().cpu().float()
+
+ # Roll the pad buffer: keep last N latent frames as context for next step.
+ all_latents = vae_input # [pad + new]
+ state.decode_pad = all_latents[-self._n_decode_pad_frames :].detach()
- n = len(state.accumulated_patches)
- if n <= 1 or n % self._vae_decode_interval == 0 or state.is_stopping:
- self._perf.start("vae_decode")
- all_p = torch.cat(state.accumulated_patches, dim=0)
- state.accumulated_patches = [all_p]
- feat = rearrange(all_p.reshape(1, -1, self._feat_dim), "b t d -> b d t")
- with torch.no_grad():
- audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1).cpu().float()
- self._perf.stop("vae_decode")
- state.last_decoded_audio = audio
- return audio
- return state.last_decoded_audio
+ state.audio_chunks.append(new_audio)
+ state.last_decoded_audio = new_audio
+ self._perf.stop("vae_decode")
+ return new_audio
# -------------------- compute_logits --------------------
@@ -830,7 +868,8 @@ def preprocess(
state = self._get_or_create_state(req_id)
state.prefill_text = ""
- state.accumulated_patches = []
+ state.decode_pad = None
+ state.audio_chunks = []
state.prefill_completed = False
state.decode_step_count = 0
state.precomputed_stop_logits = None