-
-
Notifications
You must be signed in to change notification settings - Fork 15.8k
[MM][Feat] Add support for audio in video in Qwen2.5-Omni
#26156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -53,7 +53,8 @@ | |||||||||||||||||||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||||||||||||||||||||
| from vllm.multimodal.inputs import (ImageItem, ModalityData, | ||||||||||||||||||||
| MultiModalDataDict, MultiModalFieldConfig, | ||||||||||||||||||||
| MultiModalKwargsItems, NestedTensors) | ||||||||||||||||||||
| MultiModalKwargsItems, MultiModalUUIDDict, | ||||||||||||||||||||
| NestedTensors) | ||||||||||||||||||||
| from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, | ||||||||||||||||||||
| ModalityDataItems, MultiModalDataItems, | ||||||||||||||||||||
| MultiModalDataParser) | ||||||||||||||||||||
|
|
@@ -121,11 +122,24 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, | |||||||||||||||||||
|
|
||||||||||||||||||||
| num_videos = len(video_grid_sizes) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Check if audio is embedded in video | ||||||||||||||||||||
| use_audio_in_video_tensor = hf_inputs.get("use_audio_in_video") | ||||||||||||||||||||
| use_audio_in_video = False | ||||||||||||||||||||
| if use_audio_in_video_tensor is not None: | ||||||||||||||||||||
| use_audio_in_video = bool(use_audio_in_video_tensor.item()) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # When use_audio_in_video=True, audio fields should be grouped | ||||||||||||||||||||
| # with video modality | ||||||||||||||||||||
| # This way audio and video data are in the same mm_kwargs item | ||||||||||||||||||||
| audio_modality = "video" if use_audio_in_video else "audio" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return dict( | ||||||||||||||||||||
| input_audio_features=MultiModalFieldConfig.flat_from_sizes( | ||||||||||||||||||||
| "audio", audio_feature_lengths, dim=1), | ||||||||||||||||||||
| feature_attention_mask=MultiModalFieldConfig.batched("audio"), | ||||||||||||||||||||
| audio_feature_lengths=MultiModalFieldConfig.batched("audio"), | ||||||||||||||||||||
| audio_modality, audio_feature_lengths, dim=1), | ||||||||||||||||||||
| feature_attention_mask=MultiModalFieldConfig.batched( | ||||||||||||||||||||
| audio_modality), | ||||||||||||||||||||
| audio_feature_lengths=MultiModalFieldConfig.batched( | ||||||||||||||||||||
| audio_modality), | ||||||||||||||||||||
| pixel_values=MultiModalFieldConfig.flat_from_sizes( | ||||||||||||||||||||
| "image", image_pixel_grid_sizes), | ||||||||||||||||||||
| image_embeds=MultiModalFieldConfig.flat_from_sizes( | ||||||||||||||||||||
|
|
@@ -309,6 +323,44 @@ def _get_mm_fields_config( | |||||||||||||||||||
| self.info.get_hf_config().vision_config.spatial_merge_size)( | ||||||||||||||||||||
| hf_inputs) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _cached_apply_hf_processor( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| prompt: Union[str, list[int]], | ||||||||||||||||||||
| mm_data_items: MultiModalDataItems, | ||||||||||||||||||||
| hf_processor_mm_kwargs: Mapping[str, object], | ||||||||||||||||||||
| tokenization_kwargs: Mapping[str, object], | ||||||||||||||||||||
| *, | ||||||||||||||||||||
| mm_uuids: Optional[MultiModalUUIDDict] = None, | ||||||||||||||||||||
| ): | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Override to bypass caching when use_audio_in_video=True. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| When use_audio_in_video=True, audio and video are tightly coupled and | ||||||||||||||||||||
| the caching logic doesn't handle this properly (hash/data mismatch). | ||||||||||||||||||||
| For now, bypass caching in this case. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| use_audio_in_video = hf_processor_mm_kwargs.get( | ||||||||||||||||||||
| "use_audio_in_video", False) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if use_audio_in_video: | ||||||||||||||||||||
| # Bypass caching for use_audio_in_video case | ||||||||||||||||||||
| return self._apply_hf_processor( | ||||||||||||||||||||
| prompt=prompt, | ||||||||||||||||||||
| mm_data_items=mm_data_items, | ||||||||||||||||||||
| hf_processor_mm_kwargs=hf_processor_mm_kwargs, | ||||||||||||||||||||
| tokenization_kwargs=tokenization_kwargs, | ||||||||||||||||||||
| mm_uuids=mm_uuids, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Normal caching path for other cases | ||||||||||||||||||||
| return super()._cached_apply_hf_processor( | ||||||||||||||||||||
| prompt, | ||||||||||||||||||||
| mm_data_items, | ||||||||||||||||||||
| hf_processor_mm_kwargs, | ||||||||||||||||||||
| tokenization_kwargs, | ||||||||||||||||||||
| mm_uuids=mm_uuids, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _maybe_apply_prompt_updates( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| mm_items: MultiModalDataItems, | ||||||||||||||||||||
|
|
@@ -321,8 +373,8 @@ def _maybe_apply_prompt_updates( | |||||||||||||||||||
| Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| mm_item_counts = mm_items.get_all_counts() | ||||||||||||||||||||
| self._validate_mm_kwargs(mm_kwargs, mm_item_counts) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Detect use_audio_in_video first | ||||||||||||||||||||
| use_audio_in_video = False | ||||||||||||||||||||
| if "video" in mm_kwargs: | ||||||||||||||||||||
| video_items = [ | ||||||||||||||||||||
|
|
@@ -333,6 +385,18 @@ def _maybe_apply_prompt_updates( | |||||||||||||||||||
| use_audio_in_video = all(item["use_audio_in_video"].data | ||||||||||||||||||||
| for item in video_items) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # When use_audio_in_video=True, audio items are merged into video | ||||||||||||||||||||
| # mm_kwargs | ||||||||||||||||||||
| # So we need to adjust mm_item_counts to reflect this | ||||||||||||||||||||
| adjusted_item_counts = mm_item_counts.copy() | ||||||||||||||||||||
| if use_audio_in_video and "audio" in adjusted_item_counts \ | ||||||||||||||||||||
| and "video" in adjusted_item_counts: | ||||||||||||||||||||
| # Audio items are now part of video mm_kwargs, so remove them | ||||||||||||||||||||
| # from the count | ||||||||||||||||||||
| del adjusted_item_counts["audio"] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| self._validate_mm_kwargs(mm_kwargs, adjusted_item_counts) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if is_update_applied: | ||||||||||||||||||||
| mm_placeholders = self._find_mm_placeholders( | ||||||||||||||||||||
| prompt_ids, | ||||||||||||||||||||
|
|
@@ -840,6 +904,111 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: | |||||||||||||||||||
| "audio"] = self._parse_and_validate_audio_input(**kwargs) | ||||||||||||||||||||
| return mm_input_by_modality | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _interleave_audio_video_embeddings( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| audio_embeddings: tuple[torch.Tensor, ...], | ||||||||||||||||||||
| video_embeddings: tuple[torch.Tensor, ...], | ||||||||||||||||||||
| mm_input_by_modality: dict, | ||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Interleave audio and video embeddings based on temporal chunks. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| When use_audio_in_video=True, audio and video are split into temporal | ||||||||||||||||||||
| chunks and interleaved: [video_chunk1, audio_chunk1, video_chunk2, | ||||||||||||||||||||
| audio_chunk2, ...]. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| # Extract single tensors (we expect one video and one audio) | ||||||||||||||||||||
| audio_emb = audio_embeddings[0] # [audio_len, hidden_size] | ||||||||||||||||||||
| video_emb = video_embeddings[0] # [video_len, hidden_size] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Get chunking parameters from config | ||||||||||||||||||||
| thinker_config = self.config | ||||||||||||||||||||
| seconds_per_chunk = thinker_config.seconds_per_chunk | ||||||||||||||||||||
| spatial_merge_size = thinker_config.vision_config.spatial_merge_size | ||||||||||||||||||||
| tokens_per_second = getattr(thinker_config.vision_config, | ||||||||||||||||||||
| "tokens_per_second", 25) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Get video grid dimensions from mm_input | ||||||||||||||||||||
| video_input = mm_input_by_modality["video"] | ||||||||||||||||||||
| video_grid_thw = video_input["video_grid_thw"] | ||||||||||||||||||||
| video_second_per_grid_t = video_input.get("video_second_per_grid_t", | ||||||||||||||||||||
| 1.0) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| grid_thw_list = video_grid_thw.tolist() | ||||||||||||||||||||
| if isinstance(grid_thw_list, list) and len(grid_thw_list) == 1: | ||||||||||||||||||||
| grid_thw_list = grid_thw_list[0] | ||||||||||||||||||||
| grid_t, grid_h, grid_w = grid_thw_list | ||||||||||||||||||||
| grid_t, grid_h, grid_w = int(grid_t), int(grid_h), int(grid_w) | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can perform the int conversion and removal of batch dimension before converting to list |
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Calculate temporal chunks | ||||||||||||||||||||
| t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) | ||||||||||||||||||||
| t_index = (torch.arange(grid_t, device=video_emb.device) * | ||||||||||||||||||||
| video_second_per_grid_t * tokens_per_second).long() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Split t_index into chunks | ||||||||||||||||||||
| t_index_split_chunk = MRotaryEmbedding._split_list_into_ranges( | ||||||||||||||||||||
| t_index, t_ntoken_per_chunk) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Interleave embeddings chunk by chunk | ||||||||||||||||||||
| interleaved_chunks = [] | ||||||||||||||||||||
| audio_start_idx = 0 | ||||||||||||||||||||
| video_start_idx = 0 | ||||||||||||||||||||
|
|
||||||||||||||||||||
| for chunk_idx, t_chunk in enumerate(t_index_split_chunk): | ||||||||||||||||||||
| # Calculate video tokens for this chunk | ||||||||||||||||||||
| vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( | ||||||||||||||||||||
| spatial_merge_size**2) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Get video chunk | ||||||||||||||||||||
| video_chunk = video_emb[video_start_idx:video_start_idx + | ||||||||||||||||||||
| vision_ntoken_per_chunk] | ||||||||||||||||||||
| interleaved_chunks.append(video_chunk) | ||||||||||||||||||||
| video_start_idx += vision_ntoken_per_chunk | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Get audio chunk | ||||||||||||||||||||
| audio_chunk_size = min(t_ntoken_per_chunk, | ||||||||||||||||||||
| audio_emb.shape[0] - audio_start_idx) | ||||||||||||||||||||
| if audio_chunk_size > 0: | ||||||||||||||||||||
| audio_chunk = audio_emb[audio_start_idx:audio_start_idx + | ||||||||||||||||||||
| audio_chunk_size] | ||||||||||||||||||||
| interleaved_chunks.append(audio_chunk) | ||||||||||||||||||||
| audio_start_idx += audio_chunk_size | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Add any remaining audio | ||||||||||||||||||||
| if audio_start_idx < audio_emb.shape[0]: | ||||||||||||||||||||
| remaining_audio = audio_emb[audio_start_idx:] | ||||||||||||||||||||
| interleaved_chunks.append(remaining_audio) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Concatenate all chunks | ||||||||||||||||||||
| merged_embedding = torch.cat(interleaved_chunks, dim=0) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # TODO: this should be moved to the placeholder mechanism | ||||||||||||||||||||
| # Get the text embeddings for audio_bos and audio_eos tokens | ||||||||||||||||||||
| thinker_config = self.config | ||||||||||||||||||||
| # <|audio_bos|> | ||||||||||||||||||||
| audio_start_token_id = thinker_config.audio_start_token_id | ||||||||||||||||||||
| # <|audio_eos|> | ||||||||||||||||||||
| audio_end_token_id = thinker_config.audio_end_token_id | ||||||||||||||||||||
|
|
||||||||||||||||||||
| embed_layer = self.language_model.model.embed_tokens | ||||||||||||||||||||
| device = merged_embedding.device | ||||||||||||||||||||
| dtype = merged_embedding.dtype | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Get inferred embeddings for special tokens | ||||||||||||||||||||
| audio_bos_ids = torch.tensor([audio_start_token_id], device=device) | ||||||||||||||||||||
| audio_eos_ids = torch.tensor([audio_end_token_id], device=device) | ||||||||||||||||||||
| audio_bos_emb = embed_layer(audio_bos_ids).to( | ||||||||||||||||||||
| dtype) # [1, hidden_size] | ||||||||||||||||||||
| audio_eos_emb = embed_layer(audio_eos_ids).to( | ||||||||||||||||||||
| dtype) # [1, hidden_size] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Add special token embeddings at beginning and end | ||||||||||||||||||||
| # Structure: [audio_bos] + [interleaved] + [audio_eos] | ||||||||||||||||||||
| merged_embedding = torch.cat( | ||||||||||||||||||||
| [audio_bos_emb, merged_embedding, audio_eos_emb], dim=0) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
Comment on lines
+985
to
+1009
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The manual handling of |
||||||||||||||||||||
| return merged_embedding | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def get_language_model(self) -> torch.nn.Module: | ||||||||||||||||||||
| return self.language_model | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -855,19 +1024,50 @@ def get_multimodal_embeddings(self, | |||||||||||||||||||
| # tensor corresponding to a multimodal data item (image or video). | ||||||||||||||||||||
| multimodal_embeddings: tuple[torch.Tensor, ...] = () | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # NOTE: It is important to iterate over the keys in this dictionary | ||||||||||||||||||||
| # to preserve the order of the modalities. | ||||||||||||||||||||
| for modality in mm_input_by_modality: | ||||||||||||||||||||
| multimodal_input = mm_input_by_modality[modality] | ||||||||||||||||||||
| if modality == "image": | ||||||||||||||||||||
| vision_embeddings = self._process_image_input(multimodal_input) | ||||||||||||||||||||
| multimodal_embeddings += vision_embeddings | ||||||||||||||||||||
| if modality == "video": | ||||||||||||||||||||
| video_embeddings = self._process_video_input(multimodal_input) | ||||||||||||||||||||
| multimodal_embeddings += video_embeddings | ||||||||||||||||||||
| if modality == "audio": | ||||||||||||||||||||
| audio_embeddings = self._process_audio_input(multimodal_input) | ||||||||||||||||||||
| multimodal_embeddings += audio_embeddings | ||||||||||||||||||||
| # Check if use_audio_in_video is enabled - if so, we need to | ||||||||||||||||||||
| # interleave audio and video embeddings into a single tensor | ||||||||||||||||||||
| use_audio_in_video = False | ||||||||||||||||||||
| if "video" in mm_input_by_modality and "audio" in mm_input_by_modality: | ||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tbh I think this will never be hit in V1 engine because we only call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it is a bit confusing as in this design we treat audio as "residing" in video, so it would be seen as video but has input_audio_features. |
||||||||||||||||||||
| # use_audio_in_video comes from kwargs, not from the video input | ||||||||||||||||||||
| # itself | ||||||||||||||||||||
| use_audio_in_video = kwargs.get("use_audio_in_video", False) | ||||||||||||||||||||
| use_audio_in_video = bool(use_audio_in_video.item()) | ||||||||||||||||||||
|
Comment on lines
+1033
to
+1034
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current logic for determining
The logic should correctly handle the case where the key is missing and where the tensor is batched. Since the current architecture does not support mixed batches of
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| if use_audio_in_video: | ||||||||||||||||||||
| # Process audio and video separately | ||||||||||||||||||||
| audio_embeddings = self._process_audio_input( | ||||||||||||||||||||
| mm_input_by_modality["audio"]) | ||||||||||||||||||||
| video_embeddings = self._process_video_input( | ||||||||||||||||||||
| mm_input_by_modality["video"]) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Interleave audio and video embeddings | ||||||||||||||||||||
| merged_embedding = self._interleave_audio_video_embeddings( | ||||||||||||||||||||
| audio_embeddings, video_embeddings, mm_input_by_modality) | ||||||||||||||||||||
| multimodal_embeddings += (merged_embedding, ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Process images if present | ||||||||||||||||||||
| if "image" in mm_input_by_modality: | ||||||||||||||||||||
| image_embeddings = self._process_image_input( | ||||||||||||||||||||
| mm_input_by_modality["image"]) | ||||||||||||||||||||
| multimodal_embeddings += image_embeddings | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| # Normal processing without interleaving | ||||||||||||||||||||
| # NOTE: It is important to iterate over the keys in this dictionary | ||||||||||||||||||||
| # to preserve the order of the modalities. | ||||||||||||||||||||
| for modality in mm_input_by_modality: | ||||||||||||||||||||
| multimodal_input = mm_input_by_modality[modality] | ||||||||||||||||||||
| if modality == "image": | ||||||||||||||||||||
| vision_embeddings = self._process_image_input( | ||||||||||||||||||||
| multimodal_input) | ||||||||||||||||||||
| multimodal_embeddings += vision_embeddings | ||||||||||||||||||||
| if modality == "video": | ||||||||||||||||||||
| video_embeddings = self._process_video_input( | ||||||||||||||||||||
| multimodal_input) | ||||||||||||||||||||
| multimodal_embeddings += video_embeddings | ||||||||||||||||||||
| if modality == "audio": | ||||||||||||||||||||
| audio_embeddings = self._process_audio_input( | ||||||||||||||||||||
| multimodal_input) | ||||||||||||||||||||
| multimodal_embeddings += audio_embeddings | ||||||||||||||||||||
| return multimodal_embeddings | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # TODO (ywang96): support overlapping modality embeddings so that | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
isinstancecheck is redundant