From 593641c66aa8c7d55e96d885bcc093e2e9502c00 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Sat, 18 Apr 2026 04:20:15 +0800 Subject: [PATCH 1/2] [Bugfix][VoxCPM2] Fix voice-clone decode loop by padding prefill prompt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The talker's prefill embedding includes ref_audio / prompt_audio regions on top of the target text, but vLLM only forwarded ``len(prompt_token_ids)`` slots through base_lm. The remaining ref/prompt positions were zero-padded, so lm_hidden at the audio_start position read zero and stop_head never fired — decode ran to MAX_DECODE_STEPS=2000 and the response was ~320s of noise. Serving now pads ``prompt_token_ids`` to the full prefill length (text + audio_start + ref/prompt region) using AudioVAE parameters read from hf_config, and ships the real text tokens through ``additional_information['text_token_ids']``. ``ref_audio + ref_text`` is routed to native continuation mode; ``ref_audio`` alone keeps reference-only mode. To prevent silent regressions from layout drift, the talker now asserts ``scaffold_len == tts_len`` at prefill entry — any mismatch crashes immediately instead of degrading to noisy audio. Signed-off-by: Sy03 <1370724210@qq.com> --- .../entrypoints/openai/serving_speech.py | 35 ++++++++++++++----- .../models/voxcpm2/voxcpm2_talker.py | 7 +++- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 3eaf18111c0..aabdd6879da 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -457,6 +457,31 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object) logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e) return 2048 + async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: + """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length).""" + token_ids = self._voxcpm2_encode(request.input) + bos = self._voxcpm2_tokenizer.bos_token_id + if token_ids and token_ids[0] == bos: + token_ids = token_ids[1:] + prefill_len = len(token_ids) + 1 # + audio_start + additional: dict[str, Any] = {"text_token_ids": [token_ids]} + if request.ref_audio is not None: + wav_list, sr = await self._resolve_ref_audio(request.ref_audio) + hf_cfg = self.engine_client.model_config.hf_config + vae = hf_cfg.audio_vae_config + patch_samples = hf_cfg.patch_size * math.prod(vae["encoder_rates"]) + ref_len = math.ceil(math.ceil(len(wav_list) * vae["sample_rate"] / sr) / patch_samples) + if request.ref_text is not None: + additional["prompt_audio"] = [[wav_list, sr]] + additional["prompt_text"] = [request.ref_text] + prefill_len += ref_len + len( + self._voxcpm2_tokenizer.encode(request.ref_text, add_special_tokens=False) + ) + else: + additional["reference_audio"] = [[wav_list, sr]] + prefill_len += ref_len + 2 # ref_start / ref_end + return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional} + def _get_uploaded_audio_data(self, voice_name: str) -> str | None: """Get base64 encoded audio data for uploaded voice.""" voice_name_lower = voice_name.lower() @@ -1524,16 +1549,8 @@ async def _prepare_speech_generation( if request.instructions: prompt["instruct"] = request.instructions elif self._tts_model_type == "voxcpm2": + prompt = await self._build_voxcpm2_prompt(request) tts_params = {} - additional: dict[str, Any] = {} - if request.ref_audio is not None: - wav_list, sr = await self._resolve_ref_audio(request.ref_audio) - additional["reference_audio"] = [[wav_list, sr]] - # Pre-split multichar Chinese tokens (VoxCPM2 was trained with single-char CJK IDs). - token_ids = self._voxcpm2_encode(request.input) - prompt: dict[str, Any] = {"prompt_token_ids": token_ids} - if additional: - prompt["additional_information"] = additional elif self._is_tts: validation_error = self._validate_tts_request(request) if validation_error: diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index b666e41ebc9..8a653adfba3 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -793,6 +793,10 @@ def _prepare_residual_prefill(self, state: _RequestState, base_lm_out: torch.Ten tts_len = text_mask.shape[1] scaffold_len = base_lm_out.shape[0] + assert scaffold_len == tts_len, ( + f"voxcpm2 prefill length mismatch: scaffold_len={scaffold_len} tts_len={tts_len}; " + "serving layer must pad prompt_token_ids to the full prefill length." + ) if scaffold_len < tts_len: # Voice clone / continuation: scaffold only processed vllm tokens. @@ -1063,7 +1067,8 @@ def preprocess( for rid in [r for r, s in self._active_states.items() if r not in pending_ids and s.prefill_completed]: self._cleanup_request(rid) - token_ids = input_ids.tolist() + real = info_dict.get("text_token_ids") + token_ids = real[0] if real else input_ids.tolist() # Fail-fast: unsplit multichar Chinese IDs in input_ids means the # serving layer didn't pre-split. Silent fixup here would cause # input_ids/embeds length mismatch (scheduler slot count is fixed). From 2e84a0799816fbf1e5c1a933a327956c779515f3 Mon Sep 17 00:00:00 2001 From: Sy03 <1370724210@qq.com> Date: Sat, 18 Apr 2026 16:25:09 +0800 Subject: [PATCH 2/2] [Bugfix][VoxCPM2] Address review feedback: shared prompt helper + strict assert MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract ``build_voxcpm2_prompt`` into ``voxcpm2_talker.py`` so online serving and the offline ``end2end.py`` share one tokenizer/CJK-split code path. This removes the length prediction drift between ``_voxcpm2_tokenizer.encode`` (used by serving) and ``tts.text_tokenizer(prompt_text)`` (used by the talker) that a future CJK-range change would have exposed via the prefill-length assertion. - ``serving_speech._build_voxcpm2_prompt`` delegates to the shared helper. - ``examples/offline_inference/voxcpm2/end2end.py`` builds the same padded ``prompt_token_ids`` via the helper, so ``--reference-audio`` and the new ``--ref-text`` flag no longer trip the talker assert. - ``voxcpm2_talker._prepare_residual_prefill`` keeps a strict ``scaffold_len == tts_len`` assert with no fallback pad: zero-padding ``base_lm_out`` turned ``lm_hidden`` at audio_start into zeros and caused the original voice-clone decode loop, so a hard failure is correct. - ``voxcpm2_talker.preprocess`` tightens ``token_ids = real[0] if real else …`` to ``real is None`` so an explicit empty ``text_token_ids`` list surfaces a bug instead of silently using ``input_ids``. Verified on H20 against openbmb/VoxCPM2 (online + offline, zero-shot / ref-only / ref+ref_text): all 6 cases decode within 1–13 s with no MAX_DECODE_STEPS warning, no prefill-length assertion, and Whisper ASR matches the target text. Signed-off-by: Sy03 <1370724210@qq.com> --- examples/offline_inference/voxcpm2/end2end.py | 50 +++++++++++----- .../entrypoints/openai/serving_speech.py | 36 +++++------- .../models/voxcpm2/voxcpm2_talker.py | 58 ++++++++++++++----- 3 files changed, 94 insertions(+), 50 deletions(-) diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py index 687e596018c..6b6bf78ddf1 100644 --- a/examples/offline_inference/voxcpm2/end2end.py +++ b/examples/offline_inference/voxcpm2/end2end.py @@ -65,6 +65,12 @@ def parse_args(): default=None, help="Text matching --prompt-audio for continuation mode.", ) + parser.add_argument( + "--ref-text", + type=str, + default=None, + help="Optional transcript of --reference-audio (enables ref_continuation mode).", + ) return parser.parse_args() @@ -103,24 +109,40 @@ def main(): stage_configs_path=args.stage_configs_path, ) - additional: dict = {} - if args.reference_audio: - additional["reference_audio"] = args.reference_audio - if args.prompt_audio and args.prompt_text: - additional["prompt_audio"] = args.prompt_audio - additional["prompt_text"] = args.prompt_text + from transformers import AutoTokenizer - prompt: dict = {"prompt": args.text} - if additional: - prompt["additional_information"] = additional + from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( + build_cjk_split_map, + build_voxcpm2_prompt, + ) + + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + split_map = build_cjk_split_map(tokenizer) + hf_config = engine.engine.stage_vllm_configs[0].model_config.hf_config + + ref_audio_arg = args.reference_audio or args.prompt_audio + ref_text_arg = args.ref_text or args.prompt_text + ref_wav, ref_sr = (None, None) + if ref_audio_arg: + ref_wav_arr, ref_sr = sf.read(ref_audio_arg) + ref_wav = ref_wav_arr.mean(axis=-1).tolist() if ref_wav_arr.ndim > 1 else ref_wav_arr.tolist() + + prompt = build_voxcpm2_prompt( + hf_config=hf_config, + tokenizer=tokenizer, + split_map=split_map, + text=args.text, + ref_audio=ref_wav, + ref_sr=ref_sr, + ref_text=ref_text_arg, + ) print(f"Model : {args.model}") print(f"Text : {args.text}") - if args.reference_audio: - print(f"Ref audio : {args.reference_audio}") - if args.prompt_audio: - print(f"Prompt audio: {args.prompt_audio}") - print(f"Prompt text : {args.prompt_text}") + if ref_audio_arg: + print(f"Ref audio : {ref_audio_arg}") + if ref_text_arg: + print(f"Ref text : {ref_text_arg}") print(f"Output dir : {output_dir}") t_start = time.perf_counter() diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index aabdd6879da..ba8292f0c27 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -459,28 +459,22 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object) async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length).""" - token_ids = self._voxcpm2_encode(request.input) - bos = self._voxcpm2_tokenizer.bos_token_id - if token_ids and token_ids[0] == bos: - token_ids = token_ids[1:] - prefill_len = len(token_ids) + 1 # + audio_start - additional: dict[str, Any] = {"text_token_ids": [token_ids]} + from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt + + self._voxcpm2_encode("") # lazy-init tokenizer + split_map + ref_audio = None + ref_sr = None if request.ref_audio is not None: - wav_list, sr = await self._resolve_ref_audio(request.ref_audio) - hf_cfg = self.engine_client.model_config.hf_config - vae = hf_cfg.audio_vae_config - patch_samples = hf_cfg.patch_size * math.prod(vae["encoder_rates"]) - ref_len = math.ceil(math.ceil(len(wav_list) * vae["sample_rate"] / sr) / patch_samples) - if request.ref_text is not None: - additional["prompt_audio"] = [[wav_list, sr]] - additional["prompt_text"] = [request.ref_text] - prefill_len += ref_len + len( - self._voxcpm2_tokenizer.encode(request.ref_text, add_special_tokens=False) - ) - else: - additional["reference_audio"] = [[wav_list, sr]] - prefill_len += ref_len + 2 # ref_start / ref_end - return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional} + ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio) + return build_voxcpm2_prompt( + hf_config=self.engine_client.model_config.hf_config, + tokenizer=self._voxcpm2_tokenizer, + split_map=self._voxcpm2_split_map, + text=request.input, + ref_audio=ref_audio, + ref_sr=ref_sr, + ref_text=request.ref_text, + ) def _get_uploaded_audio_data(self, voice_name: str) -> str | None: """Get base64 encoded audio data for uploaded voice.""" diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index 8a653adfba3..7d54cd17f54 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -13,6 +13,7 @@ import copy import dataclasses import logging +import math import os import time from collections.abc import Iterable @@ -80,6 +81,44 @@ def split_multichar_chinese(token_ids: list[int], split_map: dict[int, list[int] return result +def build_voxcpm2_prompt( + hf_config: Any, + tokenizer: Any, + split_map: dict[int, list[int]], + text: str, + ref_audio: Any | None = None, + ref_sr: int | None = None, + ref_text: str | None = None, +) -> dict[str, Any]: + """Build a VoxCPM2 prefill prompt whose ``prompt_token_ids`` length matches + the talker-side prefill length. + + Used by both online serving (``serving_speech._build_voxcpm2_prompt``) and + the offline example, so the talker-side length assertion never fires. + """ + ids = split_multichar_chinese(tokenizer.encode(text, add_special_tokens=True), split_map) + bos = tokenizer.bos_token_id + if ids and ids[0] == bos: + ids = ids[1:] + prefill_len = len(ids) + 1 # + audio_start + additional: dict[str, Any] = {"text_token_ids": [ids]} + if ref_audio is not None: + vae = hf_config.audio_vae_config + patch_samples = hf_config.patch_size * math.prod(vae["encoder_rates"]) + ref_len = math.ceil(math.ceil(len(ref_audio) * vae["sample_rate"] / ref_sr) / patch_samples) + if ref_text is not None: + additional["prompt_audio"] = [[ref_audio, ref_sr]] + additional["prompt_text"] = [ref_text] + ref_ids = split_multichar_chinese(tokenizer.encode(ref_text, add_special_tokens=True), split_map) + if ref_ids and ref_ids[0] == bos: + ref_ids = ref_ids[1:] + prefill_len += ref_len + len(ref_ids) + else: + additional["reference_audio"] = [[ref_audio, ref_sr]] + prefill_len += ref_len + 2 # ref_start / ref_end + return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional} + + def _encode_raw_audio( tts: nn.Module, samples: list[float] | torch.Tensor, @@ -795,21 +834,10 @@ def _prepare_residual_prefill(self, state: _RequestState, base_lm_out: torch.Ten scaffold_len = base_lm_out.shape[0] assert scaffold_len == tts_len, ( f"voxcpm2 prefill length mismatch: scaffold_len={scaffold_len} tts_len={tts_len}; " - "serving layer must pad prompt_token_ids to the full prefill length." + "caller must pad prompt_token_ids to the full prefill length " + "(see serving_speech._build_voxcpm2_prompt or the offline example)." ) - - if scaffold_len < tts_len: - # Voice clone / continuation: scaffold only processed vllm tokens. - # Pad to match TTS sequence length (extra positions are masked out). - pad = torch.zeros( - tts_len - scaffold_len, - base_lm_out.shape[-1], - device=base_lm_out.device, - dtype=base_lm_out.dtype, - ) - enc_out = torch.cat([base_lm_out, pad], dim=0).unsqueeze(0) - else: - enc_out = base_lm_out.unsqueeze(0) + enc_out = base_lm_out.unsqueeze(0) prefix_feat_cond = ( feat[:, -1, ...] @@ -1068,7 +1096,7 @@ def preprocess( self._cleanup_request(rid) real = info_dict.get("text_token_ids") - token_ids = real[0] if real else input_ids.tolist() + token_ids = input_ids.tolist() if real is None else real[0] # Fail-fast: unsplit multichar Chinese IDs in input_ids means the # serving layer didn't pre-split. Silent fixup here would cause # input_ids/embeds length mismatch (scheduler slot count is fixed).