From 526b890a2f385d6790c06e74aa96b14b25c6ed0a Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 5 Jan 2026 14:55:57 +0800 Subject: [PATCH 01/24] perf glmasr Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 442 +++++++- vllm/model_executor/models/glmasr_utils.py | 1076 +++++++++++++++++--- 2 files changed, 1351 insertions(+), 167 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index cec328ca7c54..681603bec482 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging from collections.abc import Iterable, Mapping, Sequence +from functools import lru_cache from typing import Annotated, Any, Literal, TypeAlias, cast import numpy as np import torch import torch.nn as nn from transformers import BatchFeature -from transformers.models.glmasr import GlmAsrConfig, GlmAsrEncoder, GlmAsrProcessor +from transformers.models.glmasr import GlmAsrConfig, GlmAsrProcessor from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig @@ -35,6 +37,7 @@ MultiModalDataParser, ) from vllm.multimodal.processing import ( + BaseMultiModalProcessor, PromptReplacement, PromptUpdate, PromptUpdateDetails, @@ -47,7 +50,6 @@ from .audioflamingo3 import ( AudioFlamingo3MultiModalDataParser, - AudioFlamingo3MultiModalProcessor, AudioFlamingo3ProcessingInfo, ) from .audioflamingo3 import ( @@ -57,6 +59,7 @@ DEFAULT_CONV_PARAMS, DEFAULT_MAX_AUDIO_LEN_S, DEFAULT_MERGE_FACTOR, + GlmAsrEncoder, _flatten_audio_features_by_length, _get_audio_output_lengths_for_tower, _get_num_features_for_item, @@ -73,6 +76,240 @@ from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .whisper import ISO639_1_SUPPORTED_LANGS +logger = logging.getLogger(__name__) + + +# ============================================================================= +# GPU-accelerated Whisper Feature Extractor +# ============================================================================= + + +@lru_cache(maxsize=1) +def _get_mel_filters( + n_fft: int = 400, + n_mels: int = 80, + sampling_rate: int = 16000, + device: torch.device | None = None, +) -> torch.Tensor: + """ + Compute mel filterbank matrix (cached). + Matches WhisperFeatureExtractor's mel_filter_bank with slaney norm/scale. + """ + if device is None: + device = torch.device("cpu") + # Frequency bins + n_freqs = n_fft // 2 + 1 + all_freqs = torch.linspace(0, sampling_rate // 2, n_freqs, device=device) + + # Mel scale conversion (slaney) + min_mel = 0.0 + max_mel = 2595.0 * np.log10(1.0 + (sampling_rate / 2) / 700.0) + mels = torch.linspace(min_mel, max_mel, n_mels + 2, device=device) + mel_freqs = 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + + # Create filterbank + mel_filters = torch.zeros(n_mels, n_freqs, device=device) + for i in range(n_mels): + lower = mel_freqs[i] + center = mel_freqs[i + 1] + upper = mel_freqs[i + 2] + + # Lower slope + lower_slope = (all_freqs - lower) / (center - lower + 1e-10) + # Upper slope + upper_slope = (upper - all_freqs) / (upper - center + 1e-10) + + mel_filters[i] = torch.maximum( + torch.zeros_like(all_freqs), + torch.minimum(lower_slope, upper_slope), + ) + + # Slaney normalization + enorm = 2.0 / (mel_freqs[2 : n_mels + 2] - mel_freqs[:n_mels]) + mel_filters *= enorm.unsqueeze(1) + + return mel_filters + + +class GPUWhisperFeatureExtractor: + """ + GPU-accelerated Whisper feature extractor using PyTorch. + Computes log-mel spectrogram matching WhisperFeatureExtractor output. + + Key parameters (Whisper defaults): + - n_fft: 400 (25ms window at 16kHz) + - hop_length: 160 (10ms hop at 16kHz) + - n_mels: 80 + - chunk_length: 30 seconds + - sampling_rate: 16000 + """ + + def __init__( + self, + feature_size: int = 80, + sampling_rate: int = 16000, + hop_length: int = 160, + chunk_length: int = 30, + n_fft: int = 400, + padding_value: float = 0.0, + device: str | torch.device = "cuda", + ): + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_fft = n_fft + self.padding_value = padding_value + self.device = torch.device(device) if isinstance(device, str) else device + + # Derived parameters + self.n_samples = chunk_length * sampling_rate # 480000 for 30s + self.nb_max_frames = self.n_samples // hop_length # 3000 frames + + # Pre-compute window and mel filters on device + self._window: torch.Tensor | None = None + self._mel_filters: torch.Tensor | None = None + + def _ensure_buffers(self, device: torch.device) -> None: + """Lazily initialize buffers on the target device.""" + if self._window is None or self._window.device != device: + self._window = torch.hann_window(self.n_fft, device=device) + + if self._mel_filters is None or self._mel_filters.device != device: + self._mel_filters = _get_mel_filters( + n_fft=self.n_fft, + n_mels=self.feature_size, + sampling_rate=self.sampling_rate, + device=device, + ) + + def __call__( + self, + raw_speech: list[np.ndarray] | np.ndarray | torch.Tensor, + sampling_rate: int | None = None, + padding: str = "max_length", + max_length: int | None = None, + return_attention_mask: bool = True, + return_tensors: str = "pt", + device: str | torch.device | None = None, + ) -> BatchFeature: + """ + Extract log-mel spectrogram features from audio. + + Args: + raw_speech: Audio waveform(s), can be list of arrays or batched + sampling_rate: Expected sample rate (must match self.sampling_rate) + padding: Padding strategy ('max_length' or 'longest') + max_length: Max samples (default: self.n_samples = 30s * 16kHz) + return_attention_mask: Whether to return attention mask + return_tensors: Output format ('pt' for PyTorch) + device: Device for computation (default: self.device) + + Returns: + BatchFeature with 'input_features' and optionally 'attention_mask' + """ + if sampling_rate is not None and sampling_rate != self.sampling_rate: + raise ValueError( + f"Expected sampling_rate={self.sampling_rate}, got {sampling_rate}" + ) + + device = torch.device(device) if device else self.device + max_length = max_length or self.n_samples + + # Convert inputs to list of 1D tensors + if isinstance(raw_speech, np.ndarray): + raw_speech = [raw_speech] if raw_speech.ndim == 1 else list(raw_speech) + elif isinstance(raw_speech, torch.Tensor): + raw_speech = ( + [raw_speech.numpy()] + if raw_speech.ndim == 1 + else [s.numpy() for s in raw_speech] + ) + + batch_size = len(raw_speech) + + # Get actual lengths before padding + lengths = [len(s) for s in raw_speech] + + # Pad/truncate to max_length + if padding == "max_length": + target_length = max_length + else: # 'longest' + target_length = min(max(lengths), max_length) + + # Create padded batch tensor + padded_waveforms = torch.zeros( + batch_size, target_length, dtype=torch.float32, device=device + ) + attention_mask = torch.zeros( + batch_size, target_length, dtype=torch.int32, device=device + ) + + for i, waveform in enumerate(raw_speech): + if isinstance(waveform, np.ndarray): + waveform = torch.from_numpy(waveform) + waveform = waveform.to(device=device, dtype=torch.float32) + + # Truncate if needed + actual_len = min(len(waveform), target_length) + padded_waveforms[i, :actual_len] = waveform[:actual_len] + attention_mask[i, :actual_len] = 1 + + # Extract features on GPU + input_features = self._extract_fbank_features(padded_waveforms) + + # Rescale attention mask from samples to frames + # STFT produces L//hop_length + 1 frames, but we drop the last one + frame_attention_mask = attention_mask[:, :: self.hop_length] + # Trim to match actual frame count (we drop last frame in _extract) + if attention_mask.shape[1] % self.hop_length != 0: + frame_attention_mask = frame_attention_mask[:, :-1] + + result: dict[str, Any] = {"input_features": input_features} + if return_attention_mask: + result["attention_mask"] = frame_attention_mask + + return BatchFeature(data=result, tensor_type=return_tensors) + + def _extract_fbank_features(self, waveforms: torch.Tensor) -> torch.Tensor: + """ + Compute log-mel spectrogram for batched waveforms. + + Args: + waveforms: [batch, samples] float32 tensor on target device + + Returns: + [batch, n_mels, frames] float32 tensor (log-mel spectrogram) + """ + device = waveforms.device + self._ensure_buffers(device) + + # STFT: [batch, samples] -> [batch, n_fft//2+1, frames] complex + stft = torch.stft( + waveforms, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=self._window, + return_complex=True, + ) + + # Power spectrogram, drop last frame (matching HF implementation) + magnitudes = stft[..., :-1].abs() ** 2 # [batch, n_freqs, frames] + + # Apply mel filterbank: [n_mels, n_freqs] @ [batch, n_freqs, frames] + # -> [batch, n_mels, frames] + mel_spec = torch.matmul(self._mel_filters, magnitudes) + + # Log scale with floor + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + + # Per-sample normalization (max - 8.0 floor, then scale) + max_val = log_spec.amax(dim=(1, 2), keepdim=True) + log_spec = torch.maximum(log_spec, max_val - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec + class GlmAsrFeatureInputs(TensorSchema): """ @@ -203,7 +440,35 @@ def _parse_audio_data( return super()._parse_audio_data(data) -class GlmAsrMultiModalProcessor(AudioFlamingo3MultiModalProcessor): +class GlmAsrMultiModalProcessor(BaseMultiModalProcessor["GlmAsrProcessingInfo"]): + """ + GLM-ASR processor that inherits directly from BaseMultiModalProcessor + for better performance and cleaner implementation. + Uses GPU-accelerated feature extraction for improved throughput. + """ + + # Shared GPU feature extractor instance (lazy initialized) + _gpu_feature_extractor: GPUWhisperFeatureExtractor | None = None + + @classmethod + def _get_gpu_feature_extractor( + cls, + hf_feature_extractor: WhisperFeatureExtractor, + device: str = "cuda", + ) -> GPUWhisperFeatureExtractor: + """Get or create GPU feature extractor matching HF config.""" + if cls._gpu_feature_extractor is None: + cls._gpu_feature_extractor = GPUWhisperFeatureExtractor( + feature_size=hf_feature_extractor.feature_size, + sampling_rate=hf_feature_extractor.sampling_rate, + hop_length=hf_feature_extractor.hop_length, + chunk_length=hf_feature_extractor.chunk_length, + n_fft=hf_feature_extractor.n_fft, + padding_value=hf_feature_extractor.padding_value, + device=device, + ) + return cls._gpu_feature_extractor + def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -228,6 +493,7 @@ def _calculate_chunk_counts( chunk_counts.append(min(n_chunks, max_windows)) return chunk_counts + # @torch.compile(fullgraph=True) def _call_hf_processor( self, prompt: str, @@ -235,6 +501,9 @@ def _call_hf_processor( mm_kwargs: Mapping[str, Any], tok_kwargs: Mapping[str, object], ) -> BatchFeature: + """ + Call processor with GPU-accelerated feature extraction. + """ # Normalize input: handle deprecated key and list conversion. if "audios" in mm_data: mm_data["audio"] = mm_data.pop("audios") @@ -248,26 +517,131 @@ def _call_hf_processor( prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - # Get processor for chunk counts calculation + # Get processor for tokenizer and config processor = self.info.get_hf_processor(**mm_kwargs) + hf_feature_extractor = processor.feature_extractor + tokenizer = processor.tokenizer + + # ===== Audio chunking (CPU, fast) ===== + sampling_rate = hf_feature_extractor.sampling_rate + chunk_length = hf_feature_extractor.chunk_length + max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) + window_size = int(sampling_rate * chunk_length) + max_windows = int(max_audio_len // chunk_length) + + per_sample_windows: list[int] = [] + flat_chunks: list[np.ndarray] = [] + + for audio_el in audio_list: + # Convert to numpy if needed + if isinstance(audio_el, torch.Tensor): + audio_el = audio_el.numpy() + elif isinstance(audio_el, list): + audio_el = np.array(audio_el, dtype=np.float32) + + n_samples = int(audio_el.shape[0]) + n_win = max(1, (n_samples + window_size - 1) // window_size) + if n_win > max_windows: + n_win = max_windows + + per_sample_windows.append(n_win) + time_cap = min(n_samples, n_win * window_size) + + for i in range(n_win): + start = i * window_size + end = min((i + 1) * window_size, time_cap) + flat_chunks.append(audio_el[start:end]) + + # ===== GPU Feature Extraction ===== + # Check if CUDA is available, fallback to CPU if not + use_gpu = torch.cuda.is_available() + device = "cuda" if use_gpu else "cpu" + + if use_gpu: + # Use GPU-accelerated feature extractor + gpu_extractor = self._get_gpu_feature_extractor( + hf_feature_extractor, device=device + ) + audio_inputs = gpu_extractor( + flat_chunks, + sampling_rate=sampling_rate, + return_attention_mask=True, + return_tensors="pt", + ) + else: + # Fallback to HF CPU implementation + audio_inputs = hf_feature_extractor( + flat_chunks, + sampling_rate=sampling_rate, + return_tensors="pt", + padding=True, + return_attention_mask=True, + ) - # Call parent method (it will handle sampling_rate) - outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - tok_kwargs=tok_kwargs, + # ===== Process attention mask ===== + padding_mask = audio_inputs.pop("attention_mask") + input_features_mask = padding_mask + + # ===== Compute audio token lengths ===== + chunk_lengths = padding_mask.sum(-1) # [num_chunks] + audio_lengths = torch.stack( + [ + chunk_lengths[ + sum(per_sample_windows[:i]) : sum(per_sample_windows[: i + 1]) + ].sum() + for i in range(len(per_sample_windows)) + ] ) - # Postprocess: rename mask and add chunk counts. - if "input_features_mask" in outputs: - outputs["feature_attention_mask"] = outputs.pop("input_features_mask") + # Apply convolution formula to get token counts + merge_factor = 4 + for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]: + audio_lengths = ( + audio_lengths + 2 * padding - (kernel_size - 1) - 1 + ) // stride + 1 + audio_tokens_lengths = (audio_lengths - merge_factor) // merge_factor + 1 - # Override chunk counts calculation with GLM-ASR specific logic - chunk_counts = self._calculate_chunk_counts( - audio_list, processor.feature_extractor, processor + # ===== Expand audio tokens in text ===== + import regex as re + + audio_token = getattr(processor, "audio_token", "<|pad|>") + text_list = [prompt] + + for i, audio_length in enumerate(audio_tokens_lengths): + if i < len(text_list): + expanded = re.sub( + re.escape(audio_token), + audio_token * int(audio_length), + text_list[i], + ) + text_list[i] = expanded + + # ===== Tokenize text ===== + text_inputs = tokenizer( + text_list, + return_tensors="pt", + padding=True, + **tok_kwargs, ) - outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long) + + # ===== Combine outputs ===== + # Move input_features to CPU for compatibility + input_features = audio_inputs["input_features"] + if input_features.device.type != "cpu": + input_features = input_features.cpu() + if input_features_mask.device.type != "cpu": + input_features_mask = input_features_mask.cpu() + + outputs = BatchFeature( + data={ + **text_inputs, + "input_features": input_features, + "feature_attention_mask": input_features_mask, + }, + tensor_type="pt", + ) + + outputs["chunk_counts"] = torch.tensor(per_sample_windows, dtype=torch.long) return outputs @@ -352,7 +726,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config - self.audio_tower = GlmAsrEncoder(config.audio_config) + # Use optimized vLLM native encoder + self.audio_tower = GlmAsrEncoder( + config.audio_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_tower"), + ) self.multi_modal_projector = GlmAsrMultiModalProjector( config, quant_config=quant_config, @@ -419,12 +798,31 @@ def _process_audio_input( audio_input.get("chunk_counts"), num_chunks=num_chunks ) + # Convert input_features to model dtype (e.g., bfloat16) to match model weights + input_features = input_features.to(dtype=self.audio_tower.conv1.weight.dtype) + + # audio_tower returns [batch_size, seq_len, hidden_size] where hidden_size=1280 audio_hidden_states = self.audio_tower(input_features).last_hidden_state + + # GLM-ASR merges consecutive frames: 4 frames with hidden_size=1280 + # -> 1 frame with intermediate_size=5120 + hidden_size = self.config.audio_config.hidden_size + intermediate_size = self.config.audio_config.intermediate_size + merge_ratio = intermediate_size // hidden_size + + # Truncate sequence length to be divisible by merge_ratio + seq_len = audio_hidden_states.shape[1] + seq_len_truncated = (seq_len // merge_ratio) * merge_ratio + if seq_len_truncated < seq_len: + audio_hidden_states = audio_hidden_states[:, :seq_len_truncated, :] + + # Reshape to merge consecutive frames audio_hidden_states = audio_hidden_states.reshape( num_chunks, -1, - self.config.audio_config.intermediate_size, + intermediate_size, ) + audio_features = self.multi_modal_projector(audio_hidden_states) merge_factor = getattr(self.config, "merge_factor", DEFAULT_MERGE_FACTOR) @@ -444,7 +842,9 @@ def _process_audio_input( chunk_embeddings = torch.split( masked_audio_features, audio_output_lengths.flatten().tolist() ) - return _group_audio_embeddings(chunk_embeddings, chunk_counts) + result = _group_audio_embeddings(chunk_embeddings, chunk_counts) + + return result def get_language_model(self) -> torch.nn.Module: return self.language_model @@ -453,7 +853,9 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] + masked_audio_features = self._process_audio_input(audio_input) + return masked_audio_features def forward( diff --git a/vllm/model_executor/models/glmasr_utils.py b/vllm/model_executor/models/glmasr_utils.py index f65d05252e26..681603bec482 100644 --- a/vllm/model_executor/models/glmasr_utils.py +++ b/vllm/model_executor/models/glmasr_utils.py @@ -1,165 +1,947 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import cast +import logging +from collections.abc import Iterable, Mapping, Sequence +from functools import lru_cache +from typing import Annotated, Any, Literal, TypeAlias, cast +import numpy as np import torch import torch.nn as nn +from transformers import BatchFeature +from transformers.models.glmasr import GlmAsrConfig, GlmAsrProcessor +from transformers.models.whisper import WhisperFeatureExtractor -DEFAULT_MAX_AUDIO_LEN_S = 655 -DEFAULT_MERGE_FACTOR = 4 -# Default convolution parameters: (padding, kernel_size, stride) -# These correspond to the two conv layers in GlmAsrEncoder -DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)] +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs.data import PromptType +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.utils.tensor_schema import TensorSchema, TensorShape +from .audioflamingo3 import ( + AudioFlamingo3MultiModalDataParser, + AudioFlamingo3ProcessingInfo, +) +from .audioflamingo3 import ( + _audioflamingo3_field_config as _glmasr_field_config, +) +from .glmasr_utils import ( + DEFAULT_CONV_PARAMS, + DEFAULT_MAX_AUDIO_LEN_S, + DEFAULT_MERGE_FACTOR, + GlmAsrEncoder, + _flatten_audio_features_by_length, + _get_audio_output_lengths_for_tower, + _get_num_features_for_item, + _group_audio_embeddings, + _normalize_chunk_counts, +) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, + SupportsTranscription, +) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix +from .whisper import ISO639_1_SUPPORTED_LANGS -def _calculate_conv_output_length( - input_length: torch.Tensor, padding: int, kernel_size: int, stride: int -) -> torch.Tensor: - """Calculate Conv1d output length using standard formula.""" - # Standard formula: floor((input + 2*padding - kernel_size) / stride) + 1 - return (input_length + 2 * padding - kernel_size) // stride + 1 - - -def _as_list_chunk_counts( - chunk_counts: torch.Tensor | list[int] | list[torch.Tensor], -) -> list[int]: - if isinstance(chunk_counts, torch.Tensor): - return chunk_counts.tolist() - if chunk_counts and isinstance(chunk_counts[0], torch.Tensor): - tensor_counts = cast(list[torch.Tensor], chunk_counts) - return [int(c.item()) for c in tensor_counts] - return [int(c) for c in chunk_counts] - - -def _normalize_chunk_counts( - chunk_counts: torch.Tensor | list[int] | list[torch.Tensor] | None, - num_chunks: int, -) -> list[int]: - if chunk_counts is None: - return [1] * num_chunks - return _as_list_chunk_counts(chunk_counts) - - -def _get_audio_output_lengths_from_lengths( - audio_lengths: torch.Tensor, - merge_factor: int, - conv_params: list[tuple[int, int, int]], +logger = logging.getLogger(__name__) + + +# ============================================================================= +# GPU-accelerated Whisper Feature Extractor +# ============================================================================= + + +@lru_cache(maxsize=1) +def _get_mel_filters( + n_fft: int = 400, + n_mels: int = 80, + sampling_rate: int = 16000, + device: torch.device | None = None, ) -> torch.Tensor: - for padding, kernel_size, stride in conv_params: - audio_lengths = _calculate_conv_output_length( - audio_lengths, padding, kernel_size, stride + """ + Compute mel filterbank matrix (cached). + Matches WhisperFeatureExtractor's mel_filter_bank with slaney norm/scale. + """ + if device is None: + device = torch.device("cpu") + # Frequency bins + n_freqs = n_fft // 2 + 1 + all_freqs = torch.linspace(0, sampling_rate // 2, n_freqs, device=device) + + # Mel scale conversion (slaney) + min_mel = 0.0 + max_mel = 2595.0 * np.log10(1.0 + (sampling_rate / 2) / 700.0) + mels = torch.linspace(min_mel, max_mel, n_mels + 2, device=device) + mel_freqs = 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + + # Create filterbank + mel_filters = torch.zeros(n_mels, n_freqs, device=device) + for i in range(n_mels): + lower = mel_freqs[i] + center = mel_freqs[i + 1] + upper = mel_freqs[i + 2] + + # Lower slope + lower_slope = (all_freqs - lower) / (center - lower + 1e-10) + # Upper slope + upper_slope = (upper - all_freqs) / (upper - center + 1e-10) + + mel_filters[i] = torch.maximum( + torch.zeros_like(all_freqs), + torch.minimum(lower_slope, upper_slope), ) - return (audio_lengths - merge_factor) // merge_factor + 1 + # Slaney normalization + enorm = 2.0 / (mel_freqs[2 : n_mels + 2] - mel_freqs[:n_mels]) + mel_filters *= enorm.unsqueeze(1) -def _get_audio_output_lengths_from_mask( - mask: torch.Tensor, - merge_factor: int, - conv_params: list[tuple[int, int, int]], -) -> torch.Tensor: - audio_lengths = mask.sum(-1) - return _get_audio_output_lengths_from_lengths( - audio_lengths, merge_factor, conv_params - ) + return mel_filters -def _get_audio_output_lengths_for_tower( - audio_tower: nn.Module, - audio_lengths: torch.Tensor, - merge_factor: int, - conv_params: list[tuple[int, int, int]], -) -> torch.Tensor: - if hasattr(audio_tower, "_get_feat_extract_output_lengths"): - _, audio_output_lengths = audio_tower._get_feat_extract_output_lengths( - audio_lengths +class GPUWhisperFeatureExtractor: + """ + GPU-accelerated Whisper feature extractor using PyTorch. + Computes log-mel spectrogram matching WhisperFeatureExtractor output. + + Key parameters (Whisper defaults): + - n_fft: 400 (25ms window at 16kHz) + - hop_length: 160 (10ms hop at 16kHz) + - n_mels: 80 + - chunk_length: 30 seconds + - sampling_rate: 16000 + """ + + def __init__( + self, + feature_size: int = 80, + sampling_rate: int = 16000, + hop_length: int = 160, + chunk_length: int = 30, + n_fft: int = 400, + padding_value: float = 0.0, + device: str | torch.device = "cuda", + ): + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_fft = n_fft + self.padding_value = padding_value + self.device = torch.device(device) if isinstance(device, str) else device + + # Derived parameters + self.n_samples = chunk_length * sampling_rate # 480000 for 30s + self.nb_max_frames = self.n_samples // hop_length # 3000 frames + + # Pre-compute window and mel filters on device + self._window: torch.Tensor | None = None + self._mel_filters: torch.Tensor | None = None + + def _ensure_buffers(self, device: torch.device) -> None: + """Lazily initialize buffers on the target device.""" + if self._window is None or self._window.device != device: + self._window = torch.hann_window(self.n_fft, device=device) + + if self._mel_filters is None or self._mel_filters.device != device: + self._mel_filters = _get_mel_filters( + n_fft=self.n_fft, + n_mels=self.feature_size, + sampling_rate=self.sampling_rate, + device=device, + ) + + def __call__( + self, + raw_speech: list[np.ndarray] | np.ndarray | torch.Tensor, + sampling_rate: int | None = None, + padding: str = "max_length", + max_length: int | None = None, + return_attention_mask: bool = True, + return_tensors: str = "pt", + device: str | torch.device | None = None, + ) -> BatchFeature: + """ + Extract log-mel spectrogram features from audio. + + Args: + raw_speech: Audio waveform(s), can be list of arrays or batched + sampling_rate: Expected sample rate (must match self.sampling_rate) + padding: Padding strategy ('max_length' or 'longest') + max_length: Max samples (default: self.n_samples = 30s * 16kHz) + return_attention_mask: Whether to return attention mask + return_tensors: Output format ('pt' for PyTorch) + device: Device for computation (default: self.device) + + Returns: + BatchFeature with 'input_features' and optionally 'attention_mask' + """ + if sampling_rate is not None and sampling_rate != self.sampling_rate: + raise ValueError( + f"Expected sampling_rate={self.sampling_rate}, got {sampling_rate}" + ) + + device = torch.device(device) if device else self.device + max_length = max_length or self.n_samples + + # Convert inputs to list of 1D tensors + if isinstance(raw_speech, np.ndarray): + raw_speech = [raw_speech] if raw_speech.ndim == 1 else list(raw_speech) + elif isinstance(raw_speech, torch.Tensor): + raw_speech = ( + [raw_speech.numpy()] + if raw_speech.ndim == 1 + else [s.numpy() for s in raw_speech] + ) + + batch_size = len(raw_speech) + + # Get actual lengths before padding + lengths = [len(s) for s in raw_speech] + + # Pad/truncate to max_length + if padding == "max_length": + target_length = max_length + else: # 'longest' + target_length = min(max(lengths), max_length) + + # Create padded batch tensor + padded_waveforms = torch.zeros( + batch_size, target_length, dtype=torch.float32, device=device + ) + attention_mask = torch.zeros( + batch_size, target_length, dtype=torch.int32, device=device ) - return audio_output_lengths - return _get_audio_output_lengths_from_lengths( - audio_lengths, merge_factor, conv_params - ) + for i, waveform in enumerate(raw_speech): + if isinstance(waveform, np.ndarray): + waveform = torch.from_numpy(waveform) + waveform = waveform.to(device=device, dtype=torch.float32) -def _flatten_audio_features_by_length( - audio_features: torch.Tensor, - audio_output_lengths: torch.Tensor, -) -> torch.Tensor: - num_chunks, max_audio_tokens, embed_dim = audio_features.shape - audio_output_lengths = audio_output_lengths.unsqueeze(1) - audio_features_mask = ( - torch.arange(max_audio_tokens) - .expand(num_chunks, max_audio_tokens) - .to(audio_output_lengths.device) - < audio_output_lengths - ) - return audio_features[audio_features_mask].view(-1, embed_dim) - - -def _group_audio_embeddings( - chunk_embeddings: Sequence[torch.Tensor], - chunk_counts: Sequence[int], -) -> tuple[torch.Tensor, ...]: - grouped_embeddings = [] - current_idx = 0 - for count in chunk_counts: - audio_chunks = chunk_embeddings[current_idx : current_idx + count] - grouped_embeddings.append(torch.cat(audio_chunks, dim=0)) - current_idx += count - return tuple(grouped_embeddings) - - -def _normalize_to_tensor(mask: torch.Tensor | list[torch.Tensor]) -> torch.Tensor: - """Convert mask to tensor, handling both list and tensor formats.""" - if isinstance(mask, list): - return ( - torch.stack(mask) - if mask and isinstance(mask[0], torch.Tensor) - else torch.tensor(mask) - ) - return mask - - -def _extract_mask_for_item( - feature_attention_mask: torch.Tensor | list[torch.Tensor], - chunk_counts: torch.Tensor | list[int] | None, - item_idx: int, -) -> torch.Tensor: - """Extract attention mask for a specific audio item.""" - if chunk_counts is None: - # Single item per audio - mask = feature_attention_mask[item_idx] - if isinstance(feature_attention_mask, torch.Tensor): - return mask.unsqueeze(0) - return _normalize_to_tensor(mask) - - # Multiple chunks per audio: calculate slice indices - counts = _as_list_chunk_counts(chunk_counts) - start_idx = sum(counts[:item_idx]) - end_idx = start_idx + counts[item_idx] - - # Extract slice - if isinstance(feature_attention_mask, torch.Tensor): - return feature_attention_mask[start_idx:end_idx] - mask_slice = feature_attention_mask[start_idx:end_idx] - return _normalize_to_tensor(mask_slice) - - -def _get_num_features_for_item( - feature_attention_mask: torch.Tensor | None, - chunk_counts: torch.Tensor | list[int] | None, - item_idx: int, - audio_embeds: list[torch.Tensor] | None, - merge_factor: int, - conv_params: list[tuple[int, int, int]], -) -> int: - """Get number of features for a specific audio item.""" - if feature_attention_mask is not None: - mask = _extract_mask_for_item(feature_attention_mask, chunk_counts, item_idx) - audio_output_lengths = _get_audio_output_lengths_from_mask( - mask, merge_factor, conv_params - ) - return audio_output_lengths.sum().item() - if audio_embeds is not None: - return audio_embeds[item_idx].shape[0] - raise ValueError("Either feature_attention_mask or audio_embeds must be provided") + # Truncate if needed + actual_len = min(len(waveform), target_length) + padded_waveforms[i, :actual_len] = waveform[:actual_len] + attention_mask[i, :actual_len] = 1 + + # Extract features on GPU + input_features = self._extract_fbank_features(padded_waveforms) + + # Rescale attention mask from samples to frames + # STFT produces L//hop_length + 1 frames, but we drop the last one + frame_attention_mask = attention_mask[:, :: self.hop_length] + # Trim to match actual frame count (we drop last frame in _extract) + if attention_mask.shape[1] % self.hop_length != 0: + frame_attention_mask = frame_attention_mask[:, :-1] + + result: dict[str, Any] = {"input_features": input_features} + if return_attention_mask: + result["attention_mask"] = frame_attention_mask + + return BatchFeature(data=result, tensor_type=return_tensors) + + def _extract_fbank_features(self, waveforms: torch.Tensor) -> torch.Tensor: + """ + Compute log-mel spectrogram for batched waveforms. + + Args: + waveforms: [batch, samples] float32 tensor on target device + + Returns: + [batch, n_mels, frames] float32 tensor (log-mel spectrogram) + """ + device = waveforms.device + self._ensure_buffers(device) + + # STFT: [batch, samples] -> [batch, n_fft//2+1, frames] complex + stft = torch.stft( + waveforms, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=self._window, + return_complex=True, + ) + + # Power spectrogram, drop last frame (matching HF implementation) + magnitudes = stft[..., :-1].abs() ** 2 # [batch, n_freqs, frames] + + # Apply mel filterbank: [n_mels, n_freqs] @ [batch, n_freqs, frames] + # -> [batch, n_mels, frames] + mel_spec = torch.matmul(self._mel_filters, magnitudes) + + # Log scale with floor + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + + # Per-sample normalization (max - 8.0 floor, then scale) + max_val = log_spec.amax(dim=(1, 2), keepdim=True) + log_spec = torch.maximum(log_spec, max_val - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec + + +class GlmAsrFeatureInputs(TensorSchema): + """ + Dimensions: + - num_chunks: Number of audio chunks (flattened) + - nmb: Number of mel bins + - num_audios: Number of original audio files + """ + + type: Literal["audio_features"] + input_features: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("num_chunks", "nmb", "chunk_length", dynamic_dims={"chunk_length"}), + ] + feature_attention_mask: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("num_chunks", "chunk_length", dynamic_dims={"chunk_length"}), + ] + chunk_counts: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("num_audios"), + ] + + +class GlmAsrEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size + - naf: Number of audio features + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + + type: Literal["audio_embeds"] = "audio_embeds" + audio_embeds: Annotated[ + list[torch.Tensor], + TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}), + ] + + +GlmAsrInputs: TypeAlias = GlmAsrFeatureInputs | GlmAsrEmbeddingInputs + + +class GlmAsrMultiModalProjector(nn.Module): + def __init__( + self, + config: GlmAsrConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.linear_1 = ColumnParallelLinear( + input_size=config.audio_config.intermediate_size, + output_size=config.text_config.hidden_size * 2, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) + self.act = get_act_fn(config.projector_hidden_act) + self.linear_2 = RowParallelLinear( + input_size=config.text_config.hidden_size * 2, + output_size=config.text_config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + return hidden_states + + +class GlmAsrProcessingInfo(AudioFlamingo3ProcessingInfo): + def get_hf_config(self) -> GlmAsrConfig: + return self.ctx.get_hf_config(GlmAsrConfig) + + def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor: + return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs) + + def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: + # Reuse parent implementation, but add type annotation and assertion + feature_extractor = super().get_feature_extractor(**kwargs) + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + +class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + hf_processor = self.info.get_hf_processor() + return hf_processor.audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + sampling_rate = feature_extractor.sampling_rate + num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") if mm_options else None + + max_audio_len = getattr( + self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S + ) + audio_len = int(max_audio_len * sampling_rate) + + return { + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) + } + + +class GlmAsrMultiModalDataParser(AudioFlamingo3MultiModalDataParser): + def _parse_audio_data( + self, + data: dict[str, torch.Tensor] | ModalityData[Any], + ) -> ModalityDataItems[Any, Any] | None: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_glmasr_field_config, + ) + return super()._parse_audio_data(data) + + +class GlmAsrMultiModalProcessor(BaseMultiModalProcessor["GlmAsrProcessingInfo"]): + """ + GLM-ASR processor that inherits directly from BaseMultiModalProcessor + for better performance and cleaner implementation. + Uses GPU-accelerated feature extraction for improved throughput. + """ + + # Shared GPU feature extractor instance (lazy initialized) + _gpu_feature_extractor: GPUWhisperFeatureExtractor | None = None + + @classmethod + def _get_gpu_feature_extractor( + cls, + hf_feature_extractor: WhisperFeatureExtractor, + device: str = "cuda", + ) -> GPUWhisperFeatureExtractor: + """Get or create GPU feature extractor matching HF config.""" + if cls._gpu_feature_extractor is None: + cls._gpu_feature_extractor = GPUWhisperFeatureExtractor( + feature_size=hf_feature_extractor.feature_size, + sampling_rate=hf_feature_extractor.sampling_rate, + hop_length=hf_feature_extractor.hop_length, + chunk_length=hf_feature_extractor.chunk_length, + n_fft=hf_feature_extractor.n_fft, + padding_value=hf_feature_extractor.padding_value, + device=device, + ) + return cls._gpu_feature_extractor + + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def _calculate_chunk_counts( + self, + audio_list: list[Any], + feature_extractor: WhisperFeatureExtractor, + processor: GlmAsrProcessor, + ) -> list[int]: + """Calculate chunk counts for each audio.""" + sampling_rate = feature_extractor.sampling_rate + chunk_length = feature_extractor.chunk_length + max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) + window_size = int(sampling_rate * chunk_length) + max_windows = int(max_audio_len // chunk_length) + + chunk_counts = [] + for audio in audio_list: + n_samples = len(audio) if isinstance(audio, list) else audio.shape[0] + n_chunks = max(1, (n_samples + window_size - 1) // window_size) + chunk_counts.append(min(n_chunks, max_windows)) + return chunk_counts + + # @torch.compile(fullgraph=True) + def _call_hf_processor( + self, + prompt: str, + mm_data: dict[str, object], + mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + """ + Call processor with GPU-accelerated feature extraction. + """ + # Normalize input: handle deprecated key and list conversion. + if "audios" in mm_data: + mm_data["audio"] = mm_data.pop("audios") + + audio = mm_data.get("audio", []) + audio_list = [audio] if audio and not isinstance(audio, list) else audio + + # Early return for text-only. + if not audio_list: + prompt_ids = self.info.get_tokenizer().encode(prompt) + prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + # Get processor for tokenizer and config + processor = self.info.get_hf_processor(**mm_kwargs) + hf_feature_extractor = processor.feature_extractor + tokenizer = processor.tokenizer + + # ===== Audio chunking (CPU, fast) ===== + sampling_rate = hf_feature_extractor.sampling_rate + chunk_length = hf_feature_extractor.chunk_length + max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) + window_size = int(sampling_rate * chunk_length) + max_windows = int(max_audio_len // chunk_length) + + per_sample_windows: list[int] = [] + flat_chunks: list[np.ndarray] = [] + + for audio_el in audio_list: + # Convert to numpy if needed + if isinstance(audio_el, torch.Tensor): + audio_el = audio_el.numpy() + elif isinstance(audio_el, list): + audio_el = np.array(audio_el, dtype=np.float32) + + n_samples = int(audio_el.shape[0]) + n_win = max(1, (n_samples + window_size - 1) // window_size) + if n_win > max_windows: + n_win = max_windows + + per_sample_windows.append(n_win) + time_cap = min(n_samples, n_win * window_size) + + for i in range(n_win): + start = i * window_size + end = min((i + 1) * window_size, time_cap) + flat_chunks.append(audio_el[start:end]) + + # ===== GPU Feature Extraction ===== + # Check if CUDA is available, fallback to CPU if not + use_gpu = torch.cuda.is_available() + device = "cuda" if use_gpu else "cpu" + + if use_gpu: + # Use GPU-accelerated feature extractor + gpu_extractor = self._get_gpu_feature_extractor( + hf_feature_extractor, device=device + ) + audio_inputs = gpu_extractor( + flat_chunks, + sampling_rate=sampling_rate, + return_attention_mask=True, + return_tensors="pt", + ) + else: + # Fallback to HF CPU implementation + audio_inputs = hf_feature_extractor( + flat_chunks, + sampling_rate=sampling_rate, + return_tensors="pt", + padding=True, + return_attention_mask=True, + ) + + # ===== Process attention mask ===== + padding_mask = audio_inputs.pop("attention_mask") + input_features_mask = padding_mask + + # ===== Compute audio token lengths ===== + chunk_lengths = padding_mask.sum(-1) # [num_chunks] + audio_lengths = torch.stack( + [ + chunk_lengths[ + sum(per_sample_windows[:i]) : sum(per_sample_windows[: i + 1]) + ].sum() + for i in range(len(per_sample_windows)) + ] + ) + + # Apply convolution formula to get token counts + merge_factor = 4 + for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]: + audio_lengths = ( + audio_lengths + 2 * padding - (kernel_size - 1) - 1 + ) // stride + 1 + audio_tokens_lengths = (audio_lengths - merge_factor) // merge_factor + 1 + + # ===== Expand audio tokens in text ===== + import regex as re + + audio_token = getattr(processor, "audio_token", "<|pad|>") + text_list = [prompt] + + for i, audio_length in enumerate(audio_tokens_lengths): + if i < len(text_list): + expanded = re.sub( + re.escape(audio_token), + audio_token * int(audio_length), + text_list[i], + ) + text_list[i] = expanded + + # ===== Tokenize text ===== + text_inputs = tokenizer( + text_list, + return_tensors="pt", + padding=True, + **tok_kwargs, + ) + + # ===== Combine outputs ===== + # Move input_features to CPU for compatibility + input_features = audio_inputs["input_features"] + if input_features.device.type != "cpu": + input_features = input_features.cpu() + if input_features_mask.device.type != "cpu": + input_features_mask = input_features_mask.cpu() + + outputs = BatchFeature( + data={ + **text_inputs, + "input_features": input_features, + "feature_attention_mask": input_features_mask, + }, + tensor_type="pt", + ) + + outputs["chunk_counts"] = torch.tensor(per_sample_windows, dtype=torch.long) + + return outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _glmasr_field_config(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + config = self.info.get_hf_config() + + audio_token = getattr(processor, "audio_token", "<|pad|>") + audio_token_id = vocab.get(audio_token) + if audio_token_id is None: + audio_token_id = processor.audio_token_id + + merge_factor = getattr(config, "merge_factor", DEFAULT_MERGE_FACTOR) + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") + chunk_counts = out_mm_data.get("chunk_counts") + + def get_replacement_glmasr(item_idx: int): + conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS) + audio_embeds = out_mm_data.get("audio_embeds") + num_features = _get_num_features_for_item( + feature_attention_mask, + chunk_counts, + item_idx, + audio_embeds, + merge_factor, + conv_params, + ) + + if num_features == 0: + raise ValueError("Audio is too short") + + audio_tokens = [audio_token_id] * int(num_features) + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_glmasr, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + GlmAsrMultiModalProcessor, + info=GlmAsrProcessingInfo, + dummy_inputs=GlmAsrDummyInputsBuilder, +) +class GlmAsrForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription +): + supported_languages = ISO639_1_SUPPORTED_LANGS + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + # Use optimized vLLM native encoder + self.audio_tower = GlmAsrEncoder( + config.audio_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_tower"), + ) + self.multi_modal_projector = GlmAsrMultiModalProjector( + config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + self.quant_config = quant_config + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["LlamaForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("audio"): + return "<|begin_of_audio|><|pad|><|end_of_audio|>" + + raise ValueError("Only audio modality is supported") + + def get_mm_mapping(self) -> MultiModelKeys: + return MultiModelKeys.from_string_field( + language_model="language_model.", + connector="multi_modal_projector.", + tower_model="audio_tower.", + ) + + def _parse_and_validate_audio_input(self, **kwargs: object) -> GlmAsrInputs | None: + audio_embeds = kwargs.pop("audio_embeds", None) + if audio_embeds is not None: + return GlmAsrEmbeddingInputs(type="audio_embeds", audio_embeds=audio_embeds) + + input_features = kwargs.pop("input_features", None) + if input_features is None: + return None + + return GlmAsrFeatureInputs( + type="audio_features", + input_features=input_features, + feature_attention_mask=kwargs.pop("feature_attention_mask", None), + chunk_counts=kwargs.pop("chunk_counts", None), + ) + + def _process_audio_input( + self, audio_input: GlmAsrInputs + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + if audio_input["type"] == "audio_embeds": + return tuple(audio_input["audio_embeds"]) + + input_features = audio_input["input_features"] + feature_attention_mask = audio_input["feature_attention_mask"] + + if isinstance(input_features, list): + input_features = torch.cat(input_features, dim=0) + feature_attention_mask = torch.cat(feature_attention_mask, dim=0) + + num_chunks = input_features.shape[0] + chunk_counts = _normalize_chunk_counts( + audio_input.get("chunk_counts"), num_chunks=num_chunks + ) + + # Convert input_features to model dtype (e.g., bfloat16) to match model weights + input_features = input_features.to(dtype=self.audio_tower.conv1.weight.dtype) + + # audio_tower returns [batch_size, seq_len, hidden_size] where hidden_size=1280 + audio_hidden_states = self.audio_tower(input_features).last_hidden_state + + # GLM-ASR merges consecutive frames: 4 frames with hidden_size=1280 + # -> 1 frame with intermediate_size=5120 + hidden_size = self.config.audio_config.hidden_size + intermediate_size = self.config.audio_config.intermediate_size + merge_ratio = intermediate_size // hidden_size + + # Truncate sequence length to be divisible by merge_ratio + seq_len = audio_hidden_states.shape[1] + seq_len_truncated = (seq_len // merge_ratio) * merge_ratio + if seq_len_truncated < seq_len: + audio_hidden_states = audio_hidden_states[:, :seq_len_truncated, :] + + # Reshape to merge consecutive frames + audio_hidden_states = audio_hidden_states.reshape( + num_chunks, + -1, + intermediate_size, + ) + + audio_features = self.multi_modal_projector(audio_hidden_states) + + merge_factor = getattr(self.config, "merge_factor", DEFAULT_MERGE_FACTOR) + conv_params = getattr(self.config, "conv_params", DEFAULT_CONV_PARAMS) + + audio_output_lengths = _get_audio_output_lengths_for_tower( + self.audio_tower, + feature_attention_mask.sum(-1), + merge_factor, + conv_params, + ) + + masked_audio_features = _flatten_audio_features_by_length( + audio_features, audio_output_lengths + ) + + chunk_embeddings = torch.split( + masked_audio_features, audio_output_lengths.flatten().tolist() + ) + result = _group_audio_embeddings(chunk_embeddings, chunk_counts) + + return result + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return [] + + masked_audio_features = self._process_audio_input(audio_input) + + return masked_audio_features + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + skip_prefixes = ["audio_tower.embed_positions"] + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) + + @classmethod + def _get_audio_token(cls, model_config: ModelConfig) -> str: + """Get the audio token from processor. + + Similar to get_placeholder_str but returns single token. + """ + processor = cached_processor_from_config(model_config) + return getattr(processor, "audio_token", "<|pad|>") + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + processor = cached_processor_from_config(model_config) + feature_extractor = processor.feature_extractor + max_audio_clip_s = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) + return SpeechToTextConfig( + max_audio_clip_s=max_audio_clip_s, + sample_rate=feature_extractor.sampling_rate, + ) + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + """Get the generation prompt to be used for transcription requests.""" + tokenizer = cached_tokenizer_from_config(model_config) + audio_token = cls._get_audio_token(model_config) + + if task_type == "translate": + full_lang_name_to = cls.supported_languages.get(to_language, to_language) + user_content = f"{audio_token}translate the speech to {full_lang_name_to}" + elif task_type == "transcribe": + user_content = ( + f"{audio_token}can you transcribe the speech into a written format?" + ) + else: + raise ValueError(f"Unsupported task type {task_type}") + + messages = [{"role": "user", "content": user_content}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + prompt_token_ids = tokenizer.encode(prompt) + prompt_dict = { + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": {"audio": audio}, + } + return cast(PromptType, prompt_dict) From fb1048e37109e033318977b206280dee1a47fa40 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Mon, 5 Jan 2026 15:50:47 +0800 Subject: [PATCH 02/24] fix glmasr_utils Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr_utils.py | 1440 +++++++++----------- 1 file changed, 632 insertions(+), 808 deletions(-) diff --git a/vllm/model_executor/models/glmasr_utils.py b/vllm/model_executor/models/glmasr_utils.py index 681603bec482..24d65ae54aec 100644 --- a/vllm/model_executor/models/glmasr_utils.py +++ b/vllm/model_executor/models/glmasr_utils.py @@ -2,946 +2,770 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -from collections.abc import Iterable, Mapping, Sequence -from functools import lru_cache -from typing import Annotated, Any, Literal, TypeAlias, cast +from collections.abc import Iterable, Sequence +from typing import cast -import numpy as np import torch import torch.nn as nn -from transformers import BatchFeature -from transformers.models.glmasr import GlmAsrConfig, GlmAsrProcessor -from transformers.models.whisper import WhisperFeatureExtractor -from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - MultiModalKwargsItems, +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, ) -from vllm.multimodal.parse import ( - DictEmbeddingItems, - ModalityData, - ModalityDataItems, - MultiModalDataItems, - MultiModalDataParser, -) -from vllm.multimodal.processing import ( - BaseMultiModalProcessor, - PromptReplacement, - PromptUpdate, - PromptUpdateDetails, -) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors -from vllm.tokenizers import cached_tokenizer_from_config -from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .audioflamingo3 import ( - AudioFlamingo3MultiModalDataParser, - AudioFlamingo3ProcessingInfo, -) -from .audioflamingo3 import ( - _audioflamingo3_field_config as _glmasr_field_config, -) -from .glmasr_utils import ( - DEFAULT_CONV_PARAMS, - DEFAULT_MAX_AUDIO_LEN_S, - DEFAULT_MERGE_FACTOR, - GlmAsrEncoder, - _flatten_audio_features_by_length, - _get_audio_output_lengths_for_tower, - _get_num_features_for_item, - _group_audio_embeddings, - _normalize_chunk_counts, -) -from .interfaces import ( - MultiModalEmbeddings, - SupportsLoRA, - SupportsMultiModal, - SupportsPP, - SupportsTranscription, -) -from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -from .whisper import ISO639_1_SUPPORTED_LANGS logger = logging.getLogger(__name__) +DEFAULT_MAX_AUDIO_LEN_S = 655 +DEFAULT_MERGE_FACTOR = 4 +# Default convolution parameters: (padding, kernel_size, stride) +# These correspond to the two conv layers in GlmAsrEncoder +DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)] + + +class _GlmAsrEncoderOutput: + """Simple output container compatible with transformers' BaseModelOutput.""" + + __slots__ = ("last_hidden_state",) + + def __init__(self, last_hidden_state: torch.Tensor): + self.last_hidden_state = last_hidden_state + + +def _calculate_conv_output_length( + input_length: torch.Tensor, padding: int, kernel_size: int, stride: int +) -> torch.Tensor: + """Calculate Conv1d output length using standard formula.""" + # Standard formula: floor((input + 2*padding - kernel_size) / stride) + 1 + return (input_length + 2 * padding - kernel_size) // stride + 1 + + +def _as_list_chunk_counts( + chunk_counts: torch.Tensor | list[int] | list[torch.Tensor], +) -> list[int]: + if isinstance(chunk_counts, torch.Tensor): + return chunk_counts.tolist() + if chunk_counts and isinstance(chunk_counts[0], torch.Tensor): + tensor_counts = cast(list[torch.Tensor], chunk_counts) + return [int(c.item()) for c in tensor_counts] + return [int(c) for c in chunk_counts] + + +def _normalize_chunk_counts( + chunk_counts: torch.Tensor | list[int] | list[torch.Tensor] | None, + num_chunks: int, +) -> list[int]: + if chunk_counts is None: + return [1] * num_chunks + return _as_list_chunk_counts(chunk_counts) + + +def _get_audio_output_lengths_from_lengths( + audio_lengths: torch.Tensor, + merge_factor: int, + conv_params: list[tuple[int, int, int]], +) -> torch.Tensor: + for padding, kernel_size, stride in conv_params: + audio_lengths = _calculate_conv_output_length( + audio_lengths, padding, kernel_size, stride + ) + return (audio_lengths - merge_factor) // merge_factor + 1 -# ============================================================================= -# GPU-accelerated Whisper Feature Extractor -# ============================================================================= + +def _get_audio_output_lengths_from_mask( + mask: torch.Tensor, + merge_factor: int, + conv_params: list[tuple[int, int, int]], +) -> torch.Tensor: + audio_lengths = mask.sum(-1) + return _get_audio_output_lengths_from_lengths( + audio_lengths, merge_factor, conv_params + ) -@lru_cache(maxsize=1) -def _get_mel_filters( - n_fft: int = 400, - n_mels: int = 80, - sampling_rate: int = 16000, - device: torch.device | None = None, +def _get_audio_output_lengths_for_tower( + audio_tower: nn.Module, + audio_lengths: torch.Tensor, + merge_factor: int, + conv_params: list[tuple[int, int, int]], ) -> torch.Tensor: """ - Compute mel filterbank matrix (cached). - Matches WhisperFeatureExtractor's mel_filter_bank with slaney norm/scale. + Calculate the output lengths after audio processing. + + The output length accounts for: + 1. Convolution layers (downsampling) + 2. Merge factor (further downsampling during projection) + + Args: + audio_tower: The audio encoder module + audio_lengths: Input feature lengths [batch_size] + merge_factor: Factor for merging adjacent features + conv_params: List of (padding, kernel_size, stride) for each conv layer + + Returns: + Output lengths after all processing [batch_size] """ - if device is None: - device = torch.device("cpu") - # Frequency bins - n_freqs = n_fft // 2 + 1 - all_freqs = torch.linspace(0, sampling_rate // 2, n_freqs, device=device) - - # Mel scale conversion (slaney) - min_mel = 0.0 - max_mel = 2595.0 * np.log10(1.0 + (sampling_rate / 2) / 700.0) - mels = torch.linspace(min_mel, max_mel, n_mels + 2, device=device) - mel_freqs = 700.0 * (10.0 ** (mels / 2595.0) - 1.0) - - # Create filterbank - mel_filters = torch.zeros(n_mels, n_freqs, device=device) - for i in range(n_mels): - lower = mel_freqs[i] - center = mel_freqs[i + 1] - upper = mel_freqs[i + 2] - - # Lower slope - lower_slope = (all_freqs - lower) / (center - lower + 1e-10) - # Upper slope - upper_slope = (upper - all_freqs) / (upper - center + 1e-10) - - mel_filters[i] = torch.maximum( - torch.zeros_like(all_freqs), - torch.minimum(lower_slope, upper_slope), + # First, calculate the output length after convolutions + if hasattr(audio_tower, "_get_feat_extract_output_lengths"): + _, conv_output_lengths = audio_tower._get_feat_extract_output_lengths( + audio_lengths ) + else: + conv_output_lengths = audio_lengths + for padding, kernel_size, stride in conv_params: + conv_output_lengths = _calculate_conv_output_length( + conv_output_lengths, padding, kernel_size, stride + ) - # Slaney normalization - enorm = 2.0 / (mel_freqs[2 : n_mels + 2] - mel_freqs[:n_mels]) - mel_filters *= enorm.unsqueeze(1) + # Then, apply merge_factor to get final output length + # Formula: (conv_output_lengths - merge_factor) // merge_factor + 1 + return (conv_output_lengths - merge_factor) // merge_factor + 1 - return mel_filters + +def _flatten_audio_features_by_length( + audio_features: torch.Tensor, + audio_output_lengths: torch.Tensor, +) -> torch.Tensor: + num_chunks, max_audio_tokens, embed_dim = audio_features.shape + audio_output_lengths = audio_output_lengths.unsqueeze(1) + audio_features_mask = ( + torch.arange(max_audio_tokens) + .expand(num_chunks, max_audio_tokens) + .to(audio_output_lengths.device) + < audio_output_lengths + ) + return audio_features[audio_features_mask].view(-1, embed_dim) + + +def _group_audio_embeddings( + chunk_embeddings: Sequence[torch.Tensor], + chunk_counts: Sequence[int], +) -> tuple[torch.Tensor, ...]: + grouped_embeddings = [] + current_idx = 0 + for count in chunk_counts: + audio_chunks = chunk_embeddings[current_idx : current_idx + count] + grouped_embeddings.append(torch.cat(audio_chunks, dim=0)) + current_idx += count + return tuple(grouped_embeddings) + + +def _normalize_to_tensor(mask: torch.Tensor | list[torch.Tensor]) -> torch.Tensor: + """Convert mask to tensor, handling both list and tensor formats.""" + if isinstance(mask, list): + return ( + torch.stack(mask) + if mask and isinstance(mask[0], torch.Tensor) + else torch.tensor(mask) + ) + return mask -class GPUWhisperFeatureExtractor: +def _extract_mask_for_item( + feature_attention_mask: torch.Tensor | list[torch.Tensor], + chunk_counts: torch.Tensor | list[int] | None, + item_idx: int, +) -> torch.Tensor: + """Extract attention mask for a specific audio item.""" + if chunk_counts is None: + # Single item per audio + mask = feature_attention_mask[item_idx] + if isinstance(feature_attention_mask, torch.Tensor): + return mask.unsqueeze(0) + return _normalize_to_tensor(mask) + + # Multiple chunks per audio: calculate slice indices + counts = _as_list_chunk_counts(chunk_counts) + start_idx = sum(counts[:item_idx]) + end_idx = start_idx + counts[item_idx] + + # Extract slice + if isinstance(feature_attention_mask, torch.Tensor): + return feature_attention_mask[start_idx:end_idx] + mask_slice = feature_attention_mask[start_idx:end_idx] + return _normalize_to_tensor(mask_slice) + + +def _get_num_features_for_item( + feature_attention_mask: torch.Tensor | None, + chunk_counts: torch.Tensor | list[int] | None, + item_idx: int, + audio_embeds: list[torch.Tensor] | None, + merge_factor: int, + conv_params: list[tuple[int, int, int]], +) -> int: + """Get number of features for a specific audio item.""" + if feature_attention_mask is not None: + mask = _extract_mask_for_item(feature_attention_mask, chunk_counts, item_idx) + audio_output_lengths = _get_audio_output_lengths_from_mask( + mask, merge_factor, conv_params + ) + return audio_output_lengths.sum().item() + if audio_embeds is not None: + return audio_embeds[item_idx].shape[0] + raise ValueError("Either feature_attention_mask or audio_embeds must be provided") + + +# ============================================================================ +# Optimized vLLM Native GlmAsrEncoder Implementation +# ============================================================================ + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: """ - GPU-accelerated Whisper feature extractor using PyTorch. - Computes log-mel spectrogram matching WhisperFeatureExtractor output. - - Key parameters (Whisper defaults): - - n_fft: 400 (25ms window at 16kHz) - - hop_length: 160 (10ms hop at 16kHz) - - n_mels: 80 - - chunk_length: 30 seconds - - sampling_rate: 16000 + Apply rotary position embeddings to query and key tensors. + + Follows transformers' apply_rotary_pos_emb exactly. + Supports partial rotary where only the first rotary_dim of head_dim is rotated. + + Args: + q: [batch, num_heads, seq_len, head_dim] + k: [batch, num_kv_heads, seq_len, head_dim] + cos: [batch, seq_len, rotary_dim] + sin: [batch, seq_len, rotary_dim] """ + # unsqueeze_dim=1 to add head dimension: [batch, 1, seq_len, rotary_dim] + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) - def __init__( - self, - feature_size: int = 80, - sampling_rate: int = 16000, - hop_length: int = 160, - chunk_length: int = 30, - n_fft: int = 400, - padding_value: float = 0.0, - device: str | torch.device = "cuda", - ): - self.feature_size = feature_size - self.sampling_rate = sampling_rate - self.hop_length = hop_length - self.chunk_length = chunk_length - self.n_fft = n_fft - self.padding_value = padding_value - self.device = torch.device(device) if isinstance(device, str) else device - - # Derived parameters - self.n_samples = chunk_length * sampling_rate # 480000 for 30s - self.nb_max_frames = self.n_samples // hop_length # 3000 frames - - # Pre-compute window and mel filters on device - self._window: torch.Tensor | None = None - self._mel_filters: torch.Tensor | None = None - - def _ensure_buffers(self, device: torch.device) -> None: - """Lazily initialize buffers on the target device.""" - if self._window is None or self._window.device != device: - self._window = torch.hann_window(self.n_fft, device=device) - - if self._mel_filters is None or self._mel_filters.device != device: - self._mel_filters = _get_mel_filters( - n_fft=self.n_fft, - n_mels=self.feature_size, - sampling_rate=self.sampling_rate, - device=device, - ) + # Get the rotary dimension from cos/sin + rotary_dim = cos.shape[-1] - def __call__( - self, - raw_speech: list[np.ndarray] | np.ndarray | torch.Tensor, - sampling_rate: int | None = None, - padding: str = "max_length", - max_length: int | None = None, - return_attention_mask: bool = True, - return_tensors: str = "pt", - device: str | torch.device | None = None, - ) -> BatchFeature: - """ - Extract log-mel spectrogram features from audio. + # Split into rotary and pass-through parts + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - Args: - raw_speech: Audio waveform(s), can be list of arrays or batched - sampling_rate: Expected sample rate (must match self.sampling_rate) - padding: Padding strategy ('max_length' or 'longest') - max_length: Max samples (default: self.n_samples = 30s * 16kHz) - return_attention_mask: Whether to return attention mask - return_tensors: Output format ('pt' for PyTorch) - device: Device for computation (default: self.device) + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (_rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (_rotate_half(k_rot) * sin) - Returns: - BatchFeature with 'input_features' and optionally 'attention_mask' - """ - if sampling_rate is not None and sampling_rate != self.sampling_rate: - raise ValueError( - f"Expected sampling_rate={self.sampling_rate}, got {sampling_rate}" - ) + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) - device = torch.device(device) if device else self.device - max_length = max_length or self.n_samples - - # Convert inputs to list of 1D tensors - if isinstance(raw_speech, np.ndarray): - raw_speech = [raw_speech] if raw_speech.ndim == 1 else list(raw_speech) - elif isinstance(raw_speech, torch.Tensor): - raw_speech = ( - [raw_speech.numpy()] - if raw_speech.ndim == 1 - else [s.numpy() for s in raw_speech] - ) + return q_embed, k_embed - batch_size = len(raw_speech) - # Get actual lengths before padding - lengths = [len(s) for s in raw_speech] +class GlmAsrRotaryEmbedding(nn.Module): + """ + Rotary Position Embedding for GLM-ASR encoder. + + Optimized with pre-computed cos/sin cache for better performance. + Falls back to dynamic computation only when sequence length exceeds cache. + """ - # Pad/truncate to max_length - if padding == "max_length": - target_length = max_length - else: # 'longest' - target_length = min(max(lengths), max_length) + def __init__(self, config, device: torch.device | None = None): + super().__init__() + self.config = config + self.max_seq_len_cached = config.max_position_embeddings - # Create padded batch tensor - padded_waveforms = torch.zeros( - batch_size, target_length, dtype=torch.float32, device=device - ) - attention_mask = torch.zeros( - batch_size, target_length, dtype=torch.int32, device=device + # Compute inverse frequencies following transformers implementation + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads ) - for i, waveform in enumerate(raw_speech): - if isinstance(waveform, np.ndarray): - waveform = torch.from_numpy(waveform) - waveform = waveform.to(device=device, dtype=torch.float32) + # Handle rope_parameters if present (for compatibility with transformers config) + if hasattr(config, "rope_parameters") and config.rope_parameters: + base = config.rope_parameters.get("rope_theta", 10000.0) + partial_rotary_factor = config.rope_parameters.get( + "partial_rotary_factor", 1.0 + ) + dim = int(head_dim * partial_rotary_factor) + self.attention_scaling = config.rope_parameters.get( + "attention_scaling", 1.0 + ) + else: + base = getattr(config, "rope_theta", 10000.0) + dim = head_dim + self.attention_scaling = 1.0 + + self.dim = dim + self.base = base + + # Compute the inverse frequencies exactly as transformers does + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to( + device=device, dtype=torch.float + ) + / dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) - # Truncate if needed - actual_len = min(len(waveform), target_length) - padded_waveforms[i, :actual_len] = waveform[:actual_len] - attention_mask[i, :actual_len] = 1 + # Pre-compute cos/sin cache for efficiency + self._set_cos_sin_cache(self.max_seq_len_cached, device) - # Extract features on GPU - input_features = self._extract_fbank_features(padded_waveforms) + def _set_cos_sin_cache( + self, seq_len: int, device: torch.device | None = None + ) -> None: + """Pre-compute cos and sin cache for given sequence length.""" + self.max_seq_len_cached = seq_len - # Rescale attention mask from samples to frames - # STFT produces L//hop_length + 1 frames, but we drop the last one - frame_attention_mask = attention_mask[:, :: self.hop_length] - # Trim to match actual frame count (we drop last frame in _extract) - if attention_mask.shape[1] % self.hop_length != 0: - frame_attention_mask = frame_attention_mask[:, :-1] + # Create position indices + t = torch.arange(seq_len, device=device, dtype=torch.float32) + # Compute frequencies: [seq_len, dim/2] + freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32)) + # Double the frequencies: [seq_len, dim] + emb = torch.cat((freqs, freqs), dim=-1) - result: dict[str, Any] = {"input_features": input_features} - if return_attention_mask: - result["attention_mask"] = frame_attention_mask + # Compute and cache cos/sin + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling - return BatchFeature(data=result, tensor_type=return_tensors) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) - def _extract_fbank_features(self, waveforms: torch.Tensor) -> torch.Tensor: + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """ - Compute log-mel spectrogram for batched waveforms. + Compute rotary embeddings with caching optimization. Args: - waveforms: [batch, samples] float32 tensor on target device + x: Input tensor [batch_size, seq_len, hidden_size] + position_ids: Position indices [batch_size, seq_len] Returns: - [batch, n_mels, frames] float32 tensor (log-mel spectrogram) + Tuple of (cos, sin) tensors with shape [batch_size, seq_len, rotary_dim] """ - device = waveforms.device - self._ensure_buffers(device) - - # STFT: [batch, samples] -> [batch, n_fft//2+1, frames] complex - stft = torch.stft( - waveforms, - n_fft=self.n_fft, - hop_length=self.hop_length, - window=self._window, - return_complex=True, - ) + seq_len = position_ids.shape[-1] - # Power spectrogram, drop last frame (matching HF implementation) - magnitudes = stft[..., :-1].abs() ** 2 # [batch, n_freqs, frames] + # Extend cache if needed + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len, device=x.device) - # Apply mel filterbank: [n_mels, n_freqs] @ [batch, n_freqs, frames] - # -> [batch, n_mels, frames] - mel_spec = torch.matmul(self._mel_filters, magnitudes) + # Use cached values - index with position_ids for correctness + # For encoder, position_ids is typically [0, 1, 2, ..., seq_len-1] + # so we can directly slice the cache + cos = self.cos_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] + sin = self.sin_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] - # Log scale with floor - log_spec = torch.clamp(mel_spec, min=1e-10).log10() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - # Per-sample normalization (max - 8.0 floor, then scale) - max_val = log_spec.amax(dim=(1, 2), keepdim=True) - log_spec = torch.maximum(log_spec, max_val - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - return log_spec - - -class GlmAsrFeatureInputs(TensorSchema): - """ - Dimensions: - - num_chunks: Number of audio chunks (flattened) - - nmb: Number of mel bins - - num_audios: Number of original audio files +def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ + Repeat key/value tensors for Grouped Query Attention. - type: Literal["audio_features"] - input_features: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("num_chunks", "nmb", "chunk_length", dynamic_dims={"chunk_length"}), - ] - feature_attention_mask: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("num_chunks", "chunk_length", dynamic_dims={"chunk_length"}), - ] - chunk_counts: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("num_audios"), - ] - - -class GlmAsrEmbeddingInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size - - naf: Number of audio features - - hs: Hidden size (must match the hidden size of language model - backbone) - """ + Args: + hidden_states: [batch, num_kv_heads, seq_len, head_dim] + n_rep: Number of repetitions - type: Literal["audio_embeds"] = "audio_embeds" - audio_embeds: Annotated[ - list[torch.Tensor], - TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}), - ] + Returns: + [batch, num_kv_heads * n_rep, seq_len, head_dim] + """ + if n_rep == 1: + return hidden_states + batch, num_kv_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_kv_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) -GlmAsrInputs: TypeAlias = GlmAsrFeatureInputs | GlmAsrEmbeddingInputs +class GlmAsrAttention(nn.Module): + """ + Optimized Multi-headed Grouped Query Attention for GLM-ASR. + Uses vLLM's QKVParallelLinear for better performance. + """ -class GlmAsrMultiModalProjector(nn.Module): def __init__( self, - config: GlmAsrConfig, + config, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() - self.linear_1 = ColumnParallelLinear( - input_size=config.audio_config.intermediate_size, - output_size=config.text_config.hidden_size * 2, + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + self.num_kv_heads_per_rank = max(1, self.num_kv_heads // self.tp_size) + + # Use QKVParallelLinear for fused QKV projection + # Note: GLM-ASR uses bias on Q and V, but not K + # For simplicity with QKVParallelLinear, we use bias=True for all + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_kv_heads, + bias=True, quant_config=quant_config, - prefix=f"{prefix}.linear_1", + prefix=f"{prefix}.qkv_proj", ) - self.act = get_act_fn(config.projector_hidden_act) - self.linear_2 = RowParallelLinear( - input_size=config.text_config.hidden_size * 2, - output_size=config.text_config.hidden_size, + + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, quant_config=quant_config, - prefix=f"{prefix}.linear_2", + prefix=f"{prefix}.o_proj", ) - def forward(self, audio_features: torch.Tensor) -> torch.Tensor: - hidden_states, _ = self.linear_1(audio_features) - hidden_states = self.act(hidden_states) - hidden_states, _ = self.linear_2(hidden_states) - return hidden_states + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch_size, seq_len, hidden_size] + position_embeddings: Tuple of (cos, sin) for RoPE + Returns: + [batch_size, seq_len, hidden_size] + """ + batch_size, seq_len, _ = hidden_states.shape + + # QKV projection - fused for efficiency + qkv, _ = self.qkv_proj(hidden_states) + + # Split into q, k, v + q_size = self.num_heads_per_rank * self.head_dim + kv_size = self.num_kv_heads_per_rank * self.head_dim + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + # Reshape and transpose + # [batch, seq, num_heads * head_dim] -> [batch, num_heads, seq, head_dim] + q = q.view( + batch_size, seq_len, self.num_heads_per_rank, self.head_dim + ).transpose(1, 2) + k = k.view( + batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim + ).transpose(1, 2) + # v doesn't go through RoPE, so make it contiguous now for SDPA + v = ( + v.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim) + .transpose(1, 2) + .contiguous() + ) -class GlmAsrProcessingInfo(AudioFlamingo3ProcessingInfo): - def get_hf_config(self) -> GlmAsrConfig: - return self.ctx.get_hf_config(GlmAsrConfig) + # Apply rotary position embeddings + cos, sin = position_embeddings + q, k = _apply_rotary_pos_emb(q, k, cos, sin) + + # Handle GQA: repeat k/v if needed + if self.num_kv_groups > 1: + k = _repeat_kv(k, self.num_kv_groups) + v = _repeat_kv(v, self.num_kv_groups) + + # Ensure contiguous for optimal SDPA/Flash Attention performance + # Non-contiguous tensors can cause fallback to slower implementations + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + # Scaled dot-product attention (uses Flash Attention when available) + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=False, + ) - def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor: - return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs) + # Reshape back + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, -1) - def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: - # Reuse parent implementation, but add type annotation and assertion - feature_extractor = super().get_feature_extractor(**kwargs) - assert isinstance(feature_extractor, WhisperFeatureExtractor) - return feature_extractor + # Output projection + output, _ = self.o_proj(attn_output) + return output -class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_audios = mm_counts.get("audio", 0) - hf_processor = self.info.get_hf_processor() - return hf_processor.audio_token * num_audios +class GlmAsrMLP(nn.Module): + """ + Optimized MLP for GLM-ASR encoder. + Uses vLLM's parallel linear layers for better performance. + """ - def get_dummy_mm_data( + def __init__( self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - feature_extractor = self.info.get_feature_extractor() - sampling_rate = feature_extractor.sampling_rate - num_audios = mm_counts.get("audio", 0) - audio_overrides = mm_options.get("audio") if mm_options else None - - max_audio_len = getattr( - self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.fc1 = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", ) - audio_len = int(max_audio_len * sampling_rate) - return { - "audio": self._get_dummy_audios( - length=audio_len, num_audios=num_audios, overrides=audio_overrides - ) - } + self.act_fn = get_act_fn(config.hidden_act) + self.fc2 = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) -class GlmAsrMultiModalDataParser(AudioFlamingo3MultiModalDataParser): - def _parse_audio_data( - self, - data: dict[str, torch.Tensor] | ModalityData[Any], - ) -> ModalityDataItems[Any, Any] | None: - if isinstance(data, dict): - return DictEmbeddingItems( - data, - modality="audio", - required_fields={"audio_embeds"}, - fields_factory=_glmasr_field_config, - ) - return super()._parse_audio_data(data) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states -class GlmAsrMultiModalProcessor(BaseMultiModalProcessor["GlmAsrProcessingInfo"]): +class GlmAsrEncoderLayer(nn.Module): """ - GLM-ASR processor that inherits directly from BaseMultiModalProcessor - for better performance and cleaner implementation. - Uses GPU-accelerated feature extraction for improved throughput. + Optimized Transformer encoder layer for GLM-ASR. + Combines attention and MLP with residual connections and layer norms. """ - # Shared GPU feature extractor instance (lazy initialized) - _gpu_feature_extractor: GPUWhisperFeatureExtractor | None = None - - @classmethod - def _get_gpu_feature_extractor( - cls, - hf_feature_extractor: WhisperFeatureExtractor, - device: str = "cuda", - ) -> GPUWhisperFeatureExtractor: - """Get or create GPU feature extractor matching HF config.""" - if cls._gpu_feature_extractor is None: - cls._gpu_feature_extractor = GPUWhisperFeatureExtractor( - feature_size=hf_feature_extractor.feature_size, - sampling_rate=hf_feature_extractor.sampling_rate, - hop_length=hf_feature_extractor.hop_length, - chunk_length=hf_feature_extractor.chunk_length, - n_fft=hf_feature_extractor.n_fft, - padding_value=hf_feature_extractor.padding_value, - device=device, - ) - return cls._gpu_feature_extractor - - def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self.info.get_feature_extractor() - return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate) - - def _calculate_chunk_counts( - self, - audio_list: list[Any], - feature_extractor: WhisperFeatureExtractor, - processor: GlmAsrProcessor, - ) -> list[int]: - """Calculate chunk counts for each audio.""" - sampling_rate = feature_extractor.sampling_rate - chunk_length = feature_extractor.chunk_length - max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) - window_size = int(sampling_rate * chunk_length) - max_windows = int(max_audio_len // chunk_length) - - chunk_counts = [] - for audio in audio_list: - n_samples = len(audio) if isinstance(audio, list) else audio.shape[0] - n_chunks = max(1, (n_samples + window_size - 1) // window_size) - chunk_counts.append(min(n_chunks, max_windows)) - return chunk_counts - - # @torch.compile(fullgraph=True) - def _call_hf_processor( + def __init__( self, - prompt: str, - mm_data: dict[str, object], - mm_kwargs: Mapping[str, Any], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - """ - Call processor with GPU-accelerated feature extraction. - """ - # Normalize input: handle deprecated key and list conversion. - if "audios" in mm_data: - mm_data["audio"] = mm_data.pop("audios") - - audio = mm_data.get("audio", []) - audio_list = [audio] if audio and not isinstance(audio, list) else audio - - # Early return for text-only. - if not audio_list: - prompt_ids = self.info.get_tokenizer().encode(prompt) - prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) - return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - - # Get processor for tokenizer and config - processor = self.info.get_hf_processor(**mm_kwargs) - hf_feature_extractor = processor.feature_extractor - tokenizer = processor.tokenizer - - # ===== Audio chunking (CPU, fast) ===== - sampling_rate = hf_feature_extractor.sampling_rate - chunk_length = hf_feature_extractor.chunk_length - max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) - window_size = int(sampling_rate * chunk_length) - max_windows = int(max_audio_len // chunk_length) - - per_sample_windows: list[int] = [] - flat_chunks: list[np.ndarray] = [] - - for audio_el in audio_list: - # Convert to numpy if needed - if isinstance(audio_el, torch.Tensor): - audio_el = audio_el.numpy() - elif isinstance(audio_el, list): - audio_el = np.array(audio_el, dtype=np.float32) - - n_samples = int(audio_el.shape[0]) - n_win = max(1, (n_samples + window_size - 1) // window_size) - if n_win > max_windows: - n_win = max_windows - - per_sample_windows.append(n_win) - time_cap = min(n_samples, n_win * window_size) - - for i in range(n_win): - start = i * window_size - end = min((i + 1) * window_size, time_cap) - flat_chunks.append(audio_el[start:end]) - - # ===== GPU Feature Extraction ===== - # Check if CUDA is available, fallback to CPU if not - use_gpu = torch.cuda.is_available() - device = "cuda" if use_gpu else "cpu" - - if use_gpu: - # Use GPU-accelerated feature extractor - gpu_extractor = self._get_gpu_feature_extractor( - hf_feature_extractor, device=device - ) - audio_inputs = gpu_extractor( - flat_chunks, - sampling_rate=sampling_rate, - return_attention_mask=True, - return_tensors="pt", - ) - else: - # Fallback to HF CPU implementation - audio_inputs = hf_feature_extractor( - flat_chunks, - sampling_rate=sampling_rate, - return_tensors="pt", - padding=True, - return_attention_mask=True, - ) - - # ===== Process attention mask ===== - padding_mask = audio_inputs.pop("attention_mask") - input_features_mask = padding_mask + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size - # ===== Compute audio token lengths ===== - chunk_lengths = padding_mask.sum(-1) # [num_chunks] - audio_lengths = torch.stack( - [ - chunk_lengths[ - sum(per_sample_windows[:i]) : sum(per_sample_windows[: i + 1]) - ].sum() - for i in range(len(per_sample_windows)) - ] + self.self_attn = GlmAsrAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) - # Apply convolution formula to get token counts - merge_factor = 4 - for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]: - audio_lengths = ( - audio_lengths + 2 * padding - (kernel_size - 1) - 1 - ) // stride + 1 - audio_tokens_lengths = (audio_lengths - merge_factor) // merge_factor + 1 - - # ===== Expand audio tokens in text ===== - import regex as re - - audio_token = getattr(processor, "audio_token", "<|pad|>") - text_list = [prompt] - - for i, audio_length in enumerate(audio_tokens_lengths): - if i < len(text_list): - expanded = re.sub( - re.escape(audio_token), - audio_token * int(audio_length), - text_list[i], - ) - text_list[i] = expanded - - # ===== Tokenize text ===== - text_inputs = tokenizer( - text_list, - return_tensors="pt", - padding=True, - **tok_kwargs, + self.mlp = GlmAsrMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", ) - # ===== Combine outputs ===== - # Move input_features to CPU for compatibility - input_features = audio_inputs["input_features"] - if input_features.device.type != "cpu": - input_features = input_features.cpu() - if input_features_mask.device.type != "cpu": - input_features_mask = input_features_mask.cpu() - - outputs = BatchFeature( - data={ - **text_inputs, - "input_features": input_features, - "feature_attention_mask": input_features_mask, - }, - tensor_type="pt", + layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5) + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + self.hidden_size, eps=layer_norm_eps ) - outputs["chunk_counts"] = torch.tensor(per_sample_windows, dtype=torch.long) - - return outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return _glmasr_field_config(hf_inputs) - - def _get_prompt_updates( + def forward( self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - tokenizer = self.info.get_tokenizer() - vocab = tokenizer.get_vocab() - config = self.info.get_hf_config() - - audio_token = getattr(processor, "audio_token", "<|pad|>") - audio_token_id = vocab.get(audio_token) - if audio_token_id is None: - audio_token_id = processor.audio_token_id - - merge_factor = getattr(config, "merge_factor", DEFAULT_MERGE_FACTOR) - out_mm_data = out_mm_kwargs.get_data() - feature_attention_mask = out_mm_data.get("feature_attention_mask") - chunk_counts = out_mm_data.get("chunk_counts") - - def get_replacement_glmasr(item_idx: int): - conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS) - audio_embeds = out_mm_data.get("audio_embeds") - num_features = _get_num_features_for_item( - feature_attention_mask, - chunk_counts, - item_idx, - audio_embeds, - merge_factor, - conv_params, - ) + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch_size, seq_len, hidden_size] + position_embeddings: Tuple of (cos, sin) for RoPE - if num_features == 0: - raise ValueError("Audio is too short") + Returns: + [batch_size, seq_len, hidden_size] + """ + # Self-attention with residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states - audio_tokens = [audio_token_id] * int(num_features) - return PromptUpdateDetails.select_token_id( - audio_tokens, - embed_token_id=audio_token_id, - ) + # MLP with residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - return [ - PromptReplacement( - modality="audio", - target=audio_token, - replacement=get_replacement_glmasr, - ) - ] + return hidden_states -@MULTIMODAL_REGISTRY.register_processor( - GlmAsrMultiModalProcessor, - info=GlmAsrProcessingInfo, - dummy_inputs=GlmAsrDummyInputsBuilder, -) -class GlmAsrForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription -): - supported_languages = ISO639_1_SUPPORTED_LANGS +class GlmAsrEncoder(nn.Module): + """ + Optimized GLM-ASR Audio Encoder with vLLM native implementation. + + This encoder processes audio features through convolutional layers + followed by transformer layers with rotary position embeddings. + Optimized for performance with: + - QKVParallelLinear for fused attention projections + - Tensor parallelism support via ColumnParallelLinear/RowParallelLinear + - Quantization support + - Flash Attention (SDPA) + """ + # Mapping for weight loading: transformers uses separate q/k/v, we use fused qkv packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], } - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - multimodal_config = vllm_config.model_config.multimodal_config self.config = config - self.multimodal_config = multimodal_config - - # Use optimized vLLM native encoder - self.audio_tower = GlmAsrEncoder( - config.audio_config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "audio_tower"), - ) - self.multi_modal_projector = GlmAsrMultiModalProjector( - config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector"), - ) - self.quant_config = quant_config - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model"), - architectures=["LlamaForCausalLM"], + # Convolutional feature extraction layers + self.conv1 = nn.Conv1d( + config.num_mel_bins, + config.hidden_size, + kernel_size=3, + padding=1, ) - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors + self.conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=3, + stride=2, + padding=1, ) - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("audio"): - return "<|begin_of_audio|><|pad|><|end_of_audio|>" - - raise ValueError("Only audio modality is supported") - - def get_mm_mapping(self) -> MultiModelKeys: - return MultiModelKeys.from_string_field( - language_model="language_model.", - connector="multi_modal_projector.", - tower_model="audio_tower.", + # Transformer encoder layers + self.layers = nn.ModuleList( + [ + GlmAsrEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] ) - def _parse_and_validate_audio_input(self, **kwargs: object) -> GlmAsrInputs | None: - audio_embeds = kwargs.pop("audio_embeds", None) - if audio_embeds is not None: - return GlmAsrEmbeddingInputs(type="audio_embeds", audio_embeds=audio_embeds) + # Final layer norm + layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5) + self.norm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps) - input_features = kwargs.pop("input_features", None) - if input_features is None: - return None + # Rotary position embeddings + self.rotary_emb = GlmAsrRotaryEmbedding(config) - return GlmAsrFeatureInputs( - type="audio_features", - input_features=input_features, - feature_attention_mask=kwargs.pop("feature_attention_mask", None), - chunk_counts=kwargs.pop("chunk_counts", None), + # Pre-register position_ids buffer for efficiency + # This avoids creating a new tensor on every forward pass + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0), + persistent=False, ) - def _process_audio_input( - self, audio_input: GlmAsrInputs - ) -> torch.Tensor | tuple[torch.Tensor, ...]: - if audio_input["type"] == "audio_embeds": - return tuple(audio_input["audio_embeds"]) - - input_features = audio_input["input_features"] - feature_attention_mask = audio_input["feature_attention_mask"] - - if isinstance(input_features, list): - input_features = torch.cat(input_features, dim=0) - feature_attention_mask = torch.cat(feature_attention_mask, dim=0) - - num_chunks = input_features.shape[0] - chunk_counts = _normalize_chunk_counts( - audio_input.get("chunk_counts"), num_chunks=num_chunks - ) + def _get_feat_extract_output_lengths( + self, input_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the output length after convolutions. - # Convert input_features to model dtype (e.g., bfloat16) to match model weights - input_features = input_features.to(dtype=self.audio_tower.conv1.weight.dtype) - - # audio_tower returns [batch_size, seq_len, hidden_size] where hidden_size=1280 - audio_hidden_states = self.audio_tower(input_features).last_hidden_state - - # GLM-ASR merges consecutive frames: 4 frames with hidden_size=1280 - # -> 1 frame with intermediate_size=5120 - hidden_size = self.config.audio_config.hidden_size - intermediate_size = self.config.audio_config.intermediate_size - merge_ratio = intermediate_size // hidden_size - - # Truncate sequence length to be divisible by merge_ratio - seq_len = audio_hidden_states.shape[1] - seq_len_truncated = (seq_len // merge_ratio) * merge_ratio - if seq_len_truncated < seq_len: - audio_hidden_states = audio_hidden_states[:, :seq_len_truncated, :] - - # Reshape to merge consecutive frames - audio_hidden_states = audio_hidden_states.reshape( - num_chunks, - -1, - intermediate_size, - ) + Args: + input_lengths: Input sequence lengths [batch_size] - audio_features = self.multi_modal_projector(audio_hidden_states) + Returns: + Tuple of (output after conv1, output after conv2) + """ + # Conv1: kernel=3, stride=1, padding=1 + output_lengths = (input_lengths + 2 * 1 - 3) // 1 + 1 - merge_factor = getattr(self.config, "merge_factor", DEFAULT_MERGE_FACTOR) - conv_params = getattr(self.config, "conv_params", DEFAULT_CONV_PARAMS) + # Conv2: kernel=3, stride=2, padding=1 + output_lengths = (output_lengths + 2 * 1 - 3) // 2 + 1 - audio_output_lengths = _get_audio_output_lengths_for_tower( - self.audio_tower, - feature_attention_mask.sum(-1), - merge_factor, - conv_params, - ) + return input_lengths, output_lengths - masked_audio_features = _flatten_audio_features_by_length( - audio_features, audio_output_lengths - ) + def forward(self, input_features: torch.Tensor): + """ + Forward pass through the encoder. - chunk_embeddings = torch.split( - masked_audio_features, audio_output_lengths.flatten().tolist() - ) - result = _group_audio_embeddings(chunk_embeddings, chunk_counts) + Args: + input_features: [batch_size, num_mel_bins, seq_len] - return result + Returns: + Object with .last_hidden_state attribute containing + [batch_size, seq_len', hidden_size] where seq_len' is + the sequence length after convolutions + """ + # Apply convolutional layers with GELU activation + hidden_states = torch.nn.functional.gelu(self.conv1(input_features)) + hidden_states = torch.nn.functional.gelu(self.conv2(hidden_states)) - def get_language_model(self) -> torch.nn.Module: - return self.language_model + # Transpose to [batch_size, seq_len, hidden_size] + hidden_states = hidden_states.transpose(1, 2) + output_seq_len = hidden_states.shape[1] - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: - audio_input = self._parse_and_validate_audio_input(**kwargs) - if audio_input is None: - return [] + # Use pre-registered position_ids buffer (slice to actual seq_len) + position_ids = self.position_ids[:, :output_seq_len] - masked_audio_features = self._process_audio_input(audio_input) + # Get position embeddings - uses pre-computed cache + position_embeddings = self.rotary_emb(hidden_states, position_ids) - return masked_audio_features + # Apply transformer layers + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states, position_embeddings) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> torch.Tensor | IntermediateTensors: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model.model( - input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - return hidden_states + # Final layer norm + hidden_states = self.norm(hidden_states) - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) + # Return in a format compatible with transformers' BaseModelOutput + return _GlmAsrEncoderOutput(last_hidden_state=hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = ["audio_tower.embed_positions"] - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - return loader.load_weights(weights) - - @classmethod - def _get_audio_token(cls, model_config: ModelConfig) -> str: - """Get the audio token from processor. - - Similar to get_placeholder_str but returns single token. - """ - processor = cached_processor_from_config(model_config) - return getattr(processor, "audio_token", "<|pad|>") - - @classmethod - def get_speech_to_text_config( - cls, model_config: ModelConfig, task_type: str - ) -> SpeechToTextConfig: - processor = cached_processor_from_config(model_config) - feature_extractor = processor.feature_extractor - max_audio_clip_s = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) - return SpeechToTextConfig( - max_audio_clip_s=max_audio_clip_s, - sample_rate=feature_extractor.sampling_rate, - ) - - @classmethod - def get_generation_prompt( - cls, - audio: np.ndarray, - model_config: ModelConfig, - stt_config: SpeechToTextConfig, - language: str | None, - task_type: Literal["transcribe", "translate"], - request_prompt: str, - to_language: str | None, - ) -> PromptType: - """Get the generation prompt to be used for transcription requests.""" - tokenizer = cached_tokenizer_from_config(model_config) - audio_token = cls._get_audio_token(model_config) - - if task_type == "translate": - full_lang_name_to = cls.supported_languages.get(to_language, to_language) - user_content = f"{audio_token}translate the speech to {full_lang_name_to}" - elif task_type == "transcribe": - user_content = ( - f"{audio_token}can you transcribe the speech into a written format?" - ) - else: - raise ValueError(f"Unsupported task type {task_type}") - - messages = [{"role": "user", "content": user_content}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - prompt_token_ids = tokenizer.encode(prompt) - prompt_dict = { - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": {"audio": audio}, - } - return cast(PromptType, prompt_dict) + """Custom weight loading to handle q_proj/k_proj/v_proj -> qkv_proj mapping.""" + from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Default weight loading for non-stacked params + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params From d3fb8d4fecbd9e2e8da45dce590b1657d53d7895 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 10:21:23 +0800 Subject: [PATCH 03/24] use hf_feature_extractor.mel_filter to ensure accuracy Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 112 ++++++++------------------- 1 file changed, 32 insertions(+), 80 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 681603bec482..fc096dd1b43b 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -3,7 +3,6 @@ import logging from collections.abc import Iterable, Mapping, Sequence -from functools import lru_cache from typing import Annotated, Any, Literal, TypeAlias, cast import numpy as np @@ -84,58 +83,16 @@ # ============================================================================= -@lru_cache(maxsize=1) -def _get_mel_filters( - n_fft: int = 400, - n_mels: int = 80, - sampling_rate: int = 16000, - device: torch.device | None = None, -) -> torch.Tensor: - """ - Compute mel filterbank matrix (cached). - Matches WhisperFeatureExtractor's mel_filter_bank with slaney norm/scale. - """ - if device is None: - device = torch.device("cpu") - # Frequency bins - n_freqs = n_fft // 2 + 1 - all_freqs = torch.linspace(0, sampling_rate // 2, n_freqs, device=device) - - # Mel scale conversion (slaney) - min_mel = 0.0 - max_mel = 2595.0 * np.log10(1.0 + (sampling_rate / 2) / 700.0) - mels = torch.linspace(min_mel, max_mel, n_mels + 2, device=device) - mel_freqs = 700.0 * (10.0 ** (mels / 2595.0) - 1.0) - - # Create filterbank - mel_filters = torch.zeros(n_mels, n_freqs, device=device) - for i in range(n_mels): - lower = mel_freqs[i] - center = mel_freqs[i + 1] - upper = mel_freqs[i + 2] - - # Lower slope - lower_slope = (all_freqs - lower) / (center - lower + 1e-10) - # Upper slope - upper_slope = (upper - all_freqs) / (upper - center + 1e-10) - - mel_filters[i] = torch.maximum( - torch.zeros_like(all_freqs), - torch.minimum(lower_slope, upper_slope), - ) - - # Slaney normalization - enorm = 2.0 / (mel_freqs[2 : n_mels + 2] - mel_freqs[:n_mels]) - mel_filters *= enorm.unsqueeze(1) - - return mel_filters - - class GPUWhisperFeatureExtractor: """ GPU-accelerated Whisper feature extractor using PyTorch. Computes log-mel spectrogram matching WhisperFeatureExtractor output. + This implementation reuses the mel filterbank from HuggingFace's + WhisperFeatureExtractor to ensure numerical precision (1e-5 tolerance). + The key optimization is caching the window and mel_filters tensors on GPU + to avoid repeated CPU->GPU transfers. + Key parameters (Whisper defaults): - n_fft: 400 (25ms window at 16kHz) - hop_length: 160 (10ms hop at 16kHz) @@ -146,42 +103,42 @@ class GPUWhisperFeatureExtractor: def __init__( self, - feature_size: int = 80, - sampling_rate: int = 16000, - hop_length: int = 160, - chunk_length: int = 30, - n_fft: int = 400, - padding_value: float = 0.0, + hf_feature_extractor: WhisperFeatureExtractor, device: str | torch.device = "cuda", ): - self.feature_size = feature_size - self.sampling_rate = sampling_rate - self.hop_length = hop_length - self.chunk_length = chunk_length - self.n_fft = n_fft - self.padding_value = padding_value + # Copy parameters from HF feature extractor + self.feature_size = hf_feature_extractor.feature_size + self.sampling_rate = hf_feature_extractor.sampling_rate + self.hop_length = hf_feature_extractor.hop_length + self.chunk_length = hf_feature_extractor.chunk_length + self.n_fft = hf_feature_extractor.n_fft + self.padding_value = hf_feature_extractor.padding_value self.device = torch.device(device) if isinstance(device, str) else device # Derived parameters - self.n_samples = chunk_length * sampling_rate # 480000 for 30s - self.nb_max_frames = self.n_samples // hop_length # 3000 frames + self.n_samples = self.chunk_length * self.sampling_rate # 480000 for 30s + self.nb_max_frames = self.n_samples // self.hop_length # 3000 frames + + # Store HF's mel_filters (numpy float64) for precision + # This is precomputed by HF using librosa-compatible mel_filter_bank + self._mel_filters_np: np.ndarray = hf_feature_extractor.mel_filters - # Pre-compute window and mel filters on device + # Cached GPU tensors (lazily initialized) self._window: torch.Tensor | None = None self._mel_filters: torch.Tensor | None = None + self._current_device: torch.device | None = None def _ensure_buffers(self, device: torch.device) -> None: - """Lazily initialize buffers on the target device.""" - if self._window is None or self._window.device != device: - self._window = torch.hann_window(self.n_fft, device=device) - - if self._mel_filters is None or self._mel_filters.device != device: - self._mel_filters = _get_mel_filters( - n_fft=self.n_fft, - n_mels=self.feature_size, - sampling_rate=self.sampling_rate, - device=device, - ) + """Lazily initialize and cache buffers on the target device.""" + if self._current_device == device: + return + + self._window = torch.hann_window(self.n_fft, device=device) + # Convert from numpy float64 to preserve precision during transfer + self._mel_filters = torch.from_numpy(self._mel_filters_np).to( + device=device, dtype=torch.float32 + ) + self._current_device = device def __call__( self, @@ -459,12 +416,7 @@ def _get_gpu_feature_extractor( """Get or create GPU feature extractor matching HF config.""" if cls._gpu_feature_extractor is None: cls._gpu_feature_extractor = GPUWhisperFeatureExtractor( - feature_size=hf_feature_extractor.feature_size, - sampling_rate=hf_feature_extractor.sampling_rate, - hop_length=hf_feature_extractor.hop_length, - chunk_length=hf_feature_extractor.chunk_length, - n_fft=hf_feature_extractor.n_fft, - padding_value=hf_feature_extractor.padding_value, + hf_feature_extractor=hf_feature_extractor, device=device, ) return cls._gpu_feature_extractor From 15509b824a92195468314d7256ee0b0afafeff15 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 10:43:34 +0800 Subject: [PATCH 04/24] fix shape problem Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index fc096dd1b43b..1f0259b76d6d 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -253,9 +253,11 @@ def _extract_fbank_features(self, waveforms: torch.Tensor) -> torch.Tensor: # Power spectrogram, drop last frame (matching HF implementation) magnitudes = stft[..., :-1].abs() ** 2 # [batch, n_freqs, frames] - # Apply mel filterbank: [n_mels, n_freqs] @ [batch, n_freqs, frames] + # Apply mel filterbank: [n_freqs, n_mels].T @ [batch, n_freqs, frames] # -> [batch, n_mels, frames] - mel_spec = torch.matmul(self._mel_filters, magnitudes) + # HF uses mel_filters.T @ magnitudes, where mel_filters is [n_mels, n_freqs] + # So we transpose to get [n_freqs, n_mels], then use matmul with broadcasting + mel_spec = torch.matmul(self._mel_filters.T, magnitudes) # Log scale with floor log_spec = torch.clamp(mel_spec, min=1e-10).log10() From b3e5e5164741f216cdb639ef1b846c3a4e8b89a9 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 11:07:02 +0800 Subject: [PATCH 05/24] perf Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 1f0259b76d6d..41bdfe479a2c 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -125,20 +125,26 @@ def __init__( # Cached GPU tensors (lazily initialized) self._window: torch.Tensor | None = None - self._mel_filters: torch.Tensor | None = None - self._current_device: torch.device | None = None + self._mel_filters_T: torch.Tensor | None = None # Transposed & contiguous + self._current_device_str: str | None = None def _ensure_buffers(self, device: torch.device) -> None: """Lazily initialize and cache buffers on the target device.""" - if self._current_device == device: + # Use string comparison for stable device checking + device_str = str(device) + if self._current_device_str == device_str: return self._window = torch.hann_window(self.n_fft, device=device) - # Convert from numpy float64 to preserve precision during transfer - self._mel_filters = torch.from_numpy(self._mel_filters_np).to( - device=device, dtype=torch.float32 + # Convert from numpy float64, transpose, and make contiguous for optimal matmul + # HF mel_filters is [n_mels, n_freqs], we need [n_freqs, n_mels] for matmul + # Using .contiguous() ensures optimal memory layout for GPU matmul + self._mel_filters_T = ( + torch.from_numpy(self._mel_filters_np) + .to(device=device, dtype=torch.float32) + .T.contiguous() ) - self._current_device = device + self._current_device_str = device_str def __call__( self, @@ -253,11 +259,10 @@ def _extract_fbank_features(self, waveforms: torch.Tensor) -> torch.Tensor: # Power spectrogram, drop last frame (matching HF implementation) magnitudes = stft[..., :-1].abs() ** 2 # [batch, n_freqs, frames] - # Apply mel filterbank: [n_freqs, n_mels].T @ [batch, n_freqs, frames] + # Apply mel filterbank: [n_freqs, n_mels] @ [batch, n_freqs, frames] # -> [batch, n_mels, frames] - # HF uses mel_filters.T @ magnitudes, where mel_filters is [n_mels, n_freqs] - # So we transpose to get [n_freqs, n_mels], then use matmul with broadcasting - mel_spec = torch.matmul(self._mel_filters.T, magnitudes) + # _mel_filters_T is pre-transposed and contiguous for optimal performance + mel_spec = torch.matmul(self._mel_filters_T, magnitudes) # Log scale with floor log_spec = torch.clamp(mel_spec, min=1e-10).log10() From 58ad75d4c3e260c6feb8f688733b57ed7a423354 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 13:35:00 +0800 Subject: [PATCH 06/24] move GlmAsrEncoder from utils to glmasr Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 569 +++++++++++++++++++- vllm/model_executor/models/glmasr_utils.py | 582 +-------------------- 2 files changed, 569 insertions(+), 582 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 41bdfe479a2c..b70fb9db1702 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -14,10 +14,12 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.inputs.data import PromptType from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig @@ -58,7 +60,6 @@ DEFAULT_CONV_PARAMS, DEFAULT_MAX_AUDIO_LEN_S, DEFAULT_MERGE_FACTOR, - GlmAsrEncoder, _flatten_audio_features_by_length, _get_audio_output_lengths_for_tower, _get_num_features_for_item, @@ -77,6 +78,572 @@ logger = logging.getLogger(__name__) +# Optimized vLLM Native GlmAsrEncoder Implementation + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embeddings to query and key tensors. + + Follows transformers' apply_rotary_pos_emb exactly. + Supports partial rotary where only the first rotary_dim of head_dim is rotated. + + Args: + q: [batch, num_heads, seq_len, head_dim] + k: [batch, num_kv_heads, seq_len, head_dim] + cos: [batch, seq_len, rotary_dim] + sin: [batch, seq_len, rotary_dim] + """ + # unsqueeze_dim=1 to add head dimension: [batch, 1, seq_len, rotary_dim] + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + # Get the rotary dimension from cos/sin + rotary_dim = cos.shape[-1] + + # Split into rotary and pass-through parts + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (_rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (_rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + + return q_embed, k_embed + + +class GlmAsrRotaryEmbedding(nn.Module): + """ + Rotary Position Embedding for GLM-ASR encoder. + + Optimized with pre-computed cos/sin cache for better performance. + Falls back to dynamic computation only when sequence length exceeds cache. + """ + + def __init__(self, config, device: torch.device | None = None): + super().__init__() + self.config = config + self.max_seq_len_cached = config.max_position_embeddings + + # Compute inverse frequencies following transformers implementation + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + + # Handle rope_parameters if present (for compatibility with transformers config) + if hasattr(config, "rope_parameters") and config.rope_parameters: + base = config.rope_parameters.get("rope_theta", 10000.0) + partial_rotary_factor = config.rope_parameters.get( + "partial_rotary_factor", 1.0 + ) + dim = int(head_dim * partial_rotary_factor) + self.attention_scaling = config.rope_parameters.get( + "attention_scaling", 1.0 + ) + else: + base = getattr(config, "rope_theta", 10000.0) + dim = head_dim + self.attention_scaling = 1.0 + + self.dim = dim + self.base = base + + # Compute the inverse frequencies exactly as transformers does + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to( + device=device, dtype=torch.float + ) + / dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Pre-compute cos/sin cache for efficiency + self._set_cos_sin_cache(self.max_seq_len_cached, device) + + def _set_cos_sin_cache( + self, seq_len: int, device: torch.device | None = None + ) -> None: + """Pre-compute cos and sin cache for given sequence length.""" + self.max_seq_len_cached = seq_len + + # Create position indices + t = torch.arange(seq_len, device=device, dtype=torch.float32) + # Compute frequencies: [seq_len, dim/2] + freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32)) + # Double the frequencies: [seq_len, dim] + emb = torch.cat((freqs, freqs), dim=-1) + + # Compute and cache cos/sin + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute rotary embeddings with caching optimization. + + Args: + x: Input tensor [batch_size, seq_len, hidden_size] + position_ids: Position indices [batch_size, seq_len] + + Returns: + Tuple of (cos, sin) tensors with shape [batch_size, seq_len, rotary_dim] + """ + seq_len = position_ids.shape[-1] + + # Extend cache if needed + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len, device=x.device) + + # Use cached values - index with position_ids for correctness + # For encoder, position_ids is typically [0, 1, 2, ..., seq_len-1] + # so we can directly slice the cache + cos = self.cos_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] + sin = self.sin_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Repeat key/value tensors for Grouped Query Attention. + + Args: + hidden_states: [batch, num_kv_heads, seq_len, head_dim] + n_rep: Number of repetitions + + Returns: + [batch, num_kv_heads * n_rep, seq_len, head_dim] + """ + if n_rep == 1: + return hidden_states + + batch, num_kv_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_kv_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) + + +class GlmAsrAttention(nn.Module): + """ + Optimized Multi-headed Grouped Query Attention for GLM-ASR. + Uses vLLM's QKVParallelLinear for better performance. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + self.num_kv_heads_per_rank = max(1, self.num_kv_heads // self.tp_size) + + # Use QKVParallelLinear for fused QKV projection + # Note: GLM-ASR uses bias on Q and V, but not K + # For simplicity with QKVParallelLinear, we use bias=True for all + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch_size, seq_len, hidden_size] + position_embeddings: Tuple of (cos, sin) for RoPE + + Returns: + [batch_size, seq_len, hidden_size] + """ + batch_size, seq_len, _ = hidden_states.shape + + # QKV projection - fused for efficiency + qkv, _ = self.qkv_proj(hidden_states) + + # Split into q, k, v + q_size = self.num_heads_per_rank * self.head_dim + kv_size = self.num_kv_heads_per_rank * self.head_dim + q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) + + # Reshape and transpose + # [batch, seq, num_heads * head_dim] -> [batch, num_heads, seq, head_dim] + q = q.view( + batch_size, seq_len, self.num_heads_per_rank, self.head_dim + ).transpose(1, 2) + k = k.view( + batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim + ).transpose(1, 2) + # v doesn't go through RoPE, so make it contiguous now for SDPA + v = ( + v.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + # Apply rotary position embeddings + cos, sin = position_embeddings + q, k = _apply_rotary_pos_emb(q, k, cos, sin) + + # Handle GQA: repeat k/v if needed + if self.num_kv_groups > 1: + k = _repeat_kv(k, self.num_kv_groups) + v = _repeat_kv(v, self.num_kv_groups) + + # Ensure contiguous for optimal SDPA/Flash Attention performance + # Non-contiguous tensors can cause fallback to slower implementations + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + # Scaled dot-product attention (uses Flash Attention when available) + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=False, + ) + + # Reshape back + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, -1) + + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class GlmAsrMLP(nn.Module): + """ + Optimized MLP for GLM-ASR encoder. + Uses vLLM's parallel linear layers for better performance. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.fc1 = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + + self.act_fn = get_act_fn(config.hidden_act) + + self.fc2 = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class GlmAsrEncoderLayer(nn.Module): + """ + Optimized Transformer encoder layer for GLM-ASR. + Combines attention and MLP with residual connections and layer norms. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GlmAsrAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = GlmAsrMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5) + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + self.hidden_size, eps=layer_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch_size, seq_len, hidden_size] + position_embeddings: Tuple of (cos, sin) for RoPE + + Returns: + [batch_size, seq_len, hidden_size] + """ + # Self-attention with residual + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # MLP with residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class _GlmAsrEncoderOutput: + """Simple output container compatible with transformers' BaseModelOutput.""" + + __slots__ = ("last_hidden_state",) + + def __init__(self, last_hidden_state: torch.Tensor): + self.last_hidden_state = last_hidden_state + + +class GlmAsrEncoder(nn.Module): + """ + Optimized GLM-ASR Audio Encoder with vLLM native implementation. + + This encoder processes audio features through convolutional layers + followed by transformer layers with rotary position embeddings. + Optimized for performance with: + - QKVParallelLinear for fused attention projections + - Tensor parallelism support via ColumnParallelLinear/RowParallelLinear + - Quantization support + - Flash Attention (SDPA) + """ + + # Mapping for weight loading: transformers uses separate q/k/v, we use fused qkv + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + # Convolutional feature extraction layers + self.conv1 = nn.Conv1d( + config.num_mel_bins, + config.hidden_size, + kernel_size=3, + padding=1, + ) + self.conv2 = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=3, + stride=2, + padding=1, + ) + + # Transformer encoder layers + self.layers = nn.ModuleList( + [ + GlmAsrEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + # Final layer norm + layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5) + self.norm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps) + + # Rotary position embeddings + self.rotary_emb = GlmAsrRotaryEmbedding(config) + + # Pre-register position_ids buffer for efficiency + # This avoids creating a new tensor on every forward pass + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0), + persistent=False, + ) + + def _get_feat_extract_output_lengths( + self, input_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the output length after convolutions. + + Args: + input_lengths: Input sequence lengths [batch_size] + + Returns: + Tuple of (output after conv1, output after conv2) + """ + # Conv1: kernel=3, stride=1, padding=1 + output_lengths = (input_lengths + 2 * 1 - 3) // 1 + 1 + + # Conv2: kernel=3, stride=2, padding=1 + output_lengths = (output_lengths + 2 * 1 - 3) // 2 + 1 + + return input_lengths, output_lengths + + def forward(self, input_features: torch.Tensor): + """ + Forward pass through the encoder. + + Args: + input_features: [batch_size, num_mel_bins, seq_len] + + Returns: + Object with .last_hidden_state attribute containing + [batch_size, seq_len', hidden_size] where seq_len' is + the sequence length after convolutions + """ + # Apply convolutional layers with GELU activation + hidden_states = torch.nn.functional.gelu(self.conv1(input_features)) + hidden_states = torch.nn.functional.gelu(self.conv2(hidden_states)) + + # Transpose to [batch_size, seq_len, hidden_size] + hidden_states = hidden_states.transpose(1, 2) + output_seq_len = hidden_states.shape[1] + + # Use pre-registered position_ids buffer (slice to actual seq_len) + position_ids = self.position_ids[:, :output_seq_len] + + # Get position embeddings - uses pre-computed cache + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # Apply transformer layers + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states, position_embeddings) + + # Final layer norm + hidden_states = self.norm(hidden_states) + + # Return in a format compatible with transformers' BaseModelOutput + return _GlmAsrEncoderOutput(last_hidden_state=hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Custom weight loading to handle q_proj/k_proj/v_proj -> qkv_proj mapping.""" + from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Default weight loading for non-stacked params + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + # ============================================================================= # GPU-accelerated Whisper Feature Extractor diff --git a/vllm/model_executor/models/glmasr_utils.py b/vllm/model_executor/models/glmasr_utils.py index 24d65ae54aec..a00aeaad3a3f 100644 --- a/vllm/model_executor/models/glmasr_utils.py +++ b/vllm/model_executor/models/glmasr_utils.py @@ -2,23 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from typing import cast import torch import torch.nn as nn -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, -) - logger = logging.getLogger(__name__) DEFAULT_MAX_AUDIO_LEN_S = 655 @@ -28,15 +17,6 @@ DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)] -class _GlmAsrEncoderOutput: - """Simple output container compatible with transformers' BaseModelOutput.""" - - __slots__ = ("last_hidden_state",) - - def __init__(self, last_hidden_state: torch.Tensor): - self.last_hidden_state = last_hidden_state - - def _calculate_conv_output_length( input_length: torch.Tensor, padding: int, kernel_size: int, stride: int ) -> torch.Tensor: @@ -209,563 +189,3 @@ def _get_num_features_for_item( if audio_embeds is not None: return audio_embeds[item_idx].shape[0] raise ValueError("Either feature_attention_mask or audio_embeds must be provided") - - -# ============================================================================ -# Optimized vLLM Native GlmAsrEncoder Implementation -# ============================================================================ - - -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def _apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embeddings to query and key tensors. - - Follows transformers' apply_rotary_pos_emb exactly. - Supports partial rotary where only the first rotary_dim of head_dim is rotated. - - Args: - q: [batch, num_heads, seq_len, head_dim] - k: [batch, num_kv_heads, seq_len, head_dim] - cos: [batch, seq_len, rotary_dim] - sin: [batch, seq_len, rotary_dim] - """ - # unsqueeze_dim=1 to add head dimension: [batch, 1, seq_len, rotary_dim] - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - - # Get the rotary dimension from cos/sin - rotary_dim = cos.shape[-1] - - # Split into rotary and pass-through parts - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - # Apply rotary embeddings on the first half or full tensor - q_embed = (q_rot * cos) + (_rotate_half(q_rot) * sin) - k_embed = (k_rot * cos) + (_rotate_half(k_rot) * sin) - - # Concatenate back to full shape - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - - return q_embed, k_embed - - -class GlmAsrRotaryEmbedding(nn.Module): - """ - Rotary Position Embedding for GLM-ASR encoder. - - Optimized with pre-computed cos/sin cache for better performance. - Falls back to dynamic computation only when sequence length exceeds cache. - """ - - def __init__(self, config, device: torch.device | None = None): - super().__init__() - self.config = config - self.max_seq_len_cached = config.max_position_embeddings - - # Compute inverse frequencies following transformers implementation - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - - # Handle rope_parameters if present (for compatibility with transformers config) - if hasattr(config, "rope_parameters") and config.rope_parameters: - base = config.rope_parameters.get("rope_theta", 10000.0) - partial_rotary_factor = config.rope_parameters.get( - "partial_rotary_factor", 1.0 - ) - dim = int(head_dim * partial_rotary_factor) - self.attention_scaling = config.rope_parameters.get( - "attention_scaling", 1.0 - ) - else: - base = getattr(config, "rope_theta", 10000.0) - dim = head_dim - self.attention_scaling = 1.0 - - self.dim = dim - self.base = base - - # Compute the inverse frequencies exactly as transformers does - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, dim, 2, dtype=torch.int64).to( - device=device, dtype=torch.float - ) - / dim - ) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Pre-compute cos/sin cache for efficiency - self._set_cos_sin_cache(self.max_seq_len_cached, device) - - def _set_cos_sin_cache( - self, seq_len: int, device: torch.device | None = None - ) -> None: - """Pre-compute cos and sin cache for given sequence length.""" - self.max_seq_len_cached = seq_len - - # Create position indices - t = torch.arange(seq_len, device=device, dtype=torch.float32) - # Compute frequencies: [seq_len, dim/2] - freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32)) - # Double the frequencies: [seq_len, dim] - emb = torch.cat((freqs, freqs), dim=-1) - - # Compute and cache cos/sin - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - self.register_buffer("cos_cached", cos, persistent=False) - self.register_buffer("sin_cached", sin, persistent=False) - - def forward( - self, x: torch.Tensor, position_ids: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute rotary embeddings with caching optimization. - - Args: - x: Input tensor [batch_size, seq_len, hidden_size] - position_ids: Position indices [batch_size, seq_len] - - Returns: - Tuple of (cos, sin) tensors with shape [batch_size, seq_len, rotary_dim] - """ - seq_len = position_ids.shape[-1] - - # Extend cache if needed - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len, device=x.device) - - # Use cached values - index with position_ids for correctness - # For encoder, position_ids is typically [0, 1, 2, ..., seq_len-1] - # so we can directly slice the cache - cos = self.cos_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] - sin = self.sin_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Repeat key/value tensors for Grouped Query Attention. - - Args: - hidden_states: [batch, num_kv_heads, seq_len, head_dim] - n_rep: Number of repetitions - - Returns: - [batch, num_kv_heads * n_rep, seq_len, head_dim] - """ - if n_rep == 1: - return hidden_states - - batch, num_kv_heads, slen, head_dim = hidden_states.shape - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_kv_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) - - -class GlmAsrAttention(nn.Module): - """ - Optimized Multi-headed Grouped Query Attention for GLM-ASR. - Uses vLLM's QKVParallelLinear for better performance. - """ - - def __init__( - self, - config, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_kv_groups = self.num_heads // self.num_kv_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_rank = self.num_heads // self.tp_size - self.num_kv_heads_per_rank = max(1, self.num_kv_heads // self.tp_size) - - # Use QKVParallelLinear for fused QKV projection - # Note: GLM-ASR uses bias on Q and V, but not K - # For simplicity with QKVParallelLinear, we use bias=True for all - self.qkv_proj = QKVParallelLinear( - self.hidden_size, - self.head_dim, - self.num_heads, - self.num_kv_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - """ - Args: - hidden_states: [batch_size, seq_len, hidden_size] - position_embeddings: Tuple of (cos, sin) for RoPE - - Returns: - [batch_size, seq_len, hidden_size] - """ - batch_size, seq_len, _ = hidden_states.shape - - # QKV projection - fused for efficiency - qkv, _ = self.qkv_proj(hidden_states) - - # Split into q, k, v - q_size = self.num_heads_per_rank * self.head_dim - kv_size = self.num_kv_heads_per_rank * self.head_dim - q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) - - # Reshape and transpose - # [batch, seq, num_heads * head_dim] -> [batch, num_heads, seq, head_dim] - q = q.view( - batch_size, seq_len, self.num_heads_per_rank, self.head_dim - ).transpose(1, 2) - k = k.view( - batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim - ).transpose(1, 2) - # v doesn't go through RoPE, so make it contiguous now for SDPA - v = ( - v.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim) - .transpose(1, 2) - .contiguous() - ) - - # Apply rotary position embeddings - cos, sin = position_embeddings - q, k = _apply_rotary_pos_emb(q, k, cos, sin) - - # Handle GQA: repeat k/v if needed - if self.num_kv_groups > 1: - k = _repeat_kv(k, self.num_kv_groups) - v = _repeat_kv(v, self.num_kv_groups) - - # Ensure contiguous for optimal SDPA/Flash Attention performance - # Non-contiguous tensors can cause fallback to slower implementations - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - - # Scaled dot-product attention (uses Flash Attention when available) - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=False, - ) - - # Reshape back - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, seq_len, -1) - - # Output projection - output, _ = self.o_proj(attn_output) - return output - - -class GlmAsrMLP(nn.Module): - """ - Optimized MLP for GLM-ASR encoder. - Uses vLLM's parallel linear layers for better performance. - """ - - def __init__( - self, - config, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - self.fc1 = ColumnParallelLinear( - self.hidden_size, - self.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1", - ) - - self.act_fn = get_act_fn(config.hidden_act) - - self.fc2 = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2", - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states, _ = self.fc1(hidden_states) - hidden_states = self.act_fn(hidden_states) - hidden_states, _ = self.fc2(hidden_states) - return hidden_states - - -class GlmAsrEncoderLayer(nn.Module): - """ - Optimized Transformer encoder layer for GLM-ASR. - Combines attention and MLP with residual connections and layer norms. - """ - - def __init__( - self, - config, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = GlmAsrAttention( - config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - self.mlp = GlmAsrMLP( - config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - - layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5) - self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm( - self.hidden_size, eps=layer_norm_eps - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - ) -> torch.Tensor: - """ - Args: - hidden_states: [batch_size, seq_len, hidden_size] - position_embeddings: Tuple of (cos, sin) for RoPE - - Returns: - [batch_size, seq_len, hidden_size] - """ - # Self-attention with residual - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - - # MLP with residual - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - - -class GlmAsrEncoder(nn.Module): - """ - Optimized GLM-ASR Audio Encoder with vLLM native implementation. - - This encoder processes audio features through convolutional layers - followed by transformer layers with rotary position embeddings. - Optimized for performance with: - - QKVParallelLinear for fused attention projections - - Tensor parallelism support via ColumnParallelLinear/RowParallelLinear - - Quantization support - - Flash Attention (SDPA) - """ - - # Mapping for weight loading: transformers uses separate q/k/v, we use fused qkv - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - } - - def __init__( - self, - config, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.config = config - - # Convolutional feature extraction layers - self.conv1 = nn.Conv1d( - config.num_mel_bins, - config.hidden_size, - kernel_size=3, - padding=1, - ) - self.conv2 = nn.Conv1d( - config.hidden_size, - config.hidden_size, - kernel_size=3, - stride=2, - padding=1, - ) - - # Transformer encoder layers - self.layers = nn.ModuleList( - [ - GlmAsrEncoderLayer( - config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) - - # Final layer norm - layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5) - self.norm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps) - - # Rotary position embeddings - self.rotary_emb = GlmAsrRotaryEmbedding(config) - - # Pre-register position_ids buffer for efficiency - # This avoids creating a new tensor on every forward pass - self.register_buffer( - "position_ids", - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0), - persistent=False, - ) - - def _get_feat_extract_output_lengths( - self, input_lengths: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute the output length after convolutions. - - Args: - input_lengths: Input sequence lengths [batch_size] - - Returns: - Tuple of (output after conv1, output after conv2) - """ - # Conv1: kernel=3, stride=1, padding=1 - output_lengths = (input_lengths + 2 * 1 - 3) // 1 + 1 - - # Conv2: kernel=3, stride=2, padding=1 - output_lengths = (output_lengths + 2 * 1 - 3) // 2 + 1 - - return input_lengths, output_lengths - - def forward(self, input_features: torch.Tensor): - """ - Forward pass through the encoder. - - Args: - input_features: [batch_size, num_mel_bins, seq_len] - - Returns: - Object with .last_hidden_state attribute containing - [batch_size, seq_len', hidden_size] where seq_len' is - the sequence length after convolutions - """ - # Apply convolutional layers with GELU activation - hidden_states = torch.nn.functional.gelu(self.conv1(input_features)) - hidden_states = torch.nn.functional.gelu(self.conv2(hidden_states)) - - # Transpose to [batch_size, seq_len, hidden_size] - hidden_states = hidden_states.transpose(1, 2) - output_seq_len = hidden_states.shape[1] - - # Use pre-registered position_ids buffer (slice to actual seq_len) - position_ids = self.position_ids[:, :output_seq_len] - - # Get position embeddings - uses pre-computed cache - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # Apply transformer layers - for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states, position_embeddings) - - # Final layer norm - hidden_states = self.norm(hidden_states) - - # Return in a format compatible with transformers' BaseModelOutput - return _GlmAsrEncoderOutput(last_hidden_state=hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """Custom weight loading to handle q_proj/k_proj/v_proj -> qkv_proj mapping.""" - from vllm.model_executor.model_loader.weight_utils import default_weight_loader - - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Default weight loading for non-stacked params - if name.endswith(".bias") and name not in params_dict: - continue - if name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params From c3a81de3321b526e568c1d409b631749275d76a2 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 13:49:02 +0800 Subject: [PATCH 07/24] get rid of audioflamingo3 dependency Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 116 ++++++--------------- vllm/model_executor/models/glmasr_utils.py | 68 ++++++++++++ 2 files changed, 100 insertions(+), 84 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index b70fb9db1702..84cdf10b2296 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -39,6 +39,7 @@ ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, @@ -49,22 +50,17 @@ from vllm.transformers_utils.processor import cached_processor_from_config from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .audioflamingo3 import ( - AudioFlamingo3MultiModalDataParser, - AudioFlamingo3ProcessingInfo, -) -from .audioflamingo3 import ( - _audioflamingo3_field_config as _glmasr_field_config, -) from .glmasr_utils import ( DEFAULT_CONV_PARAMS, DEFAULT_MAX_AUDIO_LEN_S, DEFAULT_MERGE_FACTOR, + _apply_rotary_pos_emb, _flatten_audio_features_by_length, _get_audio_output_lengths_for_tower, _get_num_features_for_item, _group_audio_embeddings, _normalize_chunk_counts, + _repeat_kv, ) from .interfaces import ( MultiModalEmbeddings, @@ -78,56 +74,8 @@ logger = logging.getLogger(__name__) -# Optimized vLLM Native GlmAsrEncoder Implementation - - -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def _apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embeddings to query and key tensors. - - Follows transformers' apply_rotary_pos_emb exactly. - Supports partial rotary where only the first rotary_dim of head_dim is rotated. - - Args: - q: [batch, num_heads, seq_len, head_dim] - k: [batch, num_kv_heads, seq_len, head_dim] - cos: [batch, seq_len, rotary_dim] - sin: [batch, seq_len, rotary_dim] - """ - # unsqueeze_dim=1 to add head dimension: [batch, 1, seq_len, rotary_dim] - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - - # Get the rotary dimension from cos/sin - rotary_dim = cos.shape[-1] - - # Split into rotary and pass-through parts - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - # Apply rotary embeddings on the first half or full tensor - q_embed = (q_rot * cos) + (_rotate_half(q_rot) * sin) - k_embed = (k_rot * cos) + (_rotate_half(k_rot) * sin) - - # Concatenate back to full shape - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - - return q_embed, k_embed - +# Optimized vLLM Native GlmAsrEncoder Implementation class GlmAsrRotaryEmbedding(nn.Module): """ Rotary Position Embedding for GLM-ASR encoder. @@ -227,27 +175,6 @@ def forward( return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Repeat key/value tensors for Grouped Query Attention. - - Args: - hidden_states: [batch, num_kv_heads, seq_len, head_dim] - n_rep: Number of repetitions - - Returns: - [batch, num_kv_heads * n_rep, seq_len, head_dim] - """ - if n_rep == 1: - return hidden_states - - batch, num_kv_heads, slen, head_dim = hidden_states.shape - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_kv_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) - - class GlmAsrAttention(nn.Module): """ Optimized Multi-headed Grouped Query Attention for GLM-ASR. @@ -645,9 +572,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -# ============================================================================= # GPU-accelerated Whisper Feature Extractor -# ============================================================================= class GPUWhisperFeatureExtractor: @@ -913,7 +838,7 @@ def forward(self, audio_features: torch.Tensor) -> torch.Tensor: return hidden_states -class GlmAsrProcessingInfo(AudioFlamingo3ProcessingInfo): +class GlmAsrProcessingInfo(BaseProcessingInfo): def get_hf_config(self) -> GlmAsrConfig: return self.ctx.get_hf_config(GlmAsrConfig) @@ -921,11 +846,13 @@ def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor: return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs) def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: - # Reuse parent implementation, but add type annotation and assertion - feature_extractor = super().get_feature_extractor(**kwargs) - assert isinstance(feature_extractor, WhisperFeatureExtractor) + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor return feature_extractor + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None} + class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -956,7 +883,28 @@ def get_dummy_mm_data( } -class GlmAsrMultiModalDataParser(AudioFlamingo3MultiModalDataParser): +def _glmasr_field_config(hf_inputs: Mapping[str, torch.Tensor]): + chunk_counts = hf_inputs.get("chunk_counts") + if chunk_counts is not None: + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.flat_from_sizes( + "audio", chunk_counts, dim=0 + ), + feature_attention_mask=MultiModalFieldConfig.flat_from_sizes( + "audio", chunk_counts, dim=0 + ), + chunk_counts=MultiModalFieldConfig.batched("audio"), + ) + return dict( + audio_embeds=MultiModalFieldConfig.batched("audio"), + input_features=MultiModalFieldConfig.batched("audio"), + feature_attention_mask=MultiModalFieldConfig.batched("audio"), + chunk_counts=MultiModalFieldConfig.batched("audio"), + ) + + +class GlmAsrMultiModalDataParser(MultiModalDataParser): def _parse_audio_data( self, data: dict[str, torch.Tensor] | ModalityData[Any], diff --git a/vllm/model_executor/models/glmasr_utils.py b/vllm/model_executor/models/glmasr_utils.py index a00aeaad3a3f..2a459e8420ef 100644 --- a/vllm/model_executor/models/glmasr_utils.py +++ b/vllm/model_executor/models/glmasr_utils.py @@ -17,6 +17,74 @@ DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)] +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embeddings to query and key tensors. + + Follows transformers' apply_rotary_pos_emb exactly. + Supports partial rotary where only the first rotary_dim of head_dim is rotated. + + Args: + q: [batch, num_heads, seq_len, head_dim] + k: [batch, num_kv_heads, seq_len, head_dim] + cos: [batch, seq_len, rotary_dim] + sin: [batch, seq_len, rotary_dim] + """ + # unsqueeze_dim=1 to add head dimension: [batch, 1, seq_len, rotary_dim] + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + + # Get the rotary dimension from cos/sin + rotary_dim = cos.shape[-1] + + # Split into rotary and pass-through parts + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + # Apply rotary embeddings on the first half or full tensor + q_embed = (q_rot * cos) + (_rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (_rotate_half(k_rot) * sin) + + # Concatenate back to full shape + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + + return q_embed, k_embed + + +def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Repeat key/value tensors for Grouped Query Attention. + + Args: + hidden_states: [batch, num_kv_heads, seq_len, head_dim] + n_rep: Number of repetitions + + Returns: + [batch, num_kv_heads * n_rep, seq_len, head_dim] + """ + if n_rep == 1: + return hidden_states + + batch, num_kv_heads, slen, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_kv_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) + + def _calculate_conv_output_length( input_length: torch.Tensor, padding: int, kernel_size: int, stride: int ) -> torch.Tensor: From 3dd521756e79c3ee6f7a44a562944d13d25ae84c Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 13:54:55 +0800 Subject: [PATCH 08/24] get rid of self-implemented GPUWhisperFeatureExtractor Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 245 +-------------------------- 1 file changed, 7 insertions(+), 238 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 84cdf10b2296..c9900a76185b 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -572,201 +572,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -# GPU-accelerated Whisper Feature Extractor - - -class GPUWhisperFeatureExtractor: - """ - GPU-accelerated Whisper feature extractor using PyTorch. - Computes log-mel spectrogram matching WhisperFeatureExtractor output. - - This implementation reuses the mel filterbank from HuggingFace's - WhisperFeatureExtractor to ensure numerical precision (1e-5 tolerance). - The key optimization is caching the window and mel_filters tensors on GPU - to avoid repeated CPU->GPU transfers. - - Key parameters (Whisper defaults): - - n_fft: 400 (25ms window at 16kHz) - - hop_length: 160 (10ms hop at 16kHz) - - n_mels: 80 - - chunk_length: 30 seconds - - sampling_rate: 16000 - """ - - def __init__( - self, - hf_feature_extractor: WhisperFeatureExtractor, - device: str | torch.device = "cuda", - ): - # Copy parameters from HF feature extractor - self.feature_size = hf_feature_extractor.feature_size - self.sampling_rate = hf_feature_extractor.sampling_rate - self.hop_length = hf_feature_extractor.hop_length - self.chunk_length = hf_feature_extractor.chunk_length - self.n_fft = hf_feature_extractor.n_fft - self.padding_value = hf_feature_extractor.padding_value - self.device = torch.device(device) if isinstance(device, str) else device - - # Derived parameters - self.n_samples = self.chunk_length * self.sampling_rate # 480000 for 30s - self.nb_max_frames = self.n_samples // self.hop_length # 3000 frames - - # Store HF's mel_filters (numpy float64) for precision - # This is precomputed by HF using librosa-compatible mel_filter_bank - self._mel_filters_np: np.ndarray = hf_feature_extractor.mel_filters - - # Cached GPU tensors (lazily initialized) - self._window: torch.Tensor | None = None - self._mel_filters_T: torch.Tensor | None = None # Transposed & contiguous - self._current_device_str: str | None = None - - def _ensure_buffers(self, device: torch.device) -> None: - """Lazily initialize and cache buffers on the target device.""" - # Use string comparison for stable device checking - device_str = str(device) - if self._current_device_str == device_str: - return - - self._window = torch.hann_window(self.n_fft, device=device) - # Convert from numpy float64, transpose, and make contiguous for optimal matmul - # HF mel_filters is [n_mels, n_freqs], we need [n_freqs, n_mels] for matmul - # Using .contiguous() ensures optimal memory layout for GPU matmul - self._mel_filters_T = ( - torch.from_numpy(self._mel_filters_np) - .to(device=device, dtype=torch.float32) - .T.contiguous() - ) - self._current_device_str = device_str - - def __call__( - self, - raw_speech: list[np.ndarray] | np.ndarray | torch.Tensor, - sampling_rate: int | None = None, - padding: str = "max_length", - max_length: int | None = None, - return_attention_mask: bool = True, - return_tensors: str = "pt", - device: str | torch.device | None = None, - ) -> BatchFeature: - """ - Extract log-mel spectrogram features from audio. - - Args: - raw_speech: Audio waveform(s), can be list of arrays or batched - sampling_rate: Expected sample rate (must match self.sampling_rate) - padding: Padding strategy ('max_length' or 'longest') - max_length: Max samples (default: self.n_samples = 30s * 16kHz) - return_attention_mask: Whether to return attention mask - return_tensors: Output format ('pt' for PyTorch) - device: Device for computation (default: self.device) - - Returns: - BatchFeature with 'input_features' and optionally 'attention_mask' - """ - if sampling_rate is not None and sampling_rate != self.sampling_rate: - raise ValueError( - f"Expected sampling_rate={self.sampling_rate}, got {sampling_rate}" - ) - - device = torch.device(device) if device else self.device - max_length = max_length or self.n_samples - - # Convert inputs to list of 1D tensors - if isinstance(raw_speech, np.ndarray): - raw_speech = [raw_speech] if raw_speech.ndim == 1 else list(raw_speech) - elif isinstance(raw_speech, torch.Tensor): - raw_speech = ( - [raw_speech.numpy()] - if raw_speech.ndim == 1 - else [s.numpy() for s in raw_speech] - ) - - batch_size = len(raw_speech) - - # Get actual lengths before padding - lengths = [len(s) for s in raw_speech] - - # Pad/truncate to max_length - if padding == "max_length": - target_length = max_length - else: # 'longest' - target_length = min(max(lengths), max_length) - - # Create padded batch tensor - padded_waveforms = torch.zeros( - batch_size, target_length, dtype=torch.float32, device=device - ) - attention_mask = torch.zeros( - batch_size, target_length, dtype=torch.int32, device=device - ) - - for i, waveform in enumerate(raw_speech): - if isinstance(waveform, np.ndarray): - waveform = torch.from_numpy(waveform) - waveform = waveform.to(device=device, dtype=torch.float32) - - # Truncate if needed - actual_len = min(len(waveform), target_length) - padded_waveforms[i, :actual_len] = waveform[:actual_len] - attention_mask[i, :actual_len] = 1 - - # Extract features on GPU - input_features = self._extract_fbank_features(padded_waveforms) - - # Rescale attention mask from samples to frames - # STFT produces L//hop_length + 1 frames, but we drop the last one - frame_attention_mask = attention_mask[:, :: self.hop_length] - # Trim to match actual frame count (we drop last frame in _extract) - if attention_mask.shape[1] % self.hop_length != 0: - frame_attention_mask = frame_attention_mask[:, :-1] - - result: dict[str, Any] = {"input_features": input_features} - if return_attention_mask: - result["attention_mask"] = frame_attention_mask - - return BatchFeature(data=result, tensor_type=return_tensors) - - def _extract_fbank_features(self, waveforms: torch.Tensor) -> torch.Tensor: - """ - Compute log-mel spectrogram for batched waveforms. - - Args: - waveforms: [batch, samples] float32 tensor on target device - - Returns: - [batch, n_mels, frames] float32 tensor (log-mel spectrogram) - """ - device = waveforms.device - self._ensure_buffers(device) - - # STFT: [batch, samples] -> [batch, n_fft//2+1, frames] complex - stft = torch.stft( - waveforms, - n_fft=self.n_fft, - hop_length=self.hop_length, - window=self._window, - return_complex=True, - ) - - # Power spectrogram, drop last frame (matching HF implementation) - magnitudes = stft[..., :-1].abs() ** 2 # [batch, n_freqs, frames] - - # Apply mel filterbank: [n_freqs, n_mels] @ [batch, n_freqs, frames] - # -> [batch, n_mels, frames] - # _mel_filters_T is pre-transposed and contiguous for optimal performance - mel_spec = torch.matmul(self._mel_filters_T, magnitudes) - - # Log scale with floor - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - - # Per-sample normalization (max - 8.0 floor, then scale) - max_val = log_spec.amax(dim=(1, 2), keepdim=True) - log_spec = torch.maximum(log_spec, max_val - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - - return log_spec - - class GlmAsrFeatureInputs(TensorSchema): """ Dimensions: @@ -923,26 +728,8 @@ class GlmAsrMultiModalProcessor(BaseMultiModalProcessor["GlmAsrProcessingInfo"]) """ GLM-ASR processor that inherits directly from BaseMultiModalProcessor for better performance and cleaner implementation. - Uses GPU-accelerated feature extraction for improved throughput. """ - # Shared GPU feature extractor instance (lazy initialized) - _gpu_feature_extractor: GPUWhisperFeatureExtractor | None = None - - @classmethod - def _get_gpu_feature_extractor( - cls, - hf_feature_extractor: WhisperFeatureExtractor, - device: str = "cuda", - ) -> GPUWhisperFeatureExtractor: - """Get or create GPU feature extractor matching HF config.""" - if cls._gpu_feature_extractor is None: - cls._gpu_feature_extractor = GPUWhisperFeatureExtractor( - hf_feature_extractor=hf_feature_extractor, - device=device, - ) - return cls._gpu_feature_extractor - def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -1026,31 +813,13 @@ def _call_hf_processor( end = min((i + 1) * window_size, time_cap) flat_chunks.append(audio_el[start:end]) - # ===== GPU Feature Extraction ===== - # Check if CUDA is available, fallback to CPU if not - use_gpu = torch.cuda.is_available() - device = "cuda" if use_gpu else "cpu" - - if use_gpu: - # Use GPU-accelerated feature extractor - gpu_extractor = self._get_gpu_feature_extractor( - hf_feature_extractor, device=device - ) - audio_inputs = gpu_extractor( - flat_chunks, - sampling_rate=sampling_rate, - return_attention_mask=True, - return_tensors="pt", - ) - else: - # Fallback to HF CPU implementation - audio_inputs = hf_feature_extractor( - flat_chunks, - sampling_rate=sampling_rate, - return_tensors="pt", - padding=True, - return_attention_mask=True, - ) + audio_inputs = hf_feature_extractor( + flat_chunks, + sampling_rate=sampling_rate, + return_tensors="pt", + padding=True, + return_attention_mask=True, + ) # ===== Process attention mask ===== padding_mask = audio_inputs.pop("attention_mask") From e90a9fb659e715b2fd38de290591ee11af135ea8 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 14:49:53 +0800 Subject: [PATCH 09/24] remove logger and GPU-related comment Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index c9900a76185b..78077b044248 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, TypeAlias, cast @@ -72,8 +71,6 @@ from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .whisper import ISO639_1_SUPPORTED_LANGS -logger = logging.getLogger(__name__) - # Optimized vLLM Native GlmAsrEncoder Implementation class GlmAsrRotaryEmbedding(nn.Module): @@ -762,9 +759,6 @@ def _call_hf_processor( mm_kwargs: Mapping[str, Any], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - """ - Call processor with GPU-accelerated feature extraction. - """ # Normalize input: handle deprecated key and list conversion. if "audios" in mm_data: mm_data["audio"] = mm_data.pop("audios") @@ -778,12 +772,12 @@ def _call_hf_processor( prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - # Get processor for tokenizer and config + # ===== Initialize HF processor, feature extractor, tokenizer ===== processor = self.info.get_hf_processor(**mm_kwargs) hf_feature_extractor = processor.feature_extractor tokenizer = processor.tokenizer - # ===== Audio chunking (CPU, fast) ===== + # ===== Calculate chunk counts and prepare audio chunks ===== sampling_rate = hf_feature_extractor.sampling_rate chunk_length = hf_feature_extractor.chunk_length max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) @@ -813,6 +807,7 @@ def _call_hf_processor( end = min((i + 1) * window_size, time_cap) flat_chunks.append(audio_el[start:end]) + # ===== Extract audio features ===== audio_inputs = hf_feature_extractor( flat_chunks, sampling_rate=sampling_rate, From 063b4b5e7276e28574f8bca8e8a4f66c468bc5b9 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 15:19:11 +0800 Subject: [PATCH 10/24] go back to original implmentation of GlmAsrMultiModalProcessor._call_hf_processor Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 114 +++------------------------ 1 file changed, 11 insertions(+), 103 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 78077b044248..108eefc51765 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -751,7 +751,6 @@ def _calculate_chunk_counts( chunk_counts.append(min(n_chunks, max_windows)) return chunk_counts - # @torch.compile(fullgraph=True) def _call_hf_processor( self, prompt: str, @@ -772,114 +771,23 @@ def _call_hf_processor( prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - # ===== Initialize HF processor, feature extractor, tokenizer ===== + # Get processor for chunk counts calculation processor = self.info.get_hf_processor(**mm_kwargs) - hf_feature_extractor = processor.feature_extractor - tokenizer = processor.tokenizer - # ===== Calculate chunk counts and prepare audio chunks ===== - sampling_rate = hf_feature_extractor.sampling_rate - chunk_length = hf_feature_extractor.chunk_length - max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) - window_size = int(sampling_rate * chunk_length) - max_windows = int(max_audio_len // chunk_length) - - per_sample_windows: list[int] = [] - flat_chunks: list[np.ndarray] = [] - - for audio_el in audio_list: - # Convert to numpy if needed - if isinstance(audio_el, torch.Tensor): - audio_el = audio_el.numpy() - elif isinstance(audio_el, list): - audio_el = np.array(audio_el, dtype=np.float32) - - n_samples = int(audio_el.shape[0]) - n_win = max(1, (n_samples + window_size - 1) // window_size) - if n_win > max_windows: - n_win = max_windows - - per_sample_windows.append(n_win) - time_cap = min(n_samples, n_win * window_size) - - for i in range(n_win): - start = i * window_size - end = min((i + 1) * window_size, time_cap) - flat_chunks.append(audio_el[start:end]) - - # ===== Extract audio features ===== - audio_inputs = hf_feature_extractor( - flat_chunks, - sampling_rate=sampling_rate, - return_tensors="pt", - padding=True, - return_attention_mask=True, + # Call parent method + outputs = super()._call_hf_processor( + prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs ) - # ===== Process attention mask ===== - padding_mask = audio_inputs.pop("attention_mask") - input_features_mask = padding_mask + # Postprocess: rename mask and add chunk counts + if "input_feature_mask" in outputs: + outputs["feature_attention_mask"] = outputs.pop("input_feature_mask") - # ===== Compute audio token lengths ===== - chunk_lengths = padding_mask.sum(-1) # [num_chunks] - audio_lengths = torch.stack( - [ - chunk_lengths[ - sum(per_sample_windows[:i]) : sum(per_sample_windows[: i + 1]) - ].sum() - for i in range(len(per_sample_windows)) - ] + # Override chunk counts calculation with GLM-ASR specific logic + chunk_counts = self._calculate_chunk_counts( + audio_list, processor.feature_extractor, processor ) - - # Apply convolution formula to get token counts - merge_factor = 4 - for padding, kernel_size, stride in [(1, 3, 1), (1, 3, 2)]: - audio_lengths = ( - audio_lengths + 2 * padding - (kernel_size - 1) - 1 - ) // stride + 1 - audio_tokens_lengths = (audio_lengths - merge_factor) // merge_factor + 1 - - # ===== Expand audio tokens in text ===== - import regex as re - - audio_token = getattr(processor, "audio_token", "<|pad|>") - text_list = [prompt] - - for i, audio_length in enumerate(audio_tokens_lengths): - if i < len(text_list): - expanded = re.sub( - re.escape(audio_token), - audio_token * int(audio_length), - text_list[i], - ) - text_list[i] = expanded - - # ===== Tokenize text ===== - text_inputs = tokenizer( - text_list, - return_tensors="pt", - padding=True, - **tok_kwargs, - ) - - # ===== Combine outputs ===== - # Move input_features to CPU for compatibility - input_features = audio_inputs["input_features"] - if input_features.device.type != "cpu": - input_features = input_features.cpu() - if input_features_mask.device.type != "cpu": - input_features_mask = input_features_mask.cpu() - - outputs = BatchFeature( - data={ - **text_inputs, - "input_features": input_features, - "feature_attention_mask": input_features_mask, - }, - tensor_type="pt", - ) - - outputs["chunk_counts"] = torch.tensor(per_sample_windows, dtype=torch.long) + outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long) return outputs From d1ea079e91ec91a41691a6a8cc0c7f0d778eb535 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 15:30:58 +0800 Subject: [PATCH 11/24] handle sampling_rate Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 108eefc51765..d465bcfbcc5d 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -771,18 +771,28 @@ def _call_hf_processor( prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - # Get processor for chunk counts calculation - processor = self.info.get_hf_processor(**mm_kwargs) + # Handle sampling_rate + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) # Call parent method outputs = super()._call_hf_processor( - prompt=prompt, mm_data=mm_data, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, ) # Postprocess: rename mask and add chunk counts if "input_feature_mask" in outputs: outputs["feature_attention_mask"] = outputs.pop("input_feature_mask") + # Get processor for chunk counts calculation + processor = self.info.get_hf_processor(**mm_kwargs) + # Override chunk counts calculation with GLM-ASR specific logic chunk_counts = self._calculate_chunk_counts( audio_list, processor.feature_extractor, processor From ae68aad4dd719631e34624d369173c6f5b09f249 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 15:32:58 +0800 Subject: [PATCH 12/24] delete logger in utils Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/models/glmasr_utils.py b/vllm/model_executor/models/glmasr_utils.py index 2a459e8420ef..05b358f0c879 100644 --- a/vllm/model_executor/models/glmasr_utils.py +++ b/vllm/model_executor/models/glmasr_utils.py @@ -1,15 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging from collections.abc import Sequence from typing import cast import torch import torch.nn as nn -logger = logging.getLogger(__name__) - DEFAULT_MAX_AUDIO_LEN_S = 655 DEFAULT_MERGE_FACTOR = 4 # Default convolution parameters: (padding, kernel_size, stride) From a73cf85b8ec6f8db9a2f60e85f2f6601e331bcfe Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 16:51:30 +0800 Subject: [PATCH 13/24] try use vllm.applyrotaryemb Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 118 ++++++++++----------- vllm/model_executor/models/glmasr_utils.py | 47 -------- 2 files changed, 56 insertions(+), 109 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index d465bcfbcc5d..42155511bdd7 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -22,6 +22,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import ApplyRotaryEmb from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -53,7 +54,6 @@ DEFAULT_CONV_PARAMS, DEFAULT_MAX_AUDIO_LEN_S, DEFAULT_MERGE_FACTOR, - _apply_rotary_pos_emb, _flatten_audio_features_by_length, _get_audio_output_lengths_for_tower, _get_num_features_for_item, @@ -77,8 +77,8 @@ class GlmAsrRotaryEmbedding(nn.Module): """ Rotary Position Embedding for GLM-ASR encoder. - Optimized with pre-computed cos/sin cache for better performance. - Falls back to dynamic computation only when sequence length exceeds cache. + Pre-computes cos/sin cache and uses vLLM's ApplyRotaryEmb CustomOp + for efficient rotary embedding application. """ def __init__(self, config, device: torch.device | None = None): @@ -107,6 +107,7 @@ def __init__(self, config, device: torch.device | None = None): self.attention_scaling = 1.0 self.dim = dim + self.head_dim = head_dim self.base = base # Compute the inverse frequencies exactly as transformers does @@ -124,6 +125,13 @@ def __init__(self, config, device: torch.device | None = None): # Pre-compute cos/sin cache for efficiency self._set_cos_sin_cache(self.max_seq_len_cached, device) + # Use vLLM's ApplyRotaryEmb CustomOp for efficient rotary embedding + # enforce_enable=True ensures the op is always enabled (important for ViT) + self.apply_rotary_emb = ApplyRotaryEmb( + enforce_enable=True, + is_neox_style=True, + ) + def _set_cos_sin_cache( self, seq_len: int, device: torch.device | None = None ) -> None: @@ -134,48 +142,36 @@ def _set_cos_sin_cache( t = torch.arange(seq_len, device=device, dtype=torch.float32) # Compute frequencies: [seq_len, dim/2] freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32)) - # Double the frequencies: [seq_len, dim] - emb = torch.cat((freqs, freqs), dim=-1) - # Compute and cache cos/sin - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + # Compute and cache cos/sin (shape: [seq_len, dim/2]) + # ApplyRotaryEmb expects cos/sin with shape [seq_len, rotary_dim/2] + cos = freqs.cos() * self.attention_scaling + sin = freqs.sin() * self.attention_scaling self.register_buffer("cos_cached", cos, persistent=False) self.register_buffer("sin_cached", sin, persistent=False) - def forward( - self, x: torch.Tensor, position_ids: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def get_cos_sin(self, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: """ - Compute rotary embeddings with caching optimization. + Get cos and sin tensors for a given sequence length. Args: - x: Input tensor [batch_size, seq_len, hidden_size] - position_ids: Position indices [batch_size, seq_len] + seq_len: The sequence length to get embeddings for. Returns: - Tuple of (cos, sin) tensors with shape [batch_size, seq_len, rotary_dim] + Tuple of (cos, sin) tensors with shape [seq_len, dim/2] """ - seq_len = position_ids.shape[-1] - # Extend cache if needed if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len, device=x.device) - - # Use cached values - index with position_ids for correctness - # For encoder, position_ids is typically [0, 1, 2, ..., seq_len-1] - # so we can directly slice the cache - cos = self.cos_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] - sin = self.sin_cached[:seq_len].unsqueeze(0) # [1, seq_len, dim] + self._set_cos_sin_cache(seq_len, device=self.cos_cached.device) - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return self.cos_cached[:seq_len], self.sin_cached[:seq_len] class GlmAsrAttention(nn.Module): """ Optimized Multi-headed Grouped Query Attention for GLM-ASR. - Uses vLLM's QKVParallelLinear for better performance. + Uses vLLM's QKVParallelLinear and ApplyRotaryEmb for better performance. """ def __init__( @@ -219,15 +215,21 @@ def __init__( prefix=f"{prefix}.o_proj", ) + # Use vLLM's ApplyRotaryEmb CustomOp + # enforce_enable=True ensures the op is always enabled (important for ViT) + self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) + def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states: [batch_size, seq_len, hidden_size] - position_embeddings: Tuple of (cos, sin) for RoPE + cos: [seq_len, rotary_dim/2] - cosine part of rotary embeddings + sin: [seq_len, rotary_dim/2] - sine part of rotary embeddings Returns: [batch_size, seq_len, hidden_size] @@ -242,24 +244,21 @@ def forward( kv_size = self.num_kv_heads_per_rank * self.head_dim q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1) - # Reshape and transpose - # [batch, seq, num_heads * head_dim] -> [batch, num_heads, seq, head_dim] - q = q.view( - batch_size, seq_len, self.num_heads_per_rank, self.head_dim - ).transpose(1, 2) - k = k.view( - batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim - ).transpose(1, 2) - # v doesn't go through RoPE, so make it contiguous now for SDPA - v = ( - v.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim) - .transpose(1, 2) - .contiguous() - ) + # Reshape to [batch, seq, num_heads, head_dim] for ApplyRotaryEmb + q = q.view(batch_size, seq_len, self.num_heads_per_rank, self.head_dim) + k = k.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim) + v = v.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim) + + # Apply rotary position embeddings using vLLM's ApplyRotaryEmb + # ApplyRotaryEmb expects x: [batch, seq, heads, head_dim] + # cos/sin: [seq_len, rotary_dim/2] + q = self.apply_rotary_emb(q, cos, sin) + k = self.apply_rotary_emb(k, cos, sin) - # Apply rotary position embeddings - cos, sin = position_embeddings - q, k = _apply_rotary_pos_emb(q, k, cos, sin) + # Transpose to [batch, num_heads, seq, head_dim] for attention + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) # Handle GQA: repeat k/v if needed if self.num_kv_groups > 1: @@ -368,12 +367,14 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states: [batch_size, seq_len, hidden_size] - position_embeddings: Tuple of (cos, sin) for RoPE + cos: [seq_len, rotary_dim/2] - cosine part of rotary embeddings + sin: [seq_len, rotary_dim/2] - sine part of rotary embeddings Returns: [batch_size, seq_len, hidden_size] @@ -383,7 +384,8 @@ def forward( hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, - position_embeddings=position_embeddings, + cos=cos, + sin=sin, ) hidden_states = residual + hidden_states @@ -466,14 +468,6 @@ def __init__( # Rotary position embeddings self.rotary_emb = GlmAsrRotaryEmbedding(config) - # Pre-register position_ids buffer for efficiency - # This avoids creating a new tensor on every forward pass - self.register_buffer( - "position_ids", - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0), - persistent=False, - ) - def _get_feat_extract_output_lengths( self, input_lengths: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: @@ -514,15 +508,15 @@ def forward(self, input_features: torch.Tensor): hidden_states = hidden_states.transpose(1, 2) output_seq_len = hidden_states.shape[1] - # Use pre-registered position_ids buffer (slice to actual seq_len) - position_ids = self.position_ids[:, :output_seq_len] - - # Get position embeddings - uses pre-computed cache - position_embeddings = self.rotary_emb(hidden_states, position_ids) + # Get cos/sin from rotary embedding cache + cos, sin = self.rotary_emb.get_cos_sin(output_seq_len) + # Match dtype with hidden states + cos = cos.to(dtype=hidden_states.dtype) + sin = sin.to(dtype=hidden_states.dtype) # Apply transformer layers for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states, position_embeddings) + hidden_states = encoder_layer(hidden_states, cos, sin) # Final layer norm hidden_states = self.norm(hidden_states) diff --git a/vllm/model_executor/models/glmasr_utils.py b/vllm/model_executor/models/glmasr_utils.py index 05b358f0c879..098540878afe 100644 --- a/vllm/model_executor/models/glmasr_utils.py +++ b/vllm/model_executor/models/glmasr_utils.py @@ -14,53 +14,6 @@ DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)] -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def _apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary position embeddings to query and key tensors. - - Follows transformers' apply_rotary_pos_emb exactly. - Supports partial rotary where only the first rotary_dim of head_dim is rotated. - - Args: - q: [batch, num_heads, seq_len, head_dim] - k: [batch, num_kv_heads, seq_len, head_dim] - cos: [batch, seq_len, rotary_dim] - sin: [batch, seq_len, rotary_dim] - """ - # unsqueeze_dim=1 to add head dimension: [batch, 1, seq_len, rotary_dim] - cos = cos.unsqueeze(1) - sin = sin.unsqueeze(1) - - # Get the rotary dimension from cos/sin - rotary_dim = cos.shape[-1] - - # Split into rotary and pass-through parts - q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] - k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - - # Apply rotary embeddings on the first half or full tensor - q_embed = (q_rot * cos) + (_rotate_half(q_rot) * sin) - k_embed = (k_rot * cos) + (_rotate_half(k_rot) * sin) - - # Concatenate back to full shape - q_embed = torch.cat([q_embed, q_pass], dim=-1) - k_embed = torch.cat([k_embed, k_pass], dim=-1) - - return q_embed, k_embed - - def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ Repeat key/value tensors for Grouped Query Attention. From 4c2b5a6c768ba1f25b25567464db6ce99ade2a58 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 18:47:12 +0800 Subject: [PATCH 14/24] fix ApplyRotaryEmb import error Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 42155511bdd7..7ae97d11e439 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -22,7 +22,7 @@ RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import ApplyRotaryEmb +from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( From d5d930e54406fcbc02e0db18029b9202a0dae03e Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 18:59:35 +0800 Subject: [PATCH 15/24] fix Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 76 ++++++++++++++++++++++++---- 1 file changed, 65 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 7ae97d11e439..f9ea6f01c99e 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -56,7 +56,6 @@ DEFAULT_MERGE_FACTOR, _flatten_audio_features_by_length, _get_audio_output_lengths_for_tower, - _get_num_features_for_item, _group_audio_embeddings, _normalize_chunk_counts, _repeat_kv, @@ -781,8 +780,20 @@ def _call_hf_processor( ) # Postprocess: rename mask and add chunk counts + # Handle different key names from different transformers versions if "input_feature_mask" in outputs: outputs["feature_attention_mask"] = outputs.pop("input_feature_mask") + elif "feature_attention_mask" not in outputs and "input_features" in outputs: + # If no mask is provided, create one from input_features + input_features = outputs["input_features"] + if isinstance(input_features, torch.Tensor): + # Create a mask of all ones matching the sequence length + mask = torch.ones( + input_features.shape[0], + input_features.shape[-1], + dtype=torch.long, + ) + outputs["feature_attention_mask"] = mask # Get processor for chunk counts calculation processor = self.info.get_hf_processor(**mm_kwargs) @@ -819,22 +830,65 @@ def _get_prompt_updates( audio_token_id = processor.audio_token_id merge_factor = getattr(config, "merge_factor", DEFAULT_MERGE_FACTOR) + conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS) out_mm_data = out_mm_kwargs.get_data() feature_attention_mask = out_mm_data.get("feature_attention_mask") chunk_counts = out_mm_data.get("chunk_counts") - def get_replacement_glmasr(item_idx: int): - conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS) - audio_embeds = out_mm_data.get("audio_embeds") - num_features = _get_num_features_for_item( - feature_attention_mask, - chunk_counts, - item_idx, - audio_embeds, - merge_factor, - conv_params, + # Pre-compute audio output lengths if feature_attention_mask is available + audio_output_lengths: list[int] = [] + if feature_attention_mask is not None: + # Compute output lengths for all audio items + from .glmasr_utils import ( + _as_list_chunk_counts, + _get_audio_output_lengths_from_mask, ) + if chunk_counts is not None: + counts_list = _as_list_chunk_counts(chunk_counts) + start_idx = 0 + for count in counts_list: + end_idx = start_idx + count + if isinstance(feature_attention_mask, torch.Tensor): + mask = feature_attention_mask[start_idx:end_idx] + else: + mask = feature_attention_mask[start_idx:end_idx] + if isinstance(mask, list): + mask = torch.stack(mask) + lengths = _get_audio_output_lengths_from_mask( + mask, merge_factor, conv_params + ) + audio_output_lengths.append(int(lengths.sum().item())) + start_idx = end_idx + else: + # Single chunk per audio + for idx in range(len(feature_attention_mask)): + if isinstance(feature_attention_mask, torch.Tensor): + mask = feature_attention_mask[idx : idx + 1] + else: + mask = feature_attention_mask[idx] + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask) + mask = mask.unsqueeze(0) + lengths = _get_audio_output_lengths_from_mask( + mask, merge_factor, conv_params + ) + audio_output_lengths.append(int(lengths.sum().item())) + + def get_replacement_glmasr(item_idx: int): + # Use pre-computed lengths if available, otherwise fall back to audio_embeds + if audio_output_lengths: + num_features = audio_output_lengths[item_idx] + else: + audio_embeds = out_mm_data.get("audio_embeds") + if audio_embeds is not None: + embed = audio_embeds[item_idx] + num_features = embed.shape[0] + else: + raise ValueError( + "Either feature_attention_mask or audio_embeds must be provided" + ) + if num_features == 0: raise ValueError("Audio is too short") From 4eeb067c83e4a7f4990e530f7ef407df1ac7c211 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 19:34:13 +0800 Subject: [PATCH 16/24] fix ci error Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index f9ea6f01c99e..6928e39216d3 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -495,9 +495,9 @@ def forward(self, input_features: torch.Tensor): input_features: [batch_size, num_mel_bins, seq_len] Returns: - Object with .last_hidden_state attribute containing - [batch_size, seq_len', hidden_size] where seq_len' is - the sequence length after convolutions + _GlmAsrEncoderOutput: Object with .last_hidden_state attribute \ + containing [batch_size, seq_len', hidden_size] where seq_len' \ + is the sequence length after convolutions """ # Apply convolutional layers with GELU activation hidden_states = torch.nn.functional.gelu(self.conv1(input_features)) From 2a6b538acde362377557741959393c1dfeb5222b Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 20:00:30 +0800 Subject: [PATCH 17/24] rewrite RotaryEmbedding and add some docstring for readability Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 146 +++++++++++++++------------ 1 file changed, 84 insertions(+), 62 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 6928e39216d3..8d266987e93f 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -76,14 +76,14 @@ class GlmAsrRotaryEmbedding(nn.Module): """ Rotary Position Embedding for GLM-ASR encoder. - Pre-computes cos/sin cache and uses vLLM's ApplyRotaryEmb CustomOp - for efficient rotary embedding application. + Computes rotary position embeddings on-demand for efficiency. + Only caches inv_freq as a buffer; cos/sin are computed during forward + to avoid wasted computation during initialization and ensure correct + device placement. """ - def __init__(self, config, device: torch.device | None = None): + def __init__(self, config) -> None: super().__init__() - self.config = config - self.max_seq_len_cached = config.max_position_embeddings # Compute inverse frequencies following transformers implementation head_dim = getattr( @@ -107,64 +107,28 @@ def __init__(self, config, device: torch.device | None = None): self.dim = dim self.head_dim = head_dim - self.base = base - - # Compute the inverse frequencies exactly as transformers does - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, dim, 2, dtype=torch.int64).to( - device=device, dtype=torch.float - ) - / dim - ) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Pre-compute cos/sin cache for efficiency - self._set_cos_sin_cache(self.max_seq_len_cached, device) - - # Use vLLM's ApplyRotaryEmb CustomOp for efficient rotary embedding - # enforce_enable=True ensures the op is always enabled (important for ViT) - self.apply_rotary_emb = ApplyRotaryEmb( - enforce_enable=True, - is_neox_style=True, - ) - - def _set_cos_sin_cache( - self, seq_len: int, device: torch.device | None = None - ) -> None: - """Pre-compute cos and sin cache for given sequence length.""" - self.max_seq_len_cached = seq_len - - # Create position indices - t = torch.arange(seq_len, device=device, dtype=torch.float32) - # Compute frequencies: [seq_len, dim/2] - freqs = torch.outer(t, self.inv_freq.to(device=device, dtype=torch.float32)) - # Compute and cache cos/sin (shape: [seq_len, dim/2]) - # ApplyRotaryEmb expects cos/sin with shape [seq_len, rotary_dim/2] - cos = freqs.cos() * self.attention_scaling - sin = freqs.sin() * self.attention_scaling - - self.register_buffer("cos_cached", cos, persistent=False) - self.register_buffer("sin_cached", sin, persistent=False) + # Only cache inv_freq; cos/sin computed on-demand in correct device + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) - def get_cos_sin(self, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, seq_len: int) -> torch.Tensor: """ - Get cos and sin tensors for a given sequence length. + Compute rotary position frequencies for given sequence length. Args: - seq_len: The sequence length to get embeddings for. + seq_len: The sequence length to compute embeddings for. Returns: - Tuple of (cos, sin) tensors with shape [seq_len, dim/2] + Frequency tensor with shape [seq_len, dim/2]. Use .cos() and + .sin() to get the rotary embedding components. """ - # Extend cache if needed - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len, device=self.cos_cached.device) - - return self.cos_cached[:seq_len], self.sin_cached[:seq_len] + # Compute on the same device as inv_freq (automatically correct after .to()) + seq = torch.arange( + seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) + freqs = torch.outer(seq, self.inv_freq) + return freqs * self.attention_scaling class GlmAsrAttention(nn.Module): @@ -398,7 +362,17 @@ def forward( class _GlmAsrEncoderOutput: - """Simple output container compatible with transformers' BaseModelOutput.""" + """ + Simple output container compatible with transformers' BaseModelOutput. + + This lightweight container holds the encoder output and is compatible + with the transformers library's output format while being more efficient + than a full dataclass. + + Attributes: + last_hidden_state: Final layer hidden states from the encoder. + Shape: [batch_size, seq_len, hidden_size] + """ __slots__ = ("last_hidden_state",) @@ -507,11 +481,10 @@ def forward(self, input_features: torch.Tensor): hidden_states = hidden_states.transpose(1, 2) output_seq_len = hidden_states.shape[1] - # Get cos/sin from rotary embedding cache - cos, sin = self.rotary_emb.get_cos_sin(output_seq_len) - # Match dtype with hidden states - cos = cos.to(dtype=hidden_states.dtype) - sin = sin.to(dtype=hidden_states.dtype) + # Compute rotary position embeddings on-demand + freqs = self.rotary_emb(output_seq_len) + cos = freqs.cos().to(dtype=hidden_states.dtype) + sin = freqs.sin().to(dtype=hidden_states.dtype) # Apply transformer layers for encoder_layer in self.layers: @@ -605,6 +578,19 @@ class GlmAsrEmbeddingInputs(TensorSchema): class GlmAsrMultiModalProjector(nn.Module): + """ + Projects audio encoder outputs to language model hidden space. + + This projector uses a two-layer MLP to map audio features from the + encoder's intermediate size to the language model's hidden size. + Uses vLLM's parallel linear layers for tensor parallelism support. + + Architecture: + - Linear layer: intermediate_size -> hidden_size * 2 + - Activation function (e.g., GELU) + - Linear layer: hidden_size * 2 -> hidden_size + """ + def __init__( self, config: GlmAsrConfig, @@ -634,6 +620,13 @@ def forward(self, audio_features: torch.Tensor) -> torch.Tensor: class GlmAsrProcessingInfo(BaseProcessingInfo): + """ + Processing information provider for GLM-ASR model. + + Provides access to model configuration, processor, and feature extractor + needed for audio preprocessing and multimodal integration. + """ + def get_hf_config(self) -> GlmAsrConfig: return self.ctx.get_hf_config(GlmAsrConfig) @@ -650,6 +643,14 @@ def get_supported_mm_limits(self) -> Mapping[str, int | None]: class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]): + """ + Builder for dummy inputs used in profiling and testing. + + Generates dummy text prompts and audio data that match the expected + format for GLM-ASR model inputs. Used for memory profiling and + performance benchmarking. + """ + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) hf_processor = self.info.get_hf_processor() @@ -679,6 +680,20 @@ def get_dummy_mm_data( def _glmasr_field_config(hf_inputs: Mapping[str, torch.Tensor]): + """ + Configure multimodal field batching strategy for GLM-ASR. + + Determines how to batch audio inputs based on whether chunking is used. + When chunk_counts is present, features are flattened across chunks; + otherwise, they are batched normally. + + Args: + hf_inputs: Dictionary of preprocessed inputs from HuggingFace processor. + + Returns: + Dictionary mapping field names to MultiModalFieldConfig objects + that specify batching behavior. + """ chunk_counts = hf_inputs.get("chunk_counts") if chunk_counts is not None: return dict( @@ -700,6 +715,13 @@ def _glmasr_field_config(hf_inputs: Mapping[str, torch.Tensor]): class GlmAsrMultiModalDataParser(MultiModalDataParser): + """ + Custom parser for GLM-ASR multimodal data. + + Extends the base parser to handle GLM-ASR specific audio data formats, + including both pre-computed audio embeddings and raw audio features. + """ + def _parse_audio_data( self, data: dict[str, torch.Tensor] | ModalityData[Any], @@ -730,7 +752,6 @@ def _calculate_chunk_counts( feature_extractor: WhisperFeatureExtractor, processor: GlmAsrProcessor, ) -> list[int]: - """Calculate chunk counts for each audio.""" sampling_rate = feature_extractor.sampling_rate chunk_length = feature_extractor.chunk_length max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S) @@ -1054,6 +1075,7 @@ def get_language_model(self) -> torch.nn.Module: return self.language_model def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + "" audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] From 1e6355b6ed04d0dc38d1169c163da3064d51da28 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 20:10:05 +0800 Subject: [PATCH 18/24] remove unnecessary comment && rename var name for cos to rotary_pos_emb_cos Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 35 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 8d266987e93f..2b4cc33c3101 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -71,7 +71,6 @@ from .whisper import ISO639_1_SUPPORTED_LANGS -# Optimized vLLM Native GlmAsrEncoder Implementation class GlmAsrRotaryEmbedding(nn.Module): """ Rotary Position Embedding for GLM-ASR encoder. @@ -185,14 +184,14 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states: [batch_size, seq_len, hidden_size] - cos: [seq_len, rotary_dim/2] - cosine part of rotary embeddings - sin: [seq_len, rotary_dim/2] - sine part of rotary embeddings + rotary_pos_emb_cos: [seq_len, rotary_dim/2] - cosine of rotary embeddings + rotary_pos_emb_sin: [seq_len, rotary_dim/2] - sine of rotary embeddings Returns: [batch_size, seq_len, hidden_size] @@ -215,8 +214,8 @@ def forward( # Apply rotary position embeddings using vLLM's ApplyRotaryEmb # ApplyRotaryEmb expects x: [batch, seq, heads, head_dim] # cos/sin: [seq_len, rotary_dim/2] - q = self.apply_rotary_emb(q, cos, sin) - k = self.apply_rotary_emb(k, cos, sin) + q = self.apply_rotary_emb(q, rotary_pos_emb_cos, rotary_pos_emb_sin) + k = self.apply_rotary_emb(k, rotary_pos_emb_cos, rotary_pos_emb_sin) # Transpose to [batch, num_heads, seq, head_dim] for attention q = q.transpose(1, 2) @@ -330,14 +329,14 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states: [batch_size, seq_len, hidden_size] - cos: [seq_len, rotary_dim/2] - cosine part of rotary embeddings - sin: [seq_len, rotary_dim/2] - sine part of rotary embeddings + rotary_pos_emb_cos: [seq_len, rotary_dim/2] - cosine of rotary embeddings + rotary_pos_emb_sin: [seq_len, rotary_dim/2] - sine of rotary embeddings Returns: [batch_size, seq_len, hidden_size] @@ -347,8 +346,8 @@ def forward( hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, - cos=cos, - sin=sin, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, ) hidden_states = residual + hidden_states @@ -482,13 +481,15 @@ def forward(self, input_features: torch.Tensor): output_seq_len = hidden_states.shape[1] # Compute rotary position embeddings on-demand - freqs = self.rotary_emb(output_seq_len) - cos = freqs.cos().to(dtype=hidden_states.dtype) - sin = freqs.sin().to(dtype=hidden_states.dtype) + rotary_pos_emb = self.rotary_emb(output_seq_len) + rotary_pos_emb_cos = rotary_pos_emb.cos().to(dtype=hidden_states.dtype) + rotary_pos_emb_sin = rotary_pos_emb.sin().to(dtype=hidden_states.dtype) # Apply transformer layers for encoder_layer in self.layers: - hidden_states = encoder_layer(hidden_states, cos, sin) + hidden_states = encoder_layer( + hidden_states, rotary_pos_emb_cos, rotary_pos_emb_sin + ) # Final layer norm hidden_states = self.norm(hidden_states) From 55407c9c4784be74b439d0714b2c6229d3eb7daa Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Tue, 6 Jan 2026 20:30:32 +0800 Subject: [PATCH 19/24] fix CI error Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 2b4cc33c3101..b5be42892ec3 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -460,7 +460,7 @@ def _get_feat_extract_output_lengths( return input_lengths, output_lengths - def forward(self, input_features: torch.Tensor): + def forward(self, input_features: torch.Tensor) -> _GlmAsrEncoderOutput: """ Forward pass through the encoder. @@ -680,7 +680,9 @@ def get_dummy_mm_data( } -def _glmasr_field_config(hf_inputs: Mapping[str, torch.Tensor]): +def _glmasr_field_config( + hf_inputs: Mapping[str, torch.Tensor], +) -> dict[str, MultiModalFieldConfig]: """ Configure multimodal field batching strategy for GLM-ASR. @@ -692,8 +694,8 @@ def _glmasr_field_config(hf_inputs: Mapping[str, torch.Tensor]): hf_inputs: Dictionary of preprocessed inputs from HuggingFace processor. Returns: - Dictionary mapping field names to MultiModalFieldConfig objects - that specify batching behavior. + Dictionary mapping field names to MultiModalFieldConfig objects \ + that specify batching behavior. """ chunk_counts = hf_inputs.get("chunk_counts") if chunk_counts is not None: From 312db98d23980281e2746950a927379cd628d3ba Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Wed, 7 Jan 2026 10:16:05 +0800 Subject: [PATCH 20/24] use vllm.attention.layers.mm_encoder_attention for GlmAsrEncoder's Attention Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 53 +++++++++------------- vllm/model_executor/models/glmasr_utils.py | 21 --------- 2 files changed, 21 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index b5be42892ec3..701a6164593b 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -11,6 +11,7 @@ from transformers.models.glmasr import GlmAsrConfig, GlmAsrProcessor from transformers.models.whisper import WhisperFeatureExtractor +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size @@ -58,7 +59,6 @@ _get_audio_output_lengths_for_tower, _group_audio_embeddings, _normalize_chunk_counts, - _repeat_kv, ) from .interfaces import ( MultiModalEmbeddings, @@ -133,7 +133,10 @@ def forward(self, seq_len: int) -> torch.Tensor: class GlmAsrAttention(nn.Module): """ Optimized Multi-headed Grouped Query Attention for GLM-ASR. - Uses vLLM's QKVParallelLinear and ApplyRotaryEmb for better performance. + + Uses vLLM's QKVParallelLinear for fused projections, ApplyRotaryEmb for + rotary position embeddings, and MMEncoderAttention for hardware-optimized + attention computation with automatic backend selection. """ def __init__( @@ -146,11 +149,10 @@ def __init__( self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads + self.num_kv_heads = getattr( + config, "num_key_value_heads", config.num_attention_heads + ) self.head_dim = self.hidden_size // self.num_heads - self.num_kv_groups = self.num_heads // self.num_kv_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_rank = self.num_heads // self.tp_size @@ -181,6 +183,15 @@ def __init__( # enforce_enable=True ensures the op is always enabled (important for ViT) self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True) + # Use vLLM's MMEncoderAttention for hardware-optimized attention + # Automatically selects Flash Attention, SDPA, or Pallas based on device + self.attn = MMEncoderAttention( + num_heads=self.num_heads_per_rank, + head_size=self.head_dim, + num_kv_heads=self.num_kv_heads_per_rank, + prefix=f"{prefix}.attn", + ) + def forward( self, hidden_states: torch.Tensor, @@ -217,33 +228,11 @@ def forward( q = self.apply_rotary_emb(q, rotary_pos_emb_cos, rotary_pos_emb_sin) k = self.apply_rotary_emb(k, rotary_pos_emb_cos, rotary_pos_emb_sin) - # Transpose to [batch, num_heads, seq, head_dim] for attention - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Handle GQA: repeat k/v if needed - if self.num_kv_groups > 1: - k = _repeat_kv(k, self.num_kv_groups) - v = _repeat_kv(v, self.num_kv_groups) - - # Ensure contiguous for optimal SDPA/Flash Attention performance - # Non-contiguous tensors can cause fallback to slower implementations - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - - # Scaled dot-product attention (uses Flash Attention when available) - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=False, - ) + # MMEncoderAttention expects [batch, seq, num_heads, head_dim] + # It handles GQA internally via repeat_interleave + attn_output = self.attn(q, k, v) - # Reshape back - attn_output = attn_output.transpose(1, 2).contiguous() + # Reshape back to [batch, seq, hidden_size] attn_output = attn_output.view(batch_size, seq_len, -1) # Output projection diff --git a/vllm/model_executor/models/glmasr_utils.py b/vllm/model_executor/models/glmasr_utils.py index 098540878afe..492e4b354b5e 100644 --- a/vllm/model_executor/models/glmasr_utils.py +++ b/vllm/model_executor/models/glmasr_utils.py @@ -14,27 +14,6 @@ DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)] -def _repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - Repeat key/value tensors for Grouped Query Attention. - - Args: - hidden_states: [batch, num_kv_heads, seq_len, head_dim] - n_rep: Number of repetitions - - Returns: - [batch, num_kv_heads * n_rep, seq_len, head_dim] - """ - if n_rep == 1: - return hidden_states - - batch, num_kv_heads, slen, head_dim = hidden_states.shape - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_kv_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_kv_heads * n_rep, slen, head_dim) - - def _calculate_conv_output_length( input_length: torch.Tensor, padding: int, kernel_size: int, stride: int ) -> torch.Tensor: From 53deab353525c5428c642755b7778cfa9aeb524b Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Wed, 7 Jan 2026 10:17:53 +0800 Subject: [PATCH 21/24] rename rot_embedding, attention, MLP Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 701a6164593b..5b0eb1a71f37 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -71,7 +71,7 @@ from .whisper import ISO639_1_SUPPORTED_LANGS -class GlmAsrRotaryEmbedding(nn.Module): +class GlmAsrEncoderRotaryEmbedding(nn.Module): """ Rotary Position Embedding for GLM-ASR encoder. @@ -130,9 +130,9 @@ def forward(self, seq_len: int) -> torch.Tensor: return freqs * self.attention_scaling -class GlmAsrAttention(nn.Module): +class GlmAsrEncoderAttention(nn.Module): """ - Optimized Multi-headed Grouped Query Attention for GLM-ASR. + Optimized Multi-headed Grouped Query Attention for GLM-ASR encoder. Uses vLLM's QKVParallelLinear for fused projections, ApplyRotaryEmb for rotary position embeddings, and MMEncoderAttention for hardware-optimized @@ -240,7 +240,7 @@ def forward( return output -class GlmAsrMLP(nn.Module): +class GlmAsrEncoderMLP(nn.Module): """ Optimized MLP for GLM-ASR encoder. Uses vLLM's parallel linear layers for better performance. @@ -297,13 +297,13 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GlmAsrAttention( + self.self_attn = GlmAsrEncoderAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.mlp = GlmAsrMLP( + self.mlp = GlmAsrEncoderMLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", @@ -427,7 +427,7 @@ def __init__( self.norm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps) # Rotary position embeddings - self.rotary_emb = GlmAsrRotaryEmbedding(config) + self.rotary_emb = GlmAsrEncoderRotaryEmbedding(config) def _get_feat_extract_output_lengths( self, input_lengths: torch.Tensor From 0f81e2506823aec4d75d51d6a590d455863ba65c Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Wed, 7 Jan 2026 11:08:06 +0800 Subject: [PATCH 22/24] clean code Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 5b0eb1a71f37..ccf98181d231 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -1059,15 +1059,12 @@ def _process_audio_input( chunk_embeddings = torch.split( masked_audio_features, audio_output_lengths.flatten().tolist() ) - result = _group_audio_embeddings(chunk_embeddings, chunk_counts) - - return result + return _group_audio_embeddings(chunk_embeddings, chunk_counts) def get_language_model(self) -> torch.nn.Module: return self.language_model def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: - "" audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] From 492080c5d512451b9377a1093e461c54f2a0faf1 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Wed, 7 Jan 2026 11:39:07 +0800 Subject: [PATCH 23/24] accept reivews Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index ccf98181d231..2a1a872ef000 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -624,9 +624,7 @@ def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor: return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs) def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: - hf_processor = self.get_hf_processor(**kwargs) - feature_extractor = hf_processor.feature_extractor - return feature_extractor + return self.get_hf_processor(**kwargs).feature_extractor def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None} @@ -858,16 +856,13 @@ def _get_prompt_updates( ) if chunk_counts is not None: - counts_list = _as_list_chunk_counts(chunk_counts) start_idx = 0 - for count in counts_list: + for count in _as_list_chunk_counts(chunk_counts): end_idx = start_idx + count - if isinstance(feature_attention_mask, torch.Tensor): - mask = feature_attention_mask[start_idx:end_idx] - else: - mask = feature_attention_mask[start_idx:end_idx] - if isinstance(mask, list): - mask = torch.stack(mask) + mask = feature_attention_mask[start_idx:end_idx] + if isinstance(mask, list): + mask = torch.stack(mask) + lengths = _get_audio_output_lengths_from_mask( mask, merge_factor, conv_params ) @@ -876,13 +871,9 @@ def _get_prompt_updates( else: # Single chunk per audio for idx in range(len(feature_attention_mask)): - if isinstance(feature_attention_mask, torch.Tensor): - mask = feature_attention_mask[idx : idx + 1] - else: - mask = feature_attention_mask[idx] - if not isinstance(mask, torch.Tensor): - mask = torch.tensor(mask) - mask = mask.unsqueeze(0) + mask = feature_attention_mask[idx : idx + 1] + if isinstance(mask, list): + mask = torch.tensor(mask).unsqueeze(0) lengths = _get_audio_output_lengths_from_mask( mask, merge_factor, conv_params ) From 1944742120e1692819f1ae544fb92abfa3108f58 Mon Sep 17 00:00:00 2001 From: JaredforReal Date: Wed, 7 Jan 2026 11:55:51 +0800 Subject: [PATCH 24/24] accept reivew Signed-off-by: JaredforReal --- vllm/model_executor/models/glmasr.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 2a1a872ef000..42b8d54aaa26 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -442,12 +442,12 @@ def _get_feat_extract_output_lengths( Tuple of (output after conv1, output after conv2) """ # Conv1: kernel=3, stride=1, padding=1 - output_lengths = (input_lengths + 2 * 1 - 3) // 1 + 1 + output_lengths_conv1 = (input_lengths + 2 * 1 - 3) // 1 + 1 # Conv2: kernel=3, stride=2, padding=1 - output_lengths = (output_lengths + 2 * 1 - 3) // 2 + 1 + output_lengths_conv2 = (output_lengths_conv1 + 2 * 1 - 3) // 2 + 1 - return input_lengths, output_lengths + return output_lengths_conv1, output_lengths_conv2 def forward(self, input_features: torch.Tensor) -> _GlmAsrEncoderOutput: """