diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py index 687e596018c..6b6bf78ddf1 100644 --- a/examples/offline_inference/voxcpm2/end2end.py +++ b/examples/offline_inference/voxcpm2/end2end.py @@ -65,6 +65,12 @@ def parse_args(): default=None, help="Text matching --prompt-audio for continuation mode.", ) + parser.add_argument( + "--ref-text", + type=str, + default=None, + help="Optional transcript of --reference-audio (enables ref_continuation mode).", + ) return parser.parse_args() @@ -103,24 +109,40 @@ def main(): stage_configs_path=args.stage_configs_path, ) - additional: dict = {} - if args.reference_audio: - additional["reference_audio"] = args.reference_audio - if args.prompt_audio and args.prompt_text: - additional["prompt_audio"] = args.prompt_audio - additional["prompt_text"] = args.prompt_text + from transformers import AutoTokenizer - prompt: dict = {"prompt": args.text} - if additional: - prompt["additional_information"] = additional + from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( + build_cjk_split_map, + build_voxcpm2_prompt, + ) + + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + split_map = build_cjk_split_map(tokenizer) + hf_config = engine.engine.stage_vllm_configs[0].model_config.hf_config + + ref_audio_arg = args.reference_audio or args.prompt_audio + ref_text_arg = args.ref_text or args.prompt_text + ref_wav, ref_sr = (None, None) + if ref_audio_arg: + ref_wav_arr, ref_sr = sf.read(ref_audio_arg) + ref_wav = ref_wav_arr.mean(axis=-1).tolist() if ref_wav_arr.ndim > 1 else ref_wav_arr.tolist() + + prompt = build_voxcpm2_prompt( + hf_config=hf_config, + tokenizer=tokenizer, + split_map=split_map, + text=args.text, + ref_audio=ref_wav, + ref_sr=ref_sr, + ref_text=ref_text_arg, + ) print(f"Model : {args.model}") print(f"Text : {args.text}") - if args.reference_audio: - print(f"Ref audio : {args.reference_audio}") - if args.prompt_audio: - print(f"Prompt audio: {args.prompt_audio}") - print(f"Prompt text : {args.prompt_text}") + if ref_audio_arg: + print(f"Ref audio : {ref_audio_arg}") + if ref_text_arg: + print(f"Ref text : {ref_text_arg}") print(f"Output dir : {output_dir}") t_start = time.perf_counter() diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 3eaf18111c0..ba8292f0c27 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -457,6 +457,25 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object) logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e) return 2048 + async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]: + """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length).""" + from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt + + self._voxcpm2_encode("") # lazy-init tokenizer + split_map + ref_audio = None + ref_sr = None + if request.ref_audio is not None: + ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio) + return build_voxcpm2_prompt( + hf_config=self.engine_client.model_config.hf_config, + tokenizer=self._voxcpm2_tokenizer, + split_map=self._voxcpm2_split_map, + text=request.input, + ref_audio=ref_audio, + ref_sr=ref_sr, + ref_text=request.ref_text, + ) + def _get_uploaded_audio_data(self, voice_name: str) -> str | None: """Get base64 encoded audio data for uploaded voice.""" voice_name_lower = voice_name.lower() @@ -1524,16 +1543,8 @@ async def _prepare_speech_generation( if request.instructions: prompt["instruct"] = request.instructions elif self._tts_model_type == "voxcpm2": + prompt = await self._build_voxcpm2_prompt(request) tts_params = {} - additional: dict[str, Any] = {} - if request.ref_audio is not None: - wav_list, sr = await self._resolve_ref_audio(request.ref_audio) - additional["reference_audio"] = [[wav_list, sr]] - # Pre-split multichar Chinese tokens (VoxCPM2 was trained with single-char CJK IDs). - token_ids = self._voxcpm2_encode(request.input) - prompt: dict[str, Any] = {"prompt_token_ids": token_ids} - if additional: - prompt["additional_information"] = additional elif self._is_tts: validation_error = self._validate_tts_request(request) if validation_error: diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py index b666e41ebc9..7d54cd17f54 100644 --- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py +++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py @@ -13,6 +13,7 @@ import copy import dataclasses import logging +import math import os import time from collections.abc import Iterable @@ -80,6 +81,44 @@ def split_multichar_chinese(token_ids: list[int], split_map: dict[int, list[int] return result +def build_voxcpm2_prompt( + hf_config: Any, + tokenizer: Any, + split_map: dict[int, list[int]], + text: str, + ref_audio: Any | None = None, + ref_sr: int | None = None, + ref_text: str | None = None, +) -> dict[str, Any]: + """Build a VoxCPM2 prefill prompt whose ``prompt_token_ids`` length matches + the talker-side prefill length. + + Used by both online serving (``serving_speech._build_voxcpm2_prompt``) and + the offline example, so the talker-side length assertion never fires. + """ + ids = split_multichar_chinese(tokenizer.encode(text, add_special_tokens=True), split_map) + bos = tokenizer.bos_token_id + if ids and ids[0] == bos: + ids = ids[1:] + prefill_len = len(ids) + 1 # + audio_start + additional: dict[str, Any] = {"text_token_ids": [ids]} + if ref_audio is not None: + vae = hf_config.audio_vae_config + patch_samples = hf_config.patch_size * math.prod(vae["encoder_rates"]) + ref_len = math.ceil(math.ceil(len(ref_audio) * vae["sample_rate"] / ref_sr) / patch_samples) + if ref_text is not None: + additional["prompt_audio"] = [[ref_audio, ref_sr]] + additional["prompt_text"] = [ref_text] + ref_ids = split_multichar_chinese(tokenizer.encode(ref_text, add_special_tokens=True), split_map) + if ref_ids and ref_ids[0] == bos: + ref_ids = ref_ids[1:] + prefill_len += ref_len + len(ref_ids) + else: + additional["reference_audio"] = [[ref_audio, ref_sr]] + prefill_len += ref_len + 2 # ref_start / ref_end + return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional} + + def _encode_raw_audio( tts: nn.Module, samples: list[float] | torch.Tensor, @@ -793,19 +832,12 @@ def _prepare_residual_prefill(self, state: _RequestState, base_lm_out: torch.Ten tts_len = text_mask.shape[1] scaffold_len = base_lm_out.shape[0] - - if scaffold_len < tts_len: - # Voice clone / continuation: scaffold only processed vllm tokens. - # Pad to match TTS sequence length (extra positions are masked out). - pad = torch.zeros( - tts_len - scaffold_len, - base_lm_out.shape[-1], - device=base_lm_out.device, - dtype=base_lm_out.dtype, - ) - enc_out = torch.cat([base_lm_out, pad], dim=0).unsqueeze(0) - else: - enc_out = base_lm_out.unsqueeze(0) + assert scaffold_len == tts_len, ( + f"voxcpm2 prefill length mismatch: scaffold_len={scaffold_len} tts_len={tts_len}; " + "caller must pad prompt_token_ids to the full prefill length " + "(see serving_speech._build_voxcpm2_prompt or the offline example)." + ) + enc_out = base_lm_out.unsqueeze(0) prefix_feat_cond = ( feat[:, -1, ...] @@ -1063,7 +1095,8 @@ def preprocess( for rid in [r for r, s in self._active_states.items() if r not in pending_ids and s.prefill_completed]: self._cleanup_request(rid) - token_ids = input_ids.tolist() + real = info_dict.get("text_token_ids") + token_ids = input_ids.tolist() if real is None else real[0] # Fail-fast: unsplit multichar Chinese IDs in input_ids means the # serving layer didn't pre-split. Silent fixup here would cause # input_ids/embeds length mismatch (scheduler slot count is fixed).