From 70532f65d300495fc3271cf85e0be9bf7f004c2e Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 14 Oct 2025 23:53:31 +0800 Subject: [PATCH 1/4] fix qwen3-omni audio truncation issue Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index d565a0108432..bcb0ccdcbb47 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -711,11 +711,12 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: return x # NOTE: WhisperFeatureExtractor cannot handle empty list of audios + feature_extractor = self.info.get_feature_extractor() + hop_length = feature_extractor.hop_length if audios: # NOTE: Qwen3-Omni processor accept "audio" # To make sure the cache works with padding=True, we pre-padded # the audio to multiple of hop_length. - hop_length = self.info.get_feature_extractor().hop_length mm_data["audio"] = [ pad_to_hop_length(audio, hop_length) if isinstance(audio, np.ndarray) @@ -738,7 +739,6 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: and "feature_attention_mask" in hf_inputs and (audios := mm_data.get("audio", [])) ): - hop_length = self.info.get_feature_extractor().hop_length audio_num_frames = [] for _, audio in enumerate(audios): audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio) @@ -747,6 +747,10 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: if audio_length % hop_length == 0 else (audio_length // hop_length - 1) ) + if mm_kwargs.get("truncation", True): + num_frame = min( + num_frame, feature_extractor.n_samples // hop_length + ) audio_num_frames.append(num_frame) hf_inputs["feature_attention_mask"] = [ torch.ones(num_frame) for num_frame in audio_num_frames From 9281723466a0b8ac29838ea781b1691d17acc455 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 15 Oct 2025 00:02:36 +0800 Subject: [PATCH 2/4] enforce truncation=False Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index bcb0ccdcbb47..4f9ab7b432fb 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -726,6 +726,7 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: mm_kwargs = dict( **mm_kwargs, ) + mm_kwargs["truncation"] = True hf_inputs = super()._call_hf_processor( prompt=prompt, @@ -747,10 +748,6 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: if audio_length % hop_length == 0 else (audio_length // hop_length - 1) ) - if mm_kwargs.get("truncation", True): - num_frame = min( - num_frame, feature_extractor.n_samples // hop_length - ) audio_num_frames.append(num_frame) hf_inputs["feature_attention_mask"] = [ torch.ones(num_frame) for num_frame in audio_num_frames From bd61ae5ec708f889c9dd7938da680e9d39c4ea07 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 15 Oct 2025 00:04:02 +0800 Subject: [PATCH 3/4] ooops Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen3_omni_moe_thinker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 4f9ab7b432fb..6e1c80e494ef 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -726,7 +726,7 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: mm_kwargs = dict( **mm_kwargs, ) - mm_kwargs["truncation"] = True + mm_kwargs["truncation"] = False hf_inputs = super()._call_hf_processor( prompt=prompt, From 7fd991acdb657daf24aadd0bc06e2b405689f2b2 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 15 Oct 2025 01:25:04 +0800 Subject: [PATCH 4/4] add comment and correct truncation case Signed-off-by: Isotr0py --- .../models/qwen3_omni_moe_thinker.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 6e1c80e494ef..d5a75e75aa43 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -30,7 +30,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging.version import Version from transformers import PretrainedConfig +from transformers import __version__ as TRANSFORMERS_VERSION from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( Qwen3OmniMoeConfig, @@ -726,7 +728,14 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: mm_kwargs = dict( **mm_kwargs, ) - mm_kwargs["truncation"] = False + # TODO(Isotr0py): Remove this patch after upstream fix PR + # released and Transformers version update: + # https://github.com/huggingface/transformers/pull/41473 + if ( + Version(TRANSFORMERS_VERSION) < Version("4.58.0") + and "truncation" not in mm_kwargs + ): + mm_kwargs["truncation"] = False hf_inputs = super()._call_hf_processor( prompt=prompt, @@ -748,6 +757,10 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray: if audio_length % hop_length == 0 else (audio_length // hop_length - 1) ) + if mm_kwargs.get("truncation", False): + num_frame = min( + num_frame, feature_extractor.n_samples // hop_length + ) audio_num_frames.append(num_frame) hf_inputs["feature_attention_mask"] = [ torch.ones(num_frame) for num_frame in audio_num_frames