diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md index 16d44cbadbc9..d8fb50d7fe55 100644 --- a/examples/offline_inference/qwen2_5_omni/README.md +++ b/examples/offline_inference/qwen2_5_omni/README.md @@ -11,12 +11,10 @@ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ # Read vision and audio inputs from a single video file # NOTE: V1 engine does not support interleaved modalities yet. -VLLM_USE_V1=0 \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ -q use_audio_in_video # Multiple audios -VLLM_USE_V1=0 \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ -q multi_audios ``` diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index 6fbe1303f431..991e47023339 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -7,7 +7,6 @@ from typing import NamedTuple -import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset @@ -72,11 +71,7 @@ def get_use_audio_in_video_query() -> QueryResult: ) asset = VideoAsset(name="baby_reading", num_frames=16) audio = asset.get_audio(sampling_rate=16000) - assert not envs.VLLM_USE_V1, ( - "V1 does not support use_audio_in_video. " - "Please launch this example with " - "`VLLM_USE_V1=0`." - ) + return QueryResult( inputs={ "prompt": prompt, @@ -125,7 +120,7 @@ def get_multi_audios_query() -> QueryResult: def main(args): - model_name = "Qwen/Qwen2.5-Omni-7B" + model_name = "Qwen/Qwen2.5-Omni-3B" query_result = query_map[args.query_type]() llm = LLM( @@ -138,7 +133,7 @@ def main(args): # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. - sampling_params = SamplingParams(temperature=0.2, max_tokens=64) + sampling_params = SamplingParams(temperature=0.2, max_tokens=128) outputs = llm.generate(query_result.inputs, sampling_params=sampling_params) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index a5d6004faf38..3f327634f4f4 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -23,7 +23,6 @@ """Inference-only Qwen2.5-Omni model (thinker part).""" from collections.abc import Callable, Iterable, Mapping, Sequence -from copy import copy from functools import partial from typing import Annotated, Any, Literal @@ -84,6 +83,7 @@ PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate, + PromptUpdateDetails, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -400,17 +400,31 @@ def _maybe_apply_prompt_updates( self._validate_mm_placeholders( mm_placeholders, mm_item_counts, - use_audio_in_video=use_audio_in_video, ) else: - prompt_ids, mm_placeholders = self._apply_prompt_updates( - prompt_ids, - mm_prompt_updates, - ) + # When use_audio_in_video=True, audio tokens are not in the prompt, + # so we need to filter out audio updates before applying replacements + if use_audio_in_video and "audio" in mm_prompt_updates: + # Remove audio from prompt updates (it won't match anything) + filtered_updates = { + k: v for k, v in mm_prompt_updates.items() if k != "audio" + } + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + filtered_updates, + ) + # Derive audio placeholders from video placeholders + mm_placeholders = self._derive_audio_from_video_placeholders( + mm_placeholders, mm_prompt_updates + ) + else: + prompt_ids, mm_placeholders = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + ) self._validate_mm_placeholders( mm_placeholders, mm_item_counts, - use_audio_in_video=use_audio_in_video, ) return prompt_ids, mm_placeholders @@ -532,13 +546,17 @@ def get_replacement_qwen2_vision(item_idx: int, modality: str): thinker_config = self.info.get_hf_config() def get_replacement_qwen2_use_audio_in_video(item_idx: int): + """ + Generate interleaved placeholder for video when use_audio_in_video=True. + Audio placeholders will be derived from video in + _derive_audio_from_video_placeholders(). + """ nonlocal audio_in_video_item_idx audio_num_features = audio_output_lengths[ audio_in_video_item_idx + item_idx ] video_grid_thw = out_mm_data["video_grid_thw"][item_idx] - audio_in_video_item_idx += 1 second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None) @@ -547,13 +565,22 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): else: video_second_per_grid_t = 1.0 - return self.omni_get_updates_use_audio_in_video( + # Generate interleaved audio and video token sequence + placeholder = self.omni_get_updates_use_audio_in_video( thinker_config=thinker_config, audio_len=audio_num_features, video_grid_thw=video_grid_thw, video_second_per_grid_t=video_second_per_grid_t, ) + # Mark video token positions with is_embed mask + return PromptUpdateDetails.select_token_id( + placeholder, embed_token_id=video_token_id + ) + + # Set up replacement functions + # Note: In use_audio_in_video mode, audio updates are filtered out before + # replacement, so audio_replacement_fn won't be called video_replacement_fn = ( get_replacement_qwen2_use_audio_in_video if use_audio_in_video @@ -578,6 +605,51 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): ), ] + def _derive_audio_from_video_placeholders( + self, + placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_prompt_updates: MultiModalPromptUpdates, + ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: + """ + Helper to derive audio placeholders from video placeholders when + use_audio_in_video=True. + """ + if "video" not in placeholders: + return placeholders + + # Validate audio and video counts match + num_videos = len(placeholders["video"]) + num_audios = len(mm_prompt_updates.get("audio", [])) + if num_audios != num_videos: + raise ValueError( + f"use_audio_in_video requires equal number of audio and video items, " + f"got {num_audios=}, {num_videos=}" + ) + + tokenizer = self.info.get_tokenizer() + processor = self.info.get_hf_processor() + audio_token_id = tokenizer.get_vocab()[processor.audio_token] + + result_placeholders = dict(placeholders) + audio_placeholders = [] + + # Each video is paired with one audio + for video_idx, video_placeholder in enumerate(placeholders["video"]): + # Create is_embed mask selecting only audio tokens + audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id + + audio_placeholder = PlaceholderFeaturesInfo( + modality="audio", + item_idx=video_idx, + start_idx=video_placeholder.start_idx, + tokens=video_placeholder.tokens, + is_embed=audio_is_embed, + ) + audio_placeholders.append(audio_placeholder) + + result_placeholders["audio"] = audio_placeholders + return result_placeholders + def _apply_hf_processor_main( self, prompt: str | list[int], @@ -636,19 +708,6 @@ def _apply_hf_processor_mm_only( return mm_processed_data - def _validate_mm_placeholders( - self, - mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], - mm_item_counts: Mapping[str, int], - use_audio_in_video: bool = False, - ) -> None: - if use_audio_in_video: - mm_item_counts = copy(mm_item_counts) - if "video" in mm_item_counts: - assert "audio" in mm_item_counts - mm_item_counts["audio"] -= mm_item_counts["video"] - super()._validate_mm_placeholders(mm_placeholders, mm_item_counts) - class Qwen2_5OmniConditionalGenerationMixin: def _validate_and_reshape_mm_tensor( @@ -930,6 +989,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.audio_tower = None + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): @@ -938,6 +1003,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + attn_backend_override=attn_backend_override, ) else: self.visual = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6759fe630e62..09b0b5a8cbe5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -865,8 +865,10 @@ def _init_mrope_positions(self, req_state: CachedRequestState): second_per_grid_ts.append(t) if (t := mm_input.get("audio_feature_lengths")) is not None: audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True + # Check for use_audio_in_video + use_audio_in_video_value = mm_input.get("use_audio_in_video") + if use_audio_in_video_value is not None: + use_audio_in_video = bool(use_audio_in_video_value.item()) assert supports_mrope(self.get_model()), "M-RoPE support is not implemented."