Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions vllm/model_executor/models/qwen3_omni_moe_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging.version import Version
from transformers import PretrainedConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
Qwen3OmniMoeConfig,
Expand Down Expand Up @@ -711,11 +713,12 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray:
return x

# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
feature_extractor = self.info.get_feature_extractor()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment so we can revert once it's fixed on transformers side?

hop_length = feature_extractor.hop_length
if audios:
# NOTE: Qwen3-Omni processor accept "audio"
# To make sure the cache works with padding=True, we pre-padded
# the audio to multiple of hop_length.
hop_length = self.info.get_feature_extractor().hop_length
mm_data["audio"] = [
pad_to_hop_length(audio, hop_length)
if isinstance(audio, np.ndarray)
Expand All @@ -725,6 +728,14 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray:
mm_kwargs = dict(
**mm_kwargs,
)
# TODO(Isotr0py): Remove this patch after upstream fix PR
# released and Transformers version update:
# https://github.com/huggingface/transformers/pull/41473
if (
Version(TRANSFORMERS_VERSION) < Version("4.58.0")
and "truncation" not in mm_kwargs
):
mm_kwargs["truncation"] = False

hf_inputs = super()._call_hf_processor(
prompt=prompt,
Expand All @@ -738,7 +749,6 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray:
and "feature_attention_mask" in hf_inputs
and (audios := mm_data.get("audio", []))
):
hop_length = self.info.get_feature_extractor().hop_length
audio_num_frames = []
for _, audio in enumerate(audios):
audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio)
Expand All @@ -747,6 +757,10 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray:
if audio_length % hop_length == 0
else (audio_length // hop_length - 1)
)
if mm_kwargs.get("truncation", False):
num_frame = min(
num_frame, feature_extractor.n_samples // hop_length
)
audio_num_frames.append(num_frame)
hf_inputs["feature_attention_mask"] = [
torch.ones(num_frame) for num_frame in audio_num_frames
Expand Down