Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 91 additions & 39 deletions vllm_omni/entrypoints/openai/serving_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import struct
import tempfile
import time
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -341,6 +342,71 @@ def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int:
logger.warning("Failed to estimate TTS prompt length, using fallback 2048: %s", e)
return 2048

def _estimate_fish_ref_code_len(self, ref_audio: object) -> int | None:
"""Estimate Fish Speech semantic token length from raw reference audio."""
from vllm_omni.model_executor.models.fish_speech.dac_utils import (
DAC_HOP_LENGTH,
DAC_SAMPLE_RATE,
)

if not isinstance(ref_audio, (list, tuple)) or len(ref_audio) != 2:
return None
wav, sr = ref_audio
sr = int(sr)
n_samples = len(wav)
if sr <= 0 or n_samples <= 0:
return None
resampled_len = max(1, math.ceil(n_samples * DAC_SAMPLE_RATE / sr))
return max(1, math.ceil(resampled_len / DAC_HOP_LENGTH))

def _estimate_fish_prompt_len(
self,
request: OpenAICreateSpeechRequest,
ref_audio: object,
) -> int:
"""Estimate Fish Speech clone prompt length without encoding reference audio."""
try:
from transformers import AutoTokenizer

if self._fish_speech_tokenizer is None:
model_name = self.engine_client.model_config.model
self._fish_speech_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

tokenizer = self._fish_speech_tokenizer
semantic_len = self._estimate_fish_ref_code_len(ref_audio)
if semantic_len is None:
raise ValueError("Failed to estimate Fish Speech semantic token length")

user_text = f"<|speaker:0|>{request.input}"
user_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": user_text}],
tokenize=True,
add_generation_prompt=True,
)
voice_token_id = tokenizer.encode("<|voice|>", add_special_tokens=False)
audio_start_id = tokenizer.encode("<|audio_start|>", add_special_tokens=False)
audio_end_id = tokenizer.encode("<|audio_end|>", add_special_tokens=False)
prefix_ids = tokenizer.encode(f"<|speaker:0|>{request.ref_text}", add_special_tokens=False)
im_start = tokenizer.encode("<|im_start|>", add_special_tokens=False)
im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False)
system_tag = tokenizer.encode("system\n", add_special_tokens=False)
newline = tokenizer.encode("\n", add_special_tokens=False)
return (
len(im_start)
+ len(system_tag)
+ len(prefix_ids)
+ len(audio_start_id)
+ semantic_len
+ len(audio_end_id)
+ len(im_end)
+ len(newline)
+ len(user_ids)
+ len(voice_token_id)
)
except Exception as e:
logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e)
return 2048

def _get_uploaded_audio_data(self, voice_name: str) -> str | None:
"""Get base64 encoded audio data for uploaded voice."""
voice_name_lower = voice_name.lower()
Expand Down Expand Up @@ -900,53 +966,39 @@ def _build_fish_speech_prompt(
tokenizer = self._fish_speech_tokenizer
model_name = self.engine_client.model_config.model

if ref_audio_data is not None and request.ref_text:
# Voice cloning: encode reference audio and build system message.
from vllm_omni.model_executor.models.fish_speech.dac_encoder import (
encode_reference_audio,
)

wav_samples, sr = ref_audio_data
semantic_token_ids = encode_reference_audio(model_name, wav_samples, sr)

# Build system message with ref text + audio tokens.
audio_start_id = tokenizer.encode("<|audio_start|>", add_special_tokens=False)
audio_end_id = tokenizer.encode("<|audio_end|>", add_special_tokens=False)

# System content: <|speaker:0|>{ref_text}<|audio_start|>{codes}<|audio_end|>
prefix_text = f"<|speaker:0|>{request.ref_text}"
prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False)
system_content_ids = prefix_ids + audio_start_id + semantic_token_ids + audio_end_id

# Manually build system turn: <|im_start|>system\n{content}<|im_end|>\n
im_start = tokenizer.encode("<|im_start|>", add_special_tokens=False)
im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False)
system_tag = tokenizer.encode("system\n", add_special_tokens=False)
newline = tokenizer.encode("\n", add_special_tokens=False)
system_ids = im_start + system_tag + system_content_ids + im_end + newline

