Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions examples/offline_inference/voxcpm2/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def parse_args():
default=None,
help="Text matching --prompt-audio for continuation mode.",
)
parser.add_argument(
"--ref-text",
type=str,
default=None,
help="Optional transcript of --reference-audio (enables ref_continuation mode).",
)
return parser.parse_args()


Expand Down Expand Up @@ -103,24 +109,40 @@ def main():
stage_configs_path=args.stage_configs_path,
)

additional: dict = {}
if args.reference_audio:
additional["reference_audio"] = args.reference_audio
if args.prompt_audio and args.prompt_text:
additional["prompt_audio"] = args.prompt_audio
additional["prompt_text"] = args.prompt_text
from transformers import AutoTokenizer

prompt: dict = {"prompt": args.text}
if additional:
prompt["additional_information"] = additional
from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import (
build_cjk_split_map,
build_voxcpm2_prompt,
)

tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
split_map = build_cjk_split_map(tokenizer)
hf_config = engine.engine.stage_vllm_configs[0].model_config.hf_config

ref_audio_arg = args.reference_audio or args.prompt_audio
ref_text_arg = args.ref_text or args.prompt_text
ref_wav, ref_sr = (None, None)
if ref_audio_arg:
ref_wav_arr, ref_sr = sf.read(ref_audio_arg)
ref_wav = ref_wav_arr.mean(axis=-1).tolist() if ref_wav_arr.ndim > 1 else ref_wav_arr.tolist()

prompt = build_voxcpm2_prompt(
hf_config=hf_config,
tokenizer=tokenizer,
split_map=split_map,
text=args.text,
ref_audio=ref_wav,
ref_sr=ref_sr,
ref_text=ref_text_arg,
)

print(f"Model : {args.model}")
print(f"Text : {args.text}")
if args.reference_audio:
print(f"Ref audio : {args.reference_audio}")
if args.prompt_audio:
print(f"Prompt audio: {args.prompt_audio}")
print(f"Prompt text : {args.prompt_text}")
if ref_audio_arg:
print(f"Ref audio : {ref_audio_arg}")
if ref_text_arg:
print(f"Ref text : {ref_text_arg}")
print(f"Output dir : {output_dir}")

t_start = time.perf_counter()
Expand Down
29 changes: 20 additions & 9 deletions vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,25 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object)
logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e)
return 2048

async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
"""Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length)."""
from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt

self._voxcpm2_encode("") # lazy-init tokenizer + split_map
ref_audio = None
ref_sr = None
if request.ref_audio is not None:
ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio)
return build_voxcpm2_prompt(
hf_config=self.engine_client.model_config.hf_config,
tokenizer=self._voxcpm2_tokenizer,
split_map=self._voxcpm2_split_map,
text=request.input,
ref_audio=ref_audio,
ref_sr=ref_sr,
ref_text=request.ref_text,
)

