diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py index ce404bf962d..687e596018c 100644 --- a/examples/offline_inference/voxcpm2/end2end.py +++ b/examples/offline_inference/voxcpm2/end2end.py @@ -74,16 +74,20 @@ def extract_audio(multimodal_output: dict) -> torch.Tensor: The output processor concatenates per-step delta tensors under ``model_outputs``. Falls back to ``audio`` for backwards compat. """ - audio = multimodal_output.get("model_outputs") or multimodal_output.get("audio") + audio = multimodal_output.get("model_outputs") + if audio is None: + audio = multimodal_output.get("audio") if audio is None: raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}") if isinstance(audio, list): - # Take the last valid tensor (most complete audio) + # Defensive: usually the output processor consolidates into a single + # tensor at request completion, but concatenate here too in case the + # caller consumes intermediate (pre-consolidation) outputs. valid = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio if a is not None] if not valid: raise ValueError("Audio list is empty or all elements are None.") - return valid[-1] + return torch.cat(valid, dim=0) if len(valid) > 1 else valid[0] return torch.as_tensor(audio).float().cpu().reshape(-1) diff --git a/examples/online_serving/voxcpm2/gradio_demo.py b/examples/online_serving/voxcpm2/gradio_demo.py new file mode 100644 index 00000000000..a33a2d9245f --- /dev/null +++ b/examples/online_serving/voxcpm2/gradio_demo.py @@ -0,0 +1,602 @@ +"""Gradio demo for VoxCPM2 TTS with gapless streaming audio playback. + +Uses a custom AudioWorklet-based player for gap-free streaming +(adapted from the Qwen3-TTS demo). Audio is streamed from the vLLM +server through a same-origin proxy and played via the Web Audio API's +AudioWorklet, which maintains a FIFO buffer queue and plays samples at +the audio clock rate. + +Usage: + # Start the vLLM server first: + python -m vllm_omni.entrypoints.openai.api_server \ + --model openbmb/VoxCPM2 \ + --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm2.yaml \ + --host 0.0.0.0 --port 8000 + + # Then launch the demo: + python gradio_demo.py --api-base http://localhost:8000 +""" + +from __future__ import annotations + +import argparse +import base64 +import io +import json +import logging + +import gradio as gr +import httpx +import numpy as np +import soundfile as sf +from fastapi import FastAPI, Request +from fastapi.responses import Response, StreamingResponse + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 48000 + +# ── AudioWorklet processor (loaded in browser via Blob URL) ────────── +WORKLET_JS = r""" +class TTSPlaybackProcessor extends AudioWorkletProcessor { + constructor() { + super(); + this.queue = []; + this.buf = null; + this.pos = 0; + this.playing = false; + this.played = 0; + this.port.onmessage = (e) => { + if (e.data && e.data.type === 'clear') { + this.queue = []; this.buf = null; this.pos = 0; this.played = 0; + if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped'}); } + return; + } + this.queue.push(e.data); + }; + } + process(inputs, outputs) { + const out = outputs[0][0]; + for (let i = 0; i < out.length; i++) { + if (!this.buf || this.pos >= this.buf.length) { + if (this.queue.length > 0) { + this.buf = this.queue.shift(); this.pos = 0; + } else { + for (let j = i; j < out.length; j++) out[j] = 0; + if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped', played:this.played}); } + return true; + } + } + out[i] = this.buf[this.pos++] / 32768; + this.played++; + } + if (!this.playing) { this.playing = true; this.port.postMessage({type:'started'}); } + return true; + } +} +registerProcessor('tts-playback-processor', TTSPlaybackProcessor); +""" + +PLAYER_HTML = """ +
+
+
+ Ready + +
+ + + +
+""" + + +def _build_player_js() -> str: + return f""" + +""" + + +def _encode_audio(audio_data: tuple) -> str: + sr, audio_np = audio_data + if audio_np.dtype in (np.float32, np.float64): + audio_np = np.clip(audio_np, -1.0, 1.0) + audio_np = (audio_np * 32767).astype(np.int16) + elif audio_np.dtype != np.int16: + audio_np = audio_np.astype(np.int16) + buf = io.BytesIO() + sf.write(buf, audio_np, sr, format="WAV") + return f"data:audio/wav;base64,{base64.b64encode(buf.getvalue()).decode()}" + + +def create_app(api_base: str): + app = FastAPI() + _pending: dict[str, dict] = {} + + @app.post("/proxy/v1/audio/speech") + async def proxy_speech(request: Request): + body = await request.json() + req_id = body.get("_req_id") + if req_id and req_id in _pending: + body = _pending.pop(req_id) + logger.info("Proxy: %s", {k: (f"<{len(str(v))} chars>" if k == "ref_audio" else v) for k, v in body.items()}) + try: + client = httpx.AsyncClient(timeout=300) + resp = await client.send( + client.build_request( + "POST", + f"{api_base}/v1/audio/speech", + json=body, + headers={"Authorization": "Bearer EMPTY", "Content-Type": "application/json"}, + ), + stream=True, + ) + except Exception as exc: + logger.exception("Proxy connection error") + await client.aclose() + return Response(content=str(exc), status_code=502) + if resp.status_code != 200: + content = await resp.aread() + await resp.aclose() + await client.aclose() + return Response(content=content, status_code=resp.status_code) + + async def relay(): + try: + async for chunk in resp.aiter_bytes(): + yield chunk + finally: + await resp.aclose() + await client.aclose() + + return StreamingResponse(relay(), media_type="application/octet-stream") + + css = """ + #generate-btn button { width: 100%; } + #streaming-player { border: 1px solid var(--border-color-primary) !important; border-radius: var(--block-radius) !important; padding: var(--block-padding) !important; } + """ + theme = gr.themes.Default( + primary_hue=gr.themes.Color( + c50="#f0f5ff", + c100="#dce6f9", + c200="#b8cef3", + c300="#8eb2eb", + c400="#6496e0", + c500="#4A90D9", + c600="#3a7bc8", + c700="#2d66b0", + c800="#1f4f8f", + c900="#163a6e", + c950="#0e2650", + ), + ) + + with gr.Blocks(title="VoxCPM2 TTS Demo") as demo: + gr.HTML(f""" +
+ vLLM-Omni +
+

VoxCPM2 Streaming Demo

+ + Served by vLLM-Omni + · {api_base} + · 48 kHz + +
+
+ """) + + gr.Markdown( + "**Three modes:** " + "**Voice Design** (control instruction only) · " + "**Controllable Cloning** (ref audio + optional style control) · " + "**Ultimate Cloning** (ref audio + transcript for audio continuation)" + ) + + with gr.Row(): + with gr.Column(scale=3): + text_input = gr.Textbox( + label="Target Text", + placeholder="Enter text to synthesize...", + lines=4, + ) + control_instruction = gr.Textbox( + label="Control Instruction (optional)", + placeholder="e.g. A warm young woman / Excited and fast-paced", + lines=2, + info="Describe voice style, emotion, pace. Works for both Voice Design and Controllable Cloning.", + ) + + with gr.Accordion("Voice Cloning", open=False): + ref_audio = gr.Audio( + label="Reference Audio (upload for cloning)", + type="numpy", + sources=["upload", "microphone"], + ) + ref_audio_url = gr.Textbox( + label="or Reference Audio URL", + placeholder="https://example.com/reference.wav", + ) + ultimate_clone = gr.Checkbox( + label="Ultimate Cloning Mode", + value=False, + info="Provide transcript of ref audio for audio continuation (disables control instruction)", + ) + prompt_text = gr.Textbox( + label="Reference Audio Transcript", + placeholder="Transcript of your reference audio (for ultimate cloning)", + lines=2, + visible=False, + ) + + with gr.Row(): + stream_checkbox = gr.Checkbox( + label="Stream (gapless)", + value=True, + info="AudioWorklet streaming", + ) + with gr.Row(): + generate_btn = gr.Button( + "Generate Speech", + variant="primary", + size="lg", + elem_id="generate-btn", + scale=3, + ) + reset_btn = gr.Button("Reset", variant="secondary", size="lg", scale=1) + + with gr.Column(scale=2): + player_html = gr.HTML( + value=PLAYER_HTML, + visible=True, + label="streaming player", + elem_id="streaming-player", + ) + audio_output = gr.Audio( + label="generated audio", + interactive=False, + autoplay=True, + visible=False, + ) + gr.Examples( + examples=[ + ["Hello, this is a VoxCPM2 demo running on vLLM-Omni.", ""], + [ + "I have a dream that my four little children will one day live in a nation " + "where they will not be judged by the color of their skin but by the content " + "of their character.", + "", + ], + [ + "I never asked you to stay. It's not like I care or anything. " + "But why does it still hurt so much now that you're gone?", + "A young girl with a soft, sweet voice. Speaks slowly with a melancholic tone.", + ], + ], + inputs=[text_input, control_instruction], + label="examples", + ) + gr.HTML(""" +
+ + vLLM-Omni + +
+ """) + + hidden_payload = gr.Textbox(visible=False, elem_id="tts-payload") + + def on_ultimate_toggle(checked): + return ( + gr.update(visible=checked), # prompt_text + gr.update(interactive=not checked), # control_instruction + ) + + ultimate_clone.change( + fn=on_ultimate_toggle, + inputs=[ultimate_clone], + outputs=[prompt_text, control_instruction], + ) + + def on_stream_change(stream: bool): + if stream: + return gr.update(visible=True), gr.update(visible=False) + return gr.update(visible=False), gr.update(visible=True) + + stream_checkbox.change( + fn=on_stream_change, + inputs=[stream_checkbox], + outputs=[player_html, audio_output], + ) + + def on_reset(): + return "", "", None, "", False, "", PLAYER_HTML + + reset_btn.click( + fn=on_reset, + outputs=[ + text_input, + control_instruction, + audio_output, + hidden_payload, + ultimate_clone, + prompt_text, + player_html, + ], + js="() => { if (window.ttsStop) window.ttsStop(); }", + ) + + def on_generate(stream_enabled, text, ctrl_instr, ref_a, ref_url, ult_clone, p_text): + import time as _time + + if not text or not text.strip(): + raise gr.Error("Please enter text to synthesize.") + + # VoxCPM2 uses "(instruction)text" format for control + ctrl = ctrl_instr.strip() if ctrl_instr and not ult_clone else "" + final_text = f"({ctrl}){text.strip()}" if ctrl else text.strip() + + payload: dict = { + "input": final_text, + "voice": "default", + "response_format": "pcm" if stream_enabled else "wav", + "stream": stream_enabled, + } + + # Reference audio for cloning + ref_url_s = ref_url.strip() if ref_url else "" + if ref_url_s: + payload["ref_audio"] = ref_url_s + elif ref_a is not None: + payload["ref_audio"] = _encode_audio(ref_a) + + # Ultimate cloning: prompt_audio + prompt_text for continuation + if ult_clone and p_text and p_text.strip(): + if ref_url_s: + payload["prompt_audio"] = ref_url_s + elif ref_a is not None: + payload["prompt_audio"] = payload.get("ref_audio", "") + payload["prompt_text"] = p_text.strip() + + if stream_enabled: + if ref_a is not None and not ref_url_s: + req_id = f"req-{int(_time.time() * 1000)}" + _pending[req_id] = payload + browser_payload = {"_req_id": req_id, "_nonce": int(_time.time() * 1000)} + return json.dumps(browser_payload), gr.update() + payload["_nonce"] = int(_time.time() * 1000) + return json.dumps(payload), gr.update() + else: + try: + with httpx.Client(timeout=300.0) as client: + resp = client.post( + f"{api_base}/v1/audio/speech", + json=payload, + headers={"Content-Type": "application/json", "Authorization": "Bearer EMPTY"}, + ) + except httpx.ConnectError: + raise gr.Error(f"Cannot connect to server at {api_base}.") + if resp.status_code != 200: + raise gr.Error(f"Server error ({resp.status_code}): {resp.text[:200]}") + audio_np, sr = sf.read(io.BytesIO(resp.content)) + if audio_np.ndim > 1: + audio_np = audio_np[:, 0] + return "", (sr, audio_np.astype(np.float32)) + + generate_btn.click( + fn=on_generate, + inputs=[ + stream_checkbox, + text_input, + control_instruction, + ref_audio, + ref_audio_url, + ultimate_clone, + prompt_text, + ], + outputs=[hidden_payload, audio_output], + ).then( + fn=lambda p: p, + inputs=[hidden_payload], + outputs=[hidden_payload], + js="(p) => { if (p && p.trim()) { const d = JSON.parse(p); delete d._nonce; window.ttsGenerate(d); } return p; }", + ) + + demo.queue() + + return gr.mount_gradio_app(app, demo, path="/", css=css, theme=theme, head=_build_player_js()) + + +def main(): + parser = argparse.ArgumentParser(description="VoxCPM2 streaming Gradio demo") + parser.add_argument("--api-base", default="http://localhost:8000", help="vLLM API server URL") + parser.add_argument("--host", default="0.0.0.0", help="Gradio server host") + parser.add_argument("--port", type=int, default=7860, help="Gradio server port") + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + print(f"Connecting to vLLM server at: {args.api_base}") + + import uvicorn + + uvicorn.run(create_app(args.api_base), host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py index 4e4f635d5c4..6ec4630a45e 100644 --- a/tests/e2e/offline_inference/test_voxcpm2.py +++ b/tests/e2e/offline_inference/test_voxcpm2.py @@ -33,14 +33,16 @@ def _extract_audio(multimodal_output: dict) -> torch.Tensor: """Extract the final complete audio tensor from multimodal output.""" assert isinstance(multimodal_output, dict), f"Expected dict, got {type(multimodal_output)}" - # Output processor accumulates per-step full audio under "audio". - audio = multimodal_output.get("audio") or multimodal_output.get("model_outputs") + # Output processor accumulates per-step audio chunks under "audio". + audio = multimodal_output.get("audio") + if audio is None: + audio = multimodal_output.get("model_outputs") assert audio is not None, f"No audio key, got {list(multimodal_output.keys())}" if isinstance(audio, list): - valid = [x for x in audio if isinstance(x, torch.Tensor) and x.numel() > 100] + valid = [torch.as_tensor(x).float().cpu().reshape(-1) for x in audio if x is not None] assert valid, "No valid audio tensors in output list" - audio = valid[-1] + audio = torch.cat(valid, dim=0) if len(valid) > 1 else valid[0] assert isinstance(audio, torch.Tensor), f"Expected Tensor, got {type(audio)}" return audio diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 43d02e85b84..badd799fc94 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -118,9 +118,10 @@ def _consolidate_multimodal_tensors(self) -> None: if isinstance(v, list) and v and isinstance(v[0], torch.Tensor): try: if k == "audio": - # When the audio tensor shape is inconsistent, torch.cat will fail. - # We need to use torch.cat in -1 dimension. - continue + # Concatenate delta audio chunks (1-D) into the full waveform. + # Each entry is a per-step slice; flatten to -1 so chunks with + # inconsistent leading dims can still be joined on the sample axis. + self.mm_accumulated[k] = torch.cat([t.reshape(-1) for t in v], dim=0) elif k == "sr": # Sample rate is a constant scalar, keep last value. self.mm_accumulated[k] = v[-1] diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py index 7ea5bc229dc..40bacfff6c7 100644 --- a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py +++ b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py @@ -308,31 +308,28 @@ def forward( return hidden_states def compile_selective(self) -> list[str]: - """Compile MLP + o_proj; keep RMSNorm/RoPE eager for precision.""" - compiled: list[str] = [] - for i, layer in enumerate(self.layers): - if i in self._compiled_layers: - continue - try: - layer.mlp = torch.compile( - layer.mlp, - mode="default", - fullgraph=True, - ) - layer.self_attn.o_proj = torch.compile( - layer.self_attn.o_proj, - mode="default", - fullgraph=True, - ) - layer.self_attn._fused_qkv_weight = None - self._compiled_layers.add(i) - if i == 0: - compiled.append(f"layers.*.mlp (×{len(self.layers)})") - compiled.append(f"layers.*.self_attn.o_proj (×{len(self.layers)})") - except Exception as e: - logger.warning("compile_selective: layer %d failed: %s", i, e) - break - return compiled + """Compile the full model forward as one graph. + + Earlier versions compiled ``layer.mlp`` + ``layer.self_attn.o_proj`` + (PR #2690) and then the whole ``layer`` (perf/voxcpm2-streaming-vae). + Both still paid one Dynamo dispatch per layer per decode step. + V3 profiling showed 1,332 per-layer dispatches (~28 layers × ~47 + decode steps) costing ~726 ms of CPU self-time for a long prompt. + + Compiling ``forward`` at the model level lets Dynamo unroll the + 28-layer Python loop inside the graph. Graph breaks at + PagedAttention produce sub-graphs but Dynamo memoises the whole + trace once, so the per-step dispatch drops from 28 to just a few. + """ + if self._compiled_layers: + return [] + # Null the fused-qkv caches so the compile sees the real weight layout. + for layer in self.layers: + layer.self_attn._fused_qkv_weight = None + self.forward = torch.compile(self.forward, mode="default", fullgraph=False) + # Mark every layer as compiled so idempotent callers don't double-wrap. + self._compiled_layers.update(range(len(self.layers))) + return ["forward (whole model)"] def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights from native checkpoint (base_lm. prefix pre-stripped).""" @@ -415,22 +412,14 @@ def forward( return hidden_states def compile_selective(self) -> list[str]: - """Compile MLP + o_proj (same as base_lm).""" - compiled: list[str] = [] - for i, layer in enumerate(self.layers): - if i in self._compiled_layers: - continue - try: - layer.mlp = torch.compile(layer.mlp, mode="default", fullgraph=True) - layer.self_attn.o_proj = torch.compile(layer.self_attn.o_proj, mode="default", fullgraph=True) - layer.self_attn._fused_qkv_weight = None - self._compiled_layers.add(i) - if i == 0: - compiled.append(f"layers.*.mlp (×{len(self.layers)})") - compiled.append(f"layers.*.self_attn.o_proj (×{len(self.layers)})") - except Exception as e: - logger.warning("compile_selective: residual layer %d failed: %s", i, e) - return compiled + """Compile the full residual model forward as one graph (same strategy as base_lm).""" + if self._compiled_layers: + return [] + for layer in self.layers: + layer.self_attn._fused_qkv_weight = None + self.forward = torch.compile(self.forward, mode="default", fullgraph=False) + self._compiled_layers.update(range(len(self.layers))) + return ["forward (whole residual)"] def load_weights_from_native(self, native_residual_lm: nn.Module) -> int: """Load weights from native residual_lm. Returns param count.""" diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 0898ca59ae4..94f06589046 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -11,6 +11,7 @@ from __future__ import annotations import dataclasses +import logging import os import time from collections.abc import Iterable @@ -19,7 +20,6 @@ import librosa import torch import torch.nn as nn -from einops import rearrange from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import ( @@ -86,7 +86,11 @@ class _RequestState: curr_prefix_feat_cond: torch.Tensor | None = None last_audio_patch_gpu: torch.Tensor | None = None precomputed_stop_logits: torch.Tensor | None = None - accumulated_patches: list[torch.Tensor] = dataclasses.field(default_factory=list) + # Rolling tail of previously-decoded latents used as VAE receptive-field context. + # Shape (n_pad_frames, feat_dim) on GPU. None before first decode. + decode_pad: torch.Tensor | None = None + # Audio chunks already emitted (CPU float32), concatenated for cumulative output. + audio_chunks: list[torch.Tensor] = dataclasses.field(default_factory=list) decode_step_count: int = 0 request_start_time: float = 0.0 prefill_completed: bool = False @@ -229,11 +233,11 @@ def _optimized_solve_euler( buffers.x_in[b : 2 * b].copy_(x) buffers.mu_in[:b].copy_(mu) buffers.mu_in[b : 2 * b].zero_() - buffers.t_in[:b].fill_(t.item()) - buffers.t_in[b : 2 * b].fill_(t.item()) + # Broadcast the 0-dim GPU scalar directly instead of + # ``.fill_(t.item())`` — ``.item()`` forces a GPU->CPU sync. + buffers.t_in[: 2 * b].copy_(t) if mean_mode: - buffers.dt_in[:b].fill_(dt.item()) - buffers.dt_in[b : 2 * b].fill_(dt.item()) + buffers.dt_in[: 2 * b].copy_(dt) else: buffers.dt_in.zero_() buffers.cond_in[:b].copy_(cond[:b]) @@ -263,9 +267,10 @@ def _optimized_solve_euler( else: buffers.x_in[:b].copy_(x) buffers.mu_in[:b].copy_(mu) - buffers.t_in[:b].fill_(t.item()) + # Broadcast the 0-dim GPU scalar; ``.fill_(t.item())`` would sync. + buffers.t_in[:b].copy_(t) if mean_mode: - buffers.dt_in[:b].fill_(dt.item()) + buffers.dt_in[:b].copy_(dt) else: buffers.dt_in[:b].zero_() buffers.cond_in[:b].copy_(cond[:b]) @@ -320,7 +325,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self._inference_timesteps = 10 self._cfg_value = 2.0 self._cfg_cutoff_ratio = 1.0 - self._vae_decode_interval = 5 + # Number of trailing latent frames to keep as VAE receptive-field context + # for sliding-window streaming decode. 12 matches the nanovllm reference + # implementation and covers the longest VAE decoder receptive field. + self._n_decode_pad_frames = 12 self._enable_torch_compile = True self._compile_vae = True self._max_decode_steps = 2000 @@ -686,7 +694,9 @@ def _finish_prefill(self, state: _RequestState, meta: dict, res_out: torch.Tenso state.request_start_time = time.perf_counter() state.prefill_completed = True - logger.info("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item()) + if logger.isEnabledFor(logging.DEBUG): + # Only compute the norm (which forces a GPU->CPU sync) if we will log it. + logger.debug("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item()) self._perf.reset() def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor, dev: Any): @@ -720,26 +730,54 @@ def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor # -------------------- audio collection -------------------- def _collect_audio(self, state: _RequestState) -> torch.Tensor | None: - patch = state.last_audio_patch_gpu - if patch is not None: - state.last_audio_patch_gpu = None - state.accumulated_patches.append(patch.reshape(1, -1).float()) + """Per-step sliding-window VAE decode (nanovllm pattern). - if not state.accumulated_patches: + Each decode step feeds ``[decode_pad, new_patch]`` through the VAE + and slices out only the audio region corresponding to the new patch. + The pad buffer (last ``_n_decode_pad_frames`` latent frames) provides + the receptive-field context needed by the VAE's transposed convolutions, + eliminating boundary artifacts between chunks. + + Returns the delta audio chunk (not cumulative) so the output processor + can stream each chunk to the client independently. + """ + patch = state.last_audio_patch_gpu + if patch is None: return None + state.last_audio_patch_gpu = None + + # patch shape: (patch_size, feat_dim) or (1, patch_size, feat_dim) + new_latent = patch.reshape(-1, self._feat_dim).to(torch.float32) + n_new = new_latent.shape[0] # = patch_size (typically 4) + + self._perf.start("vae_decode") + + # Build VAE input: [pad_frames | new_latent] + if state.decode_pad is not None: + vae_input = torch.cat([state.decode_pad, new_latent], dim=0) + pad_frames = state.decode_pad.shape[0] + else: + vae_input = new_latent + pad_frames = 0 + + # VAE decode: (1, feat_dim, T_frames) -> (1, 1, T_samples) + feat = vae_input.unsqueeze(0).transpose(1, 2).contiguous() + with torch.no_grad(): + audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1) + + # Slice out only the new audio (after the pad region). + # Each latent frame maps to decoder_chunk_size audio samples. + dcs = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_input.shape[0])) + new_audio = audio[pad_frames * dcs : (pad_frames + n_new) * dcs].detach().cpu().float() + + # Roll the pad buffer: keep last N latent frames as context for next step. + all_latents = vae_input # [pad + new] + state.decode_pad = all_latents[-self._n_decode_pad_frames :].detach() - n = len(state.accumulated_patches) - if n <= 1 or n % self._vae_decode_interval == 0 or state.is_stopping: - self._perf.start("vae_decode") - all_p = torch.cat(state.accumulated_patches, dim=0) - state.accumulated_patches = [all_p] - feat = rearrange(all_p.reshape(1, -1, self._feat_dim), "b t d -> b d t") - with torch.no_grad(): - audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1).cpu().float() - self._perf.stop("vae_decode") - state.last_decoded_audio = audio - return audio - return state.last_decoded_audio + state.audio_chunks.append(new_audio) + state.last_decoded_audio = new_audio + self._perf.stop("vae_decode") + return new_audio # -------------------- compute_logits -------------------- @@ -830,7 +868,8 @@ def preprocess( state = self._get_or_create_state(req_id) state.prefill_text = "" - state.accumulated_patches = [] + state.decode_pad = None + state.audio_chunks = [] state.prefill_completed = False state.decode_step_count = 0 state.precomputed_stop_logits = None