# User turn via chat template.
user_text = f"<|speaker:0|>{request.input}"
user_messages = [{"role": "user", "content": user_text}]
user_ids = tokenizer.apply_chat_template(user_messages, tokenize=True, add_generation_prompt=True)
prompt_ids = system_ids + user_ids
else:
if ref_audio_data is None or not request.ref_text:
# No voice cloning: simple user message.
user_text = f"<|speaker:0|>{request.input}"
messages = [{"role": "user", "content": user_text}]
prompt_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
voice_token_id = tokenizer.encode("<|voice|>", add_special_tokens=False)
prompt_ids = prompt_ids + voice_token_id

# Append <|voice|> token to signal voice generation.
voice_token_id = tokenizer.encode("<|voice|>", add_special_tokens=False)
prompt_ids = prompt_ids + voice_token_id
additional_information: dict[str, Any] = {
"text": [request.input],
"max_new_tokens": [request.max_new_tokens or 4096],
}
return {
"prompt_token_ids": prompt_ids,
"additional_information": additional_information,
}

additional_information: dict[str, Any] = {
"text": [request.input],
"max_new_tokens": [request.max_new_tokens or 4096],
wav_samples, sr = ref_audio_data
ph_len = self._estimate_fish_prompt_len(request, ref_audio_data)
with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f:
np.save(f, np.asarray(wav_samples, dtype=np.float32))
ref_audio_path = f.name
Comment on lines +988 to +990
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Don't serialize clone audio as a local temp path

This path writes reference audio to a node-local temp file and passes only the filename through additional_information; the serialization layer transports that as a scalar string, not file contents, so deployments where the slow-AR worker is not on the same filesystem (e.g., disaggregated/non-mp executors) cannot open the file in np.load(ref_audio_path) and voice cloning fails.

Useful? React with 👍 / 👎.


additional_information = {
"text": request.input,
"max_new_tokens": request.max_new_tokens or 4096,
"ref_text": request.ref_text,
"ref_audio_path": ref_audio_path,
"ref_audio_sr": int(sr),
"fish_structured_voice_clone": True,
}

return {
"prompt_token_ids": prompt_ids,
"prompt_token_ids": [1] * ph_len,
"additional_information": additional_information,
}

Expand Down
96 changes: 68 additions & 28 deletions vllm_omni/model_executor/models/fish_speech/dac_encoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""DAC codec encoder for Fish Speech S2 Pro voice cloning.

Encodes reference audio into VQ codes for use as prompt conditioning.
Runs on CPU in the API server process -- loaded lazily on first use.
"""

from __future__ import annotations

import os
from functools import lru_cache

import numpy as np
import torch
Expand All @@ -20,13 +20,20 @@

logger = init_logger(__name__)

_codec_cache: dict[str, nn.Module] = {}
_codec_cache: dict[tuple[str, str, str], nn.Module] = {}


def _load_dac_codec(model_path: str) -> nn.Module:
"""Load the DAC codec model from codec.pth (cached per model_path)."""
if model_path in _codec_cache:
return _codec_cache[model_path]
def _load_dac_codec(
model_path: str,
*,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,
) -> nn.Module:
"""Load the DAC codec model from codec.pth."""
device = torch.device(device)
cache_key = (model_path, str(device), str(dtype))
if cache_key in _codec_cache:
return _codec_cache[cache_key]

codec_path = os.path.join(model_path, "codec.pth")
if not os.path.exists(codec_path):
Expand All @@ -47,29 +54,36 @@ def _load_dac_codec(model_path: str) -> nn.Module:
if "generator" in state_dict:
state_dict = state_dict["generator"]
codec.load_state_dict(state_dict, strict=False)
codec = codec.to(device=device, dtype=dtype)
codec.eval()

_codec_cache[model_path] = codec
logger.info("Loaded DAC codec encoder from %s (CPU)", codec_path)
_codec_cache[cache_key] = codec
logger.info("Loaded DAC codec encoder from %s (%s, dtype=%s)", codec_path, device, dtype)
return codec


def _resample(wav: np.ndarray, sr: int, target_sr: int) -> np.ndarray:
"""Resample audio using torchaudio's polyphase resampling."""
if sr == target_sr:
return wav
@lru_cache(maxsize=16)
def _get_resample_kernel(
source_sr: int,
target_sr: int,
device_type: str,
device_index: int | None,
dtype_name: str,
):
import torchaudio

wav_t = torch.from_numpy(wav).unsqueeze(0).float()
wav_t = torchaudio.functional.resample(wav_t, sr, target_sr)
return wav_t.squeeze(0).numpy()
device = torch.device(device_type, device_index) if device_index is not None else torch.device(device_type)
dtype = getattr(torch, dtype_name)
return torchaudio.transforms.Resample(source_sr, target_sr).to(device=device, dtype=dtype)


@torch.no_grad()
def encode_reference_audio(
model_path: str,
wav_samples: list[float] | np.ndarray,
wav_samples: list[float] | np.ndarray | torch.Tensor,
sample_rate: int,
*,
device: torch.device | str | None = None,
) -> list[int]:
"""Encode reference audio into semantic token IDs for prompt conditioning.

Expand All @@ -81,30 +95,56 @@ def encode_reference_audio(
Returns:
List of semantic token IDs (151678 + code_value for each frame).
"""
codec = _load_dac_codec(model_path)

wav = np.asarray(wav_samples, dtype=np.float32)
if wav.ndim > 1:
wav = np.mean(wav, axis=-1)

# Resample to DAC sample rate (44100).
wav = _resample(wav, sample_rate, DAC_SAMPLE_RATE)
if device is None:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
device = torch.device(device)
dtype = torch.float32
codec = _load_dac_codec(model_path, device=device, dtype=dtype)

if isinstance(wav_samples, torch.Tensor):
wav_tensor = wav_samples.detach()
else:
wav_tensor = torch.as_tensor(wav_samples)

wav_tensor = wav_tensor.to(device=device, dtype=dtype)
if wav_tensor.ndim == 2:
# Accept both [channels, samples] and [samples, channels] layouts.
if wav_tensor.shape[0] <= 8 and wav_tensor.shape[1] > wav_tensor.shape[0]:
wav_tensor = wav_tensor.mean(dim=0)
elif wav_tensor.shape[-1] <= 8 and wav_tensor.shape[0] > wav_tensor.shape[-1]:
wav_tensor = wav_tensor.mean(dim=-1)
else:
wav_tensor = wav_tensor.mean(dim=0)
elif wav_tensor.ndim > 2:
wav_tensor = wav_tensor.reshape(-1, wav_tensor.shape[-1]).mean(dim=0)
wav_tensor = wav_tensor.flatten()

if sample_rate != DAC_SAMPLE_RATE:
resampler = _get_resample_kernel(
int(sample_rate),
DAC_SAMPLE_RATE,
device.type,
device.index,
"float32",
)
wav_tensor = resampler(wav_tensor.unsqueeze(0)).squeeze(0)

# Encode: [1, 1, T] -> codes [1, num_codebooks, num_frames]
wav_tensor = torch.from_numpy(wav).unsqueeze(0).unsqueeze(0).float()
feature_lengths = torch.tensor([wav_tensor.shape[-1]])
wav_tensor = wav_tensor.unsqueeze(0).unsqueeze(0)
feature_lengths = torch.tensor([wav_tensor.shape[-1]], device=device, dtype=torch.long)
codes, feature_lengths_out = codec.encode(wav_tensor, feature_lengths)

# Extract semantic codebook (index 0) - shape [num_frames].
semantic_codes = codes[0, 0, :].tolist()
semantic_codes = codes[0, 0, :].to(device="cpu", dtype=torch.long).tolist()

# Convert to semantic token IDs: <|semantic:{i}|> = 151678 + i
SEMANTIC_TOKEN_OFFSET = 151678
semantic_token_ids = [SEMANTIC_TOKEN_OFFSET + int(c) for c in semantic_codes]

logger.info(
"Encoded reference audio: %d samples @ %dHz -> %d semantic tokens",
len(wav_samples),
int(wav_tensor.shape[-1]),
sample_rate,
len(semantic_token_ids),
)
Expand Down
Loading
Loading