Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 79 additions & 30 deletions vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from collections import OrderedDict
from collections.abc import Callable, Iterable, Mapping
from functools import lru_cache
from typing import Any
from urllib.parse import urlparse

Expand Down Expand Up @@ -258,6 +259,16 @@ def _dynamic_range_compression(x, c=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * c)


@lru_cache(maxsize=8)
def _cached_mel_filter_bank(sampling_rate: int, n_fft: int, n_mels: int, fmin: int, fmax: int | None) -> torch.Tensor:
return mel_filter_bank(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)


@lru_cache(maxsize=8)
def _cached_hann_window(win_size: int) -> torch.Tensor:
return torch.hann_window(win_size)


Comment on lines +262 to +271
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some kind of global cache for these kind of objects? Like there are other systems that also extract mel_filter_bank or do resampling. @linyueqian

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense to me. this should be a common module.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, i'll work on another PR on making this for all the audio models.

def mel_spectrogram(
y: torch.Tensor,
n_fft: int,
Expand All @@ -278,17 +289,11 @@ def mel_spectrogram(
logger.warning("Max value of input waveform signal is %s", torch.max(y))
device = y.device
if mel_basis is None:
mel_basis = mel_filter_bank(
sr=sampling_rate,
n_fft=n_fft,
n_mels=num_mels,
fmin=fmin,
fmax=fmax,
).to(device)
mel_basis = _cached_mel_filter_bank(sampling_rate, n_fft, num_mels, fmin, fmax).to(device)
elif mel_basis.device != device:
mel_basis = mel_basis.to(device)
if hann_window is None:
hann_window = torch.hann_window(win_size, device=device)
hann_window = _cached_hann_window(win_size).to(device)
elif hann_window.device != device:
hann_window = hann_window.to(device)
padding = (n_fft - hop_size) // 2
Expand Down Expand Up @@ -464,11 +469,29 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {}
)

self._stacked_codec_embed: torch.Tensor | None = None
# Bounded LRU: caller-supplied orig_sr can otherwise grow this without limit.
self._resampler_cache: OrderedDict[tuple[int, int], AudioResampler] = OrderedDict()
self._resampler_cache_max = 16
self._tts_pad_embed_cache: torch.Tensor | None = None

# -------------------- vLLM required hooks --------------------

def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

def _get_tts_pad_embed_cached(self, device: torch.device) -> torch.Tensor:
"""Return a device-pinned tts_pad embedding, rebuilding once per device."""
cached = self._tts_pad_embed_cache
if cached is not None and cached.device == device:
return cached
pad_id = torch.tensor([[self.config.tts_pad_token_id]], device=device, dtype=torch.long)
with torch.no_grad():
projected = self.text_projection(self.text_embedding(pad_id))
cached = projected.reshape(1, -1).to(dtype=torch.bfloat16).detach()
self._tts_pad_embed_cache = cached
return cached

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -535,21 +558,22 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
ref_len = meta.get("ref_code_len")
if ref_len is None:
continue
if not isinstance(ac, torch.Tensor):
continue
span_len = int(ac.shape[0])
if isinstance(ref_len, torch.Tensor):
if ref_len.numel() == 0:
raise ValueError("ref_code_len is an empty tensor")
ref_len_val = int(ref_len.reshape(-1)[-1].item())
elif isinstance(ref_len, list):
ref_len_tail = ref_len.reshape(-1)[-1:].to(dtype=torch.int32, device=ac.device)
ref_code_len_list.append(ref_len_tail.expand(span_len).contiguous())
continue
if isinstance(ref_len, list):
if len(ref_len) != 1:
raise ValueError(f"ref_code_len must be scalar or 1-element list, got len={len(ref_len)}")
ref_len_val = int(ref_len[0])
else:
ref_len_val = int(ref_len)
if isinstance(ac, torch.Tensor):
# Emit ref_code_len per-token span for runner slicing (consumer takes the last value).
ref_code_len_list.append(
torch.full((int(ac.shape[0]),), ref_len_val, dtype=torch.int32, device=ac.device)
)
ref_code_len_list.append(torch.full((span_len,), ref_len_val, dtype=torch.int32, device=ac.device))

if not audio_codes_list:
return OmniOutput(text_hidden_states=hidden, multimodal_outputs={})
Expand Down Expand Up @@ -696,13 +720,13 @@ def preprocess(
info_update.setdefault("codes", {})["audio"] = zeros
return input_ids_out, prompt_embeds, info_update

# Decode: span_len == 1
# Pop one text-step vector from tailing_text_hidden queue.
# These tensors stay on GPU via gpu_resident_buffer_keys - .to() is a no-op.
# Decode: span_len == 1. Buffers are GPU-resident so .to() is a no-op.
tts_pad_embed_buf = embed.get("tts_pad")
if not isinstance(tts_pad_embed_buf, torch.Tensor):
raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.")
tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
if isinstance(tts_pad_embed_buf, torch.Tensor):
tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
else:
# Defensive: rebuild from text_embedding when prefill state was evicted.
tts_pad_embed = self._get_tts_pad_embed_cached(input_ids.device)

tail = hs.get("trailing_text")
text_offset = max(0, int(meta.get("talker_text_offset", 0) or 0))
Expand Down Expand Up @@ -738,9 +762,11 @@ def preprocess(
next_text_offset = text_offset

last_hidden = hs.get("last")
if not isinstance(last_hidden, torch.Tensor):
raise RuntimeError("Missing hidden_states['last'] in additional_information; postprocess must run.")
past_hidden = last_hidden.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
if isinstance(last_hidden, torch.Tensor):
past_hidden = last_hidden.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
else:
# Defensive: EOS step row is zeroed by the invalid-layer-0 mask and filtered downstream.
past_hidden = torch.zeros_like(text_step)

# Use OmniGPUModelRunner talker_mtp fast-path for residual codebooks and per-step inputs_embeds update.
last_id_hidden = self.embed_input_ids(input_ids.reshape(1, 1).to(torch.long)).to(
Expand Down Expand Up @@ -1304,7 +1330,7 @@ def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor:
# Resample to 24kHz for speaker encoder.
target_sr = int(getattr(self.config.speaker_encoder_config, "sample_rate", 24000))
if sr != target_sr:
resampler = AudioResampler(target_sr=target_sr)
resampler = self._get_resampler(int(sr), target_sr)
wav = resampler.resample(wav.astype(np.float32), orig_sr=int(sr))
sr = target_sr

Expand Down Expand Up @@ -2215,8 +2241,29 @@ def _talker_and_collect_speaker(ws: Iterable[tuple[str, torch.Tensor]]):
# eagerly initialized module and satisfy the strict loader check.
loaded |= {name for name, _ in self.named_parameters() if name.startswith("speaker_encoder.")}
logger.info("Loaded %d weights for Qwen3TTSTalkerForConditionalGeneration", len(loaded))
self._build_stacked_codec_embed()
return loaded

def _build_stacked_codec_embed(self) -> None:
embeds = self.code_predictor.get_input_embeddings()
if not embeds:
return
w = embeds[0].weight
self._stacked_codec_embed = torch.stack([e.weight.detach() for e in embeds], dim=0).to(
device=w.device, dtype=w.dtype
)

def _get_resampler(self, orig_sr: int, target_sr: int) -> AudioResampler:
key = (orig_sr, target_sr)
cache = self._resampler_cache
if key in cache:
cache.move_to_end(key)
return cache[key]
cache[key] = AudioResampler(target_sr=target_sr)
if len(cache) > self._resampler_cache_max:
cache.popitem(last=False)
return cache[key]

# -------------------- GPU-side MTP fast-path --------------------

@torch.inference_mode()
Expand Down Expand Up @@ -2276,11 +2323,13 @@ def talker_mtp(
invalid0 = (layer0 < 0) | (layer0 >= int(self._codebook_vocab_size))
audio_codes = torch.where(invalid0.expand_as(audio_codes), torch.zeros_like(audio_codes), audio_codes)

# Sum embeddings of all code groups, then add the current text step.
# Single gather over stacked [Q-1, V, H] replaces Q-1 serial embedding kernels.
residual_ids_t = audio_codes[:, 1:]
embeds: list[torch.Tensor] = [last_id_hidden]
for i in range(max_steps):
embeds.append(self.code_predictor.get_input_embeddings()[i](residual_ids_t[:, i : i + 1]))
summed = torch.cat(embeds, dim=1).sum(1, keepdim=True) # [B,1,H]
if self._stacked_codec_embed is None:
self._build_stacked_codec_embed()
embed_weight = self._stacked_codec_embed.to(device=dev)
row_idx = torch.arange(max_steps, device=dev).unsqueeze(0).expand(bsz, -1)
gathered = embed_weight[row_idx, residual_ids_t]
summed = (last_id_hidden.squeeze(1) + gathered.sum(dim=1)).unsqueeze(1)
inputs_embeds_out = (summed + text_step).reshape(bsz, -1)
return inputs_embeds_out, audio_codes.to(dtype=torch.long)
Loading