Skip to content

[MM][Feat] Add support for audio in video in Qwen2.5-Omni#26156

Closed
wwl2755 wants to merge 1 commit intovllm-project:mainfrom
wwl2755:mm-omni
Closed

[MM][Feat] Add support for audio in video in Qwen2.5-Omni#26156
wwl2755 wants to merge 1 commit intovllm-project:mainfrom
wwl2755:mm-omni

Conversation

@wwl2755
Copy link
Copy Markdown
Contributor

@wwl2755 wwl2755 commented Oct 3, 2025

Fix some of #23888

Enable audio in video in Qwen2.5-Omni in V1 engine.

CC: @ywang96 @DarkLight1337 @Isotr0py

Test

python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video

INFO 10-03 06:37:53 [llm.py:306] Supported_tasks: ['generate']
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.96s/it]
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.19s/it, est. speed input: 2270.81 toks/s, output: 83.17 toks/s]
The video shows a baby sitting on a bed, wearing glasses and a light blue sleeveless top. The baby is holding and flipping through the pages of a book. There's a crib in the background with some clothes on it. The baby seems to be enjoying the book. The baby says "嗯" which means "um" in English. So, the baby is saying "嗯" while looking at the book. What do you think the baby might be thinking about in the book?

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
@mergify mergify bot added documentation Improvements or additions to documentation qwen Related to Qwen models v1 labels Oct 3, 2025
Comment on lines +937 to +939
grid_thw_list = video_grid_thw.tolist()
if isinstance(grid_thw_list, list) and len(grid_thw_list) == 1:
grid_thw_list = grid_thw_list[0]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
grid_thw_list = video_grid_thw.tolist()
if isinstance(grid_thw_list, list) and len(grid_thw_list) == 1:
grid_thw_list = grid_thw_list[0]
grid_thw_list = video_grid_thw.tolist()
if len(grid_thw_list) == 1:
grid_thw_list = grid_thw_list[0]

The isinstance check is redundant

if isinstance(grid_thw_list, list) and len(grid_thw_list) == 1:
grid_thw_list = grid_thw_list[0]
grid_t, grid_h, grid_w = grid_thw_list
grid_t, grid_h, grid_w = int(grid_t), int(grid_h), int(grid_w)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can perform the int conversion and removal of batch dimension before converting to list

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for audio in video in the Qwen2.5-Omni model for the V1 engine. The changes are mostly in qwen2_5_omni_thinker.py to handle the new use_audio_in_video flag, including logic for interleaving audio and video embeddings. The implementation is largely correct, but I've identified a critical issue with handling batched requests that could lead to a crash, and a high-severity maintainability concern regarding the manual handling of special token embeddings. My review includes suggestions to address these points.

Comment on lines +1033 to +1034
use_audio_in_video = kwargs.get("use_audio_in_video", False)
use_audio_in_video = bool(use_audio_in_video.item())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for determining use_audio_in_video is incorrect and will fail for batched requests.

  1. kwargs.get("use_audio_in_video", False) is problematic. If the key is not present, it returns False, and False.item() will raise an AttributeError.
  2. If multiple requests are batched, kwargs.get("use_audio_in_video") will be a tensor with multiple elements. Calling .item() on it will raise a ValueError.

The logic should correctly handle the case where the key is missing and where the tensor is batched. Since the current architecture does not support mixed batches of use_audio_in_video, we should enforce that all requests in a batch have the same setting. A more robust implementation would be to check if all values in the tensor are the same before proceeding.

Suggested change
use_audio_in_video = kwargs.get("use_audio_in_video", False)
use_audio_in_video = bool(use_audio_in_video.item())
use_audio_in_video_tensor = kwargs.get("use_audio_in_video")
use_audio_in_video = False
if use_audio_in_video_tensor is not None:
# This assumes all requests in a batch have the same setting.
# We take the value from the first item.
val = use_audio_in_video_tensor[0] if use_audio_in_video_tensor.ndim > 0 else use_audio_in_video_tensor
use_audio_in_video = bool(val.item())

Comment on lines +985 to +1009
# TODO: this should be moved to the placeholder mechanism
# Get the text embeddings for audio_bos and audio_eos tokens
thinker_config = self.config
# <|audio_bos|>
audio_start_token_id = thinker_config.audio_start_token_id
# <|audio_eos|>
audio_end_token_id = thinker_config.audio_end_token_id

embed_layer = self.language_model.model.embed_tokens
device = merged_embedding.device
dtype = merged_embedding.dtype

# Get inferred embeddings for special tokens
audio_bos_ids = torch.tensor([audio_start_token_id], device=device)
audio_eos_ids = torch.tensor([audio_end_token_id], device=device)
audio_bos_emb = embed_layer(audio_bos_ids).to(
dtype) # [1, hidden_size]
audio_eos_emb = embed_layer(audio_eos_ids).to(
dtype) # [1, hidden_size]

# Add special token embeddings at beginning and end
# Structure: [audio_bos] + [interleaved] + [audio_eos]
merged_embedding = torch.cat(
[audio_bos_emb, merged_embedding, audio_eos_emb], dim=0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The manual handling of audio_bos and audio_eos token embeddings is a maintainability concern. This logic is inconsistent with how special tokens are typically handled through the placeholder mechanism in vLLM, which makes the code more fragile to future changes in tokenization or embedding logic. As noted in the TODO comment, this should be refactored to use the placeholder mechanism, for example, by modifying omni_get_updates_use_audio_in_video to include the special tokens and letting the standard machinery handle the embedding.

# Check if use_audio_in_video is enabled - if so, we need to
# interleave audio and video embeddings into a single tensor
use_audio_in_video = False
if "video" in mm_input_by_modality and "audio" in mm_input_by_modality:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I think this will never be hit in V1 engine because we only call get_multimodal_embeddings on one modality at a time

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is a bit confusing as in this design we treat audio as "residing" in video, so it would be seen as video but has input_audio_features.

@wwl2755
Copy link
Copy Markdown
Contributor Author

wwl2755 commented Oct 3, 2025

Thank you @DarkLight1337 for the early review!

After offline-syncing with @ywang96, he may have an alternative and cleaner method to support this. Will get back to this when we finalized our design.

@mergify
Copy link
Copy Markdown

mergify bot commented Oct 8, 2025

Documentation preview: https://vllm--26156.org.readthedocs.build/en/26156/

@mergify
Copy link
Copy Markdown

mergify bot commented Oct 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 8, 2025
@sleepy-dev-bin
Copy link
Copy Markdown

please Merge this to releases/v0.11.1

@DarkLight1337
Copy link
Copy Markdown
Member

Closing as superseded by #27721

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation needs-rebase qwen Related to Qwen models v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants