From 29b4199aa550bb5689d78b3a86e488b96a348a38 Mon Sep 17 00:00:00 2001 From: JuanPZuluaga Date: Sat, 16 May 2026 07:43:24 +0000 Subject: [PATCH] cache hot buffers in qwen3_tts talker; fall back on evicted state Signed-off-by: JuanPZuluaga --- .../models/qwen3_tts/qwen3_tts_talker.py | 110 +++++++++++++----- 1 file changed, 80 insertions(+), 30 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index a705a1cd48..a3a15ff9d4 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -4,7 +4,9 @@ import copy import io import os +from collections import OrderedDict from collections.abc import Callable, Iterable, Mapping +from functools import lru_cache from typing import Any from urllib.parse import urlparse @@ -255,6 +257,16 @@ def _dynamic_range_compression(x, c=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * c) +@lru_cache(maxsize=8) +def _cached_mel_filter_bank(sampling_rate: int, n_fft: int, n_mels: int, fmin: int, fmax: int | None) -> torch.Tensor: + return mel_filter_bank(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + + +@lru_cache(maxsize=8) +def _cached_hann_window(win_size: int) -> torch.Tensor: + return torch.hann_window(win_size) + + def mel_spectrogram( y: torch.Tensor, n_fft: int, @@ -275,17 +287,11 @@ def mel_spectrogram( logger.warning("Max value of input waveform signal is %s", torch.max(y)) device = y.device if mel_basis is None: - mel_basis = mel_filter_bank( - sr=sampling_rate, - n_fft=n_fft, - n_mels=num_mels, - fmin=fmin, - fmax=fmax, - ).to(device) + mel_basis = _cached_mel_filter_bank(sampling_rate, n_fft, num_mels, fmin, fmax).to(device) elif mel_basis.device != device: mel_basis = mel_basis.to(device) if hann_window is None: - hann_window = torch.hann_window(win_size, device=device) + hann_window = _cached_hann_window(win_size).to(device) elif hann_window.device != device: hann_window = hann_window.to(device) padding = (n_fft - hop_size) // 2 @@ -448,11 +454,29 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {} ) + self._stacked_codec_embed: torch.Tensor | None = None + # Bounded LRU: caller-supplied orig_sr can otherwise grow this without limit. + self._resampler_cache: OrderedDict[tuple[int, int], AudioResampler] = OrderedDict() + self._resampler_cache_max = 16 + self._tts_pad_embed_cache: torch.Tensor | None = None + # -------------------- vLLM required hooks -------------------- def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor: return self.model.embed_input_ids(input_ids) + def _get_tts_pad_embed_cached(self, device: torch.device) -> torch.Tensor: + """Return a device-pinned tts_pad embedding, rebuilding once per device.""" + cached = self._tts_pad_embed_cache + if cached is not None and cached.device == device: + return cached + pad_id = torch.tensor([[self.config.tts_pad_token_id]], device=device, dtype=torch.long) + with torch.no_grad(): + projected = self.text_projection(self.text_embedding(pad_id)) + cached = projected.reshape(1, -1).to(dtype=torch.bfloat16).detach() + self._tts_pad_embed_cache = cached + return cached + def forward( self, input_ids: torch.Tensor, @@ -519,21 +543,22 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A ref_len = meta.get("ref_code_len") if ref_len is None: continue + if not isinstance(ac, torch.Tensor): + continue + span_len = int(ac.shape[0]) if isinstance(ref_len, torch.Tensor): if ref_len.numel() == 0: raise ValueError("ref_code_len is an empty tensor") - ref_len_val = int(ref_len.reshape(-1)[-1].item()) - elif isinstance(ref_len, list): + ref_len_tail = ref_len.reshape(-1)[-1:].to(dtype=torch.int32, device=ac.device) + ref_code_len_list.append(ref_len_tail.expand(span_len).contiguous()) + continue + if isinstance(ref_len, list): if len(ref_len) != 1: raise ValueError(f"ref_code_len must be scalar or 1-element list, got len={len(ref_len)}") ref_len_val = int(ref_len[0]) else: ref_len_val = int(ref_len) - if isinstance(ac, torch.Tensor): - # Emit ref_code_len per-token span for runner slicing (consumer takes the last value). - ref_code_len_list.append( - torch.full((int(ac.shape[0]),), ref_len_val, dtype=torch.int32, device=ac.device) - ) + ref_code_len_list.append(torch.full((span_len,), ref_len_val, dtype=torch.int32, device=ac.device)) if not audio_codes_list: return OmniOutput(text_hidden_states=hidden, multimodal_outputs={}) @@ -678,13 +703,13 @@ def preprocess( info_update.setdefault("codes", {})["audio"] = zeros return input_ids_out, prompt_embeds, info_update - # Decode: span_len == 1 - # Pop one text-step vector from tailing_text_hidden queue. - # These tensors stay on GPU via gpu_resident_buffer_keys - .to() is a no-op. + # Decode: span_len == 1. Buffers are GPU-resident so .to() is a no-op. tts_pad_embed_buf = embed.get("tts_pad") - if not isinstance(tts_pad_embed_buf, torch.Tensor): - raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.") - tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + if isinstance(tts_pad_embed_buf, torch.Tensor): + tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + else: + # Defensive: rebuild from text_embedding when prefill state was evicted. + tts_pad_embed = self._get_tts_pad_embed_cached(input_ids.device) tail = hs.get("trailing_text") text_offset = max(0, int(meta.get("talker_text_offset", 0) or 0)) @@ -720,9 +745,11 @@ def preprocess( next_text_offset = text_offset last_hidden = hs.get("last") - if not isinstance(last_hidden, torch.Tensor): - raise RuntimeError("Missing hidden_states['last'] in additional_information; postprocess must run.") - past_hidden = last_hidden.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + if isinstance(last_hidden, torch.Tensor): + past_hidden = last_hidden.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1) + else: + # Defensive: EOS step row is zeroed by the invalid-layer-0 mask and filtered downstream. + past_hidden = torch.zeros_like(text_step) # Use OmniGPUModelRunner talker_mtp fast-path for residual codebooks and per-step inputs_embeds update. last_id_hidden = self.embed_input_ids(input_ids.reshape(1, 1).to(torch.long)).to( @@ -1286,7 +1313,7 @@ def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor: # Resample to 24kHz for speaker encoder. target_sr = int(getattr(self.config.speaker_encoder_config, "sample_rate", 24000)) if sr != target_sr: - resampler = AudioResampler(target_sr=target_sr) + resampler = self._get_resampler(int(sr), target_sr) wav = resampler.resample(wav.astype(np.float32), orig_sr=int(sr)) sr = target_sr @@ -2097,8 +2124,29 @@ def _talker_and_collect_speaker(ws: Iterable[tuple[str, torch.Tensor]]): # eagerly initialized module and satisfy the strict loader check. loaded |= {name for name, _ in self.named_parameters() if name.startswith("speaker_encoder.")} logger.info("Loaded %d weights for Qwen3TTSTalkerForConditionalGeneration", len(loaded)) + self._build_stacked_codec_embed() return loaded + def _build_stacked_codec_embed(self) -> None: + embeds = self.code_predictor.get_input_embeddings() + if not embeds: + return + w = embeds[0].weight + self._stacked_codec_embed = torch.stack([e.weight.detach() for e in embeds], dim=0).to( + device=w.device, dtype=w.dtype + ) + + def _get_resampler(self, orig_sr: int, target_sr: int) -> AudioResampler: + key = (orig_sr, target_sr) + cache = self._resampler_cache + if key in cache: + cache.move_to_end(key) + return cache[key] + cache[key] = AudioResampler(target_sr=target_sr) + if len(cache) > self._resampler_cache_max: + cache.popitem(last=False) + return cache[key] + # -------------------- GPU-side MTP fast-path -------------------- @torch.inference_mode() @@ -2158,11 +2206,13 @@ def talker_mtp( invalid0 = (layer0 < 0) | (layer0 >= int(self._codebook_vocab_size)) audio_codes = torch.where(invalid0.expand_as(audio_codes), torch.zeros_like(audio_codes), audio_codes) - # Sum embeddings of all code groups, then add the current text step. + # Single gather over stacked [Q-1, V, H] replaces Q-1 serial embedding kernels. residual_ids_t = audio_codes[:, 1:] - embeds: list[torch.Tensor] = [last_id_hidden] - for i in range(max_steps): - embeds.append(self.code_predictor.get_input_embeddings()[i](residual_ids_t[:, i : i + 1])) - summed = torch.cat(embeds, dim=1).sum(1, keepdim=True) # [B,1,H] + if self._stacked_codec_embed is None: + self._build_stacked_codec_embed() + embed_weight = self._stacked_codec_embed.to(device=dev) + row_idx = torch.arange(max_steps, device=dev).unsqueeze(0).expand(bsz, -1) + gathered = embed_weight[row_idx, residual_ids_t] + summed = (last_id_hidden.squeeze(1) + gathered.sum(dim=1)).unsqueeze(1) inputs_embeds_out = (summed + text_step).reshape(bsz, -1) return inputs_embeds_out, audio_codes.to(dtype=torch.long)