def _get_uploaded_audio_data(self, voice_name: str) -> str | None:
"""Get base64 encoded audio data for uploaded voice."""
voice_name_lower = voice_name.lower()
Expand Down Expand Up @@ -1524,16 +1543,8 @@ async def _prepare_speech_generation(
if request.instructions:
prompt["instruct"] = request.instructions
elif self._tts_model_type == "voxcpm2":
prompt = await self._build_voxcpm2_prompt(request)
tts_params = {}
additional: dict[str, Any] = {}
if request.ref_audio is not None:
wav_list, sr = await self._resolve_ref_audio(request.ref_audio)
additional["reference_audio"] = [[wav_list, sr]]
# Pre-split multichar Chinese tokens (VoxCPM2 was trained with single-char CJK IDs).
token_ids = self._voxcpm2_encode(request.input)
prompt: dict[str, Any] = {"prompt_token_ids": token_ids}
if additional:
prompt["additional_information"] = additional
elif self._is_tts:
validation_error = self._validate_tts_request(request)
if validation_error:
Expand Down
61 changes: 47 additions & 14 deletions vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import copy
import dataclasses
import logging
import math
import os
import time
from collections.abc import Iterable
Expand Down Expand Up @@ -80,6 +81,44 @@ def split_multichar_chinese(token_ids: list[int], split_map: dict[int, list[int]
return result


def build_voxcpm2_prompt(
hf_config: Any,
tokenizer: Any,
split_map: dict[int, list[int]],
text: str,
ref_audio: Any | None = None,
ref_sr: int | None = None,
ref_text: str | None = None,
) -> dict[str, Any]:
"""Build a VoxCPM2 prefill prompt whose ``prompt_token_ids`` length matches
the talker-side prefill length.

Used by both online serving (``serving_speech._build_voxcpm2_prompt``) and
the offline example, so the talker-side length assertion never fires.
"""
ids = split_multichar_chinese(tokenizer.encode(text, add_special_tokens=True), split_map)
bos = tokenizer.bos_token_id
if ids and ids[0] == bos:
ids = ids[1:]
prefill_len = len(ids) + 1 # + audio_start
additional: dict[str, Any] = {"text_token_ids": [ids]}
if ref_audio is not None:
vae = hf_config.audio_vae_config
patch_samples = hf_config.patch_size * math.prod(vae["encoder_rates"])
ref_len = math.ceil(math.ceil(len(ref_audio) * vae["sample_rate"] / ref_sr) / patch_samples)
if ref_text is not None:
additional["prompt_audio"] = [[ref_audio, ref_sr]]
additional["prompt_text"] = [ref_text]
ref_ids = split_multichar_chinese(tokenizer.encode(ref_text, add_special_tokens=True), split_map)
if ref_ids and ref_ids[0] == bos:
ref_ids = ref_ids[1:]
prefill_len += ref_len + len(ref_ids)
else:
additional["reference_audio"] = [[ref_audio, ref_sr]]
prefill_len += ref_len + 2 # ref_start / ref_end
return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional}


def _encode_raw_audio(
tts: nn.Module,
samples: list[float] | torch.Tensor,
Expand Down Expand Up @@ -793,19 +832,12 @@ def _prepare_residual_prefill(self, state: _RequestState, base_lm_out: torch.Ten

tts_len = text_mask.shape[1]
scaffold_len = base_lm_out.shape[0]

if scaffold_len < tts_len:
# Voice clone / continuation: scaffold only processed vllm tokens.
# Pad to match TTS sequence length (extra positions are masked out).
pad = torch.zeros(
tts_len - scaffold_len,
base_lm_out.shape[-1],
device=base_lm_out.device,
dtype=base_lm_out.dtype,
)
enc_out = torch.cat([base_lm_out, pad], dim=0).unsqueeze(0)
else:
enc_out = base_lm_out.unsqueeze(0)
assert scaffold_len == tts_len, (
Comment thread
Sy0307 marked this conversation as resolved.
f"voxcpm2 prefill length mismatch: scaffold_len={scaffold_len} tts_len={tts_len}; "
"caller must pad prompt_token_ids to the full prefill length "
"(see serving_speech._build_voxcpm2_prompt or the offline example)."
)
enc_out = base_lm_out.unsqueeze(0)

prefix_feat_cond = (
feat[:, -1, ...]
Expand Down Expand Up @@ -1063,7 +1095,8 @@ def preprocess(
for rid in [r for r, s in self._active_states.items() if r not in pending_ids and s.prefill_completed]:
self._cleanup_request(rid)

token_ids = input_ids.tolist()
real = info_dict.get("text_token_ids")
token_ids = input_ids.tolist() if real is None else real[0]
# Fail-fast: unsplit multichar Chinese IDs in input_ids means the
# serving layer didn't pre-split. Silent fixup here would cause
# input_ids/embeds length mismatch (scheduler slot count is fixed).
Expand Down
Loading