diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 9a012abaacb..c927049639a 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -5,6 +5,7 @@ import os import re import struct +import tempfile import time from pathlib import Path from typing import Any @@ -341,6 +342,71 @@ def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int: logger.warning("Failed to estimate TTS prompt length, using fallback 2048: %s", e) return 2048 + def _estimate_fish_ref_code_len(self, ref_audio: object) -> int | None: + """Estimate Fish Speech semantic token length from raw reference audio.""" + from vllm_omni.model_executor.models.fish_speech.dac_utils import ( + DAC_HOP_LENGTH, + DAC_SAMPLE_RATE, + ) + + if not isinstance(ref_audio, (list, tuple)) or len(ref_audio) != 2: + return None + wav, sr = ref_audio + sr = int(sr) + n_samples = len(wav) + if sr <= 0 or n_samples <= 0: + return None + resampled_len = max(1, math.ceil(n_samples * DAC_SAMPLE_RATE / sr)) + return max(1, math.ceil(resampled_len / DAC_HOP_LENGTH)) + + def _estimate_fish_prompt_len( + self, + request: OpenAICreateSpeechRequest, + ref_audio: object, + ) -> int: + """Estimate Fish Speech clone prompt length without encoding reference audio.""" + try: + from transformers import AutoTokenizer + + if self._fish_speech_tokenizer is None: + model_name = self.engine_client.model_config.model + self._fish_speech_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + + tokenizer = self._fish_speech_tokenizer + semantic_len = self._estimate_fish_ref_code_len(ref_audio) + if semantic_len is None: + raise ValueError("Failed to estimate Fish Speech semantic token length") + + user_text = f"<|speaker:0|>{request.input}" + user_ids = tokenizer.apply_chat_template( + [{"role": "user", "content": user_text}], + tokenize=True, + add_generation_prompt=True, + ) + voice_token_id = tokenizer.encode("<|voice|>", add_special_tokens=False) + audio_start_id = tokenizer.encode("<|audio_start|>", add_special_tokens=False) + audio_end_id = tokenizer.encode("<|audio_end|>", add_special_tokens=False) + prefix_ids = tokenizer.encode(f"<|speaker:0|>{request.ref_text}", add_special_tokens=False) + im_start = tokenizer.encode("<|im_start|>", add_special_tokens=False) + im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False) + system_tag = tokenizer.encode("system\n", add_special_tokens=False) + newline = tokenizer.encode("\n", add_special_tokens=False) + return ( + len(im_start) + + len(system_tag) + + len(prefix_ids) + + len(audio_start_id) + + semantic_len + + len(audio_end_id) + + len(im_end) + + len(newline) + + len(user_ids) + + len(voice_token_id) + ) + except Exception as e: + logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e) + return 2048 + def _get_uploaded_audio_data(self, voice_name: str) -> str | None: """Get base64 encoded audio data for uploaded voice.""" voice_name_lower = voice_name.lower() @@ -900,53 +966,39 @@ def _build_fish_speech_prompt( tokenizer = self._fish_speech_tokenizer model_name = self.engine_client.model_config.model - if ref_audio_data is not None and request.ref_text: - # Voice cloning: encode reference audio and build system message. - from vllm_omni.model_executor.models.fish_speech.dac_encoder import ( - encode_reference_audio, - ) - - wav_samples, sr = ref_audio_data - semantic_token_ids = encode_reference_audio(model_name, wav_samples, sr) - - # Build system message with ref text + audio tokens. - audio_start_id = tokenizer.encode("<|audio_start|>", add_special_tokens=False) - audio_end_id = tokenizer.encode("<|audio_end|>", add_special_tokens=False) - - # System content: <|speaker:0|>{ref_text}<|audio_start|>{codes}<|audio_end|> - prefix_text = f"<|speaker:0|>{request.ref_text}" - prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False) - system_content_ids = prefix_ids + audio_start_id + semantic_token_ids + audio_end_id - - # Manually build system turn: <|im_start|>system\n{content}<|im_end|>\n - im_start = tokenizer.encode("<|im_start|>", add_special_tokens=False) - im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False) - system_tag = tokenizer.encode("system\n", add_special_tokens=False) - newline = tokenizer.encode("\n", add_special_tokens=False) - system_ids = im_start + system_tag + system_content_ids + im_end + newline - - # User turn via chat template. - user_text = f"<|speaker:0|>{request.input}" - user_messages = [{"role": "user", "content": user_text}] - user_ids = tokenizer.apply_chat_template(user_messages, tokenize=True, add_generation_prompt=True) - prompt_ids = system_ids + user_ids - else: + if ref_audio_data is None or not request.ref_text: # No voice cloning: simple user message. user_text = f"<|speaker:0|>{request.input}" messages = [{"role": "user", "content": user_text}] prompt_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True) + voice_token_id = tokenizer.encode("<|voice|>", add_special_tokens=False) + prompt_ids = prompt_ids + voice_token_id - # Append <|voice|> token to signal voice generation. - voice_token_id = tokenizer.encode("<|voice|>", add_special_tokens=False) - prompt_ids = prompt_ids + voice_token_id + additional_information: dict[str, Any] = { + "text": [request.input], + "max_new_tokens": [request.max_new_tokens or 4096], + } + return { + "prompt_token_ids": prompt_ids, + "additional_information": additional_information, + } - additional_information: dict[str, Any] = { - "text": [request.input], - "max_new_tokens": [request.max_new_tokens or 4096], + wav_samples, sr = ref_audio_data + ph_len = self._estimate_fish_prompt_len(request, ref_audio_data) + with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f: + np.save(f, np.asarray(wav_samples, dtype=np.float32)) + ref_audio_path = f.name + + additional_information = { + "text": request.input, + "max_new_tokens": request.max_new_tokens or 4096, + "ref_text": request.ref_text, + "ref_audio_path": ref_audio_path, + "ref_audio_sr": int(sr), + "fish_structured_voice_clone": True, } - return { - "prompt_token_ids": prompt_ids, + "prompt_token_ids": [1] * ph_len, "additional_information": additional_information, } diff --git a/vllm_omni/model_executor/models/fish_speech/dac_encoder.py b/vllm_omni/model_executor/models/fish_speech/dac_encoder.py index 084a863ad37..e89815ab433 100644 --- a/vllm_omni/model_executor/models/fish_speech/dac_encoder.py +++ b/vllm_omni/model_executor/models/fish_speech/dac_encoder.py @@ -1,12 +1,12 @@ """DAC codec encoder for Fish Speech S2 Pro voice cloning. Encodes reference audio into VQ codes for use as prompt conditioning. -Runs on CPU in the API server process -- loaded lazily on first use. """ from __future__ import annotations import os +from functools import lru_cache import numpy as np import torch @@ -20,13 +20,20 @@ logger = init_logger(__name__) -_codec_cache: dict[str, nn.Module] = {} +_codec_cache: dict[tuple[str, str, str], nn.Module] = {} -def _load_dac_codec(model_path: str) -> nn.Module: - """Load the DAC codec model from codec.pth (cached per model_path).""" - if model_path in _codec_cache: - return _codec_cache[model_path] +def _load_dac_codec( + model_path: str, + *, + device: torch.device | str = "cpu", + dtype: torch.dtype = torch.float32, +) -> nn.Module: + """Load the DAC codec model from codec.pth.""" + device = torch.device(device) + cache_key = (model_path, str(device), str(dtype)) + if cache_key in _codec_cache: + return _codec_cache[cache_key] codec_path = os.path.join(model_path, "codec.pth") if not os.path.exists(codec_path): @@ -47,29 +54,36 @@ def _load_dac_codec(model_path: str) -> nn.Module: if "generator" in state_dict: state_dict = state_dict["generator"] codec.load_state_dict(state_dict, strict=False) + codec = codec.to(device=device, dtype=dtype) codec.eval() - _codec_cache[model_path] = codec - logger.info("Loaded DAC codec encoder from %s (CPU)", codec_path) + _codec_cache[cache_key] = codec + logger.info("Loaded DAC codec encoder from %s (%s, dtype=%s)", codec_path, device, dtype) return codec -def _resample(wav: np.ndarray, sr: int, target_sr: int) -> np.ndarray: - """Resample audio using torchaudio's polyphase resampling.""" - if sr == target_sr: - return wav +@lru_cache(maxsize=16) +def _get_resample_kernel( + source_sr: int, + target_sr: int, + device_type: str, + device_index: int | None, + dtype_name: str, +): import torchaudio - wav_t = torch.from_numpy(wav).unsqueeze(0).float() - wav_t = torchaudio.functional.resample(wav_t, sr, target_sr) - return wav_t.squeeze(0).numpy() + device = torch.device(device_type, device_index) if device_index is not None else torch.device(device_type) + dtype = getattr(torch, dtype_name) + return torchaudio.transforms.Resample(source_sr, target_sr).to(device=device, dtype=dtype) @torch.no_grad() def encode_reference_audio( model_path: str, - wav_samples: list[float] | np.ndarray, + wav_samples: list[float] | np.ndarray | torch.Tensor, sample_rate: int, + *, + device: torch.device | str | None = None, ) -> list[int]: """Encode reference audio into semantic token IDs for prompt conditioning. @@ -81,22 +95,48 @@ def encode_reference_audio( Returns: List of semantic token IDs (151678 + code_value for each frame). """ - codec = _load_dac_codec(model_path) - - wav = np.asarray(wav_samples, dtype=np.float32) - if wav.ndim > 1: - wav = np.mean(wav, axis=-1) - - # Resample to DAC sample rate (44100). - wav = _resample(wav, sample_rate, DAC_SAMPLE_RATE) + if device is None: + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + else: + device = torch.device(device) + dtype = torch.float32 + codec = _load_dac_codec(model_path, device=device, dtype=dtype) + + if isinstance(wav_samples, torch.Tensor): + wav_tensor = wav_samples.detach() + else: + wav_tensor = torch.as_tensor(wav_samples) + + wav_tensor = wav_tensor.to(device=device, dtype=dtype) + if wav_tensor.ndim == 2: + # Accept both [channels, samples] and [samples, channels] layouts. + if wav_tensor.shape[0] <= 8 and wav_tensor.shape[1] > wav_tensor.shape[0]: + wav_tensor = wav_tensor.mean(dim=0) + elif wav_tensor.shape[-1] <= 8 and wav_tensor.shape[0] > wav_tensor.shape[-1]: + wav_tensor = wav_tensor.mean(dim=-1) + else: + wav_tensor = wav_tensor.mean(dim=0) + elif wav_tensor.ndim > 2: + wav_tensor = wav_tensor.reshape(-1, wav_tensor.shape[-1]).mean(dim=0) + wav_tensor = wav_tensor.flatten() + + if sample_rate != DAC_SAMPLE_RATE: + resampler = _get_resample_kernel( + int(sample_rate), + DAC_SAMPLE_RATE, + device.type, + device.index, + "float32", + ) + wav_tensor = resampler(wav_tensor.unsqueeze(0)).squeeze(0) # Encode: [1, 1, T] -> codes [1, num_codebooks, num_frames] - wav_tensor = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0).float() - feature_lengths = torch.tensor([wav_tensor.shape[-1]]) + wav_tensor = wav_tensor.unsqueeze(0).unsqueeze(0) + feature_lengths = torch.tensor([wav_tensor.shape[-1]], device=device, dtype=torch.long) codes, feature_lengths_out = codec.encode(wav_tensor, feature_lengths) # Extract semantic codebook (index 0) - shape [num_frames]. - semantic_codes = codes[0, 0, :].tolist() + semantic_codes = codes[0, 0, :].to(device="cpu", dtype=torch.long).tolist() # Convert to semantic token IDs: <|semantic:{i}|> = 151678 + i SEMANTIC_TOKEN_OFFSET = 151678 @@ -104,7 +144,7 @@ def encode_reference_audio( logger.info( "Encoded reference audio: %d samples @ %dHz -> %d semantic tokens", - len(wav_samples), + int(wav_tensor.shape[-1]), sample_rate, len(semantic_token_ids), ) diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 6145815aac8..5bfe5541f7d 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -14,9 +14,11 @@ import dataclasses import math +import os from collections.abc import Iterable from typing import Any +import numpy as np import torch import torch.nn as nn from transformers import AutoTokenizer @@ -33,6 +35,7 @@ from vllm_omni.model_executor.models.output_templates import OmniOutput from .configuration_fish_speech import FishSpeechConfig, FishSpeechFastARConfig, FishSpeechSlowARConfig +from .dac_encoder import _load_dac_codec, encode_reference_audio from .fish_speech_fast_ar import FishSpeechFastAR logger = init_logger(__name__) @@ -360,7 +363,10 @@ def preprocess( dev = input_ids.device if is_first_prefill: - prompt_embeds = self._build_prefill_embeds(input_ids, info_dict) + if bool(info_dict.get("fish_structured_voice_clone", False)): + prompt_embeds = self._build_structured_voice_clone_prefill_embeds(info_dict) + else: + prompt_embeds = self._build_prefill_embeds(input_ids, info_dict) prompt_embeds_buf = prompt_embeds.detach().to("cpu").contiguous() if not prompt_embeds_buf.is_pinned(): prompt_embeds_buf = prompt_embeds_buf.pin_memory() @@ -507,6 +513,52 @@ def _build_prefill_embeds( result = base_embeds + codebook_sum return result.squeeze(0).to(dtype=torch.bfloat16) + def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any]) -> torch.Tensor: + tokenizer = self._get_tokenizer() + ref_text = info_dict.get("ref_text") + text = info_dict.get("text") + ref_audio_path = info_dict.get("ref_audio_path") + ref_audio_sr = info_dict.get("ref_audio_sr") + if not isinstance(ref_text, str) or not isinstance(text, str): + raise ValueError("Fish Speech structured voice clone requires string text and ref_text") + if not isinstance(ref_audio_path, str) or not ref_audio_path: + raise ValueError("Fish Speech structured voice clone requires ref_audio_path") + if not isinstance(ref_audio_sr, int): + raise ValueError("Fish Speech structured voice clone requires integer ref_audio_sr") + + ref_audio_wav = np.load(ref_audio_path) + os.remove(ref_audio_path) + + semantic_token_ids = encode_reference_audio( + self.model_path, + ref_audio_wav, + ref_audio_sr, + device=self.codebook_embeddings.weight.device, + ) + audio_start_id = tokenizer.encode("<|audio_start|>", add_special_tokens=False) + audio_end_id = tokenizer.encode("<|audio_end|>", add_special_tokens=False) + prefix_ids = tokenizer.encode(f"<|speaker:0|>{ref_text}", add_special_tokens=False) + im_start = tokenizer.encode("<|im_start|>", add_special_tokens=False) + im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False) + system_tag = tokenizer.encode("system\n", add_special_tokens=False) + newline = tokenizer.encode("\n", add_special_tokens=False) + system_ids = ( + im_start + system_tag + prefix_ids + audio_start_id + semantic_token_ids + audio_end_id + im_end + newline + ) + user_text = f"<|speaker:0|>{text}" + user_ids = tokenizer.apply_chat_template( + [{"role": "user", "content": user_text}], + tokenize=True, + add_generation_prompt=True, + ) + voice_token_id = tokenizer.encode("<|voice|>", add_special_tokens=False) + prompt_ids = torch.tensor( + system_ids + user_ids + voice_token_id, + dtype=torch.long, + device=self.codebook_embeddings.weight.device, + ) + return self.embed_input_ids(prompt_ids.unsqueeze(0)).squeeze(0).to(dtype=torch.bfloat16) + # -------------------- GPU-side MTP fast-path -------------------- @torch.inference_mode() @@ -681,4 +733,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: except Exception as exc: logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc) + codec_device = self.codebook_embeddings.weight.device + _load_dac_codec( + self.model_path, + device=codec_device, + dtype=torch.float32, + ) + return loaded_params