Skip to content

[Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni #33605

Merged
ywang96 merged 8 commits intovllm-project:mainfrom
linyueqian:fix-qwen25-omni-audio-in-video
Feb 4, 2026
Merged

[Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni #33605
ywang96 merged 8 commits intovllm-project:mainfrom
linyueqian:fix-qwen25-omni-audio-in-video

Conversation

@linyueqian
Copy link
Copy Markdown
Contributor

@linyueqian linyueqian commented Feb 2, 2026

Purpose

Fix bugs preventing use_audio_in_video=True from working correctly with Qwen2.5-Omni and Qwen3-Omni.

Bug 1: KeyError: 'audio' in MultiModalBudget

Both Qwen2_5OmniThinkerProcessingInfo and Qwen3OmniMoeThinkerProcessingInfo inherit get_mm_max_tokens_per_item from Qwen2VLProcessingInfo, which only returns {"image": ..., "video": ...} without an "audio" key. This causes a KeyError when MultiModalBudget.__init__ tries to look up the audio budget. Fixed by overriding get_mm_max_tokens_per_item in both models to include audio token budget computed from the WhisperFeatureExtractor config. Fixed by #33634

Bug 2: Embedding merge misalignment with interleaved audio-in-video tokens

When use_audio_in_video=True, the HF processor interleaves video and audio tokens in chunks: [video_chunk1, audio_chunk1, video_chunk2, audio_chunk2, ...]. However, _gather_mm_embeddings provides embeddings as separate contiguous tensors [all_video_embeds, all_audio_embeds] with an ORed is_multimodal mask. The default masked_scatter_ fills True positions sequentially, placing video embeddings at audio token positions and vice versa.

Fixed by overriding embed_input_ids in both models to detect interleaved video+audio tokens using input_ids, build per-modality masks (is_video, is_audio), and scatter each modality's embeddings separately. For Qwen3-Omni, the deepstack vision path also required fixing — its is_vision mask was built using sequential position tracking which similarly breaks under interleaving; now uses input_ids-based masks when interleaved.

This is the same underlying issue identified in #32994 but solved locally in the model without changing the shared multimodal interface.

Related PRs: #27721, #32772, #32994

Test Plan

# Qwen2.5-Omni
python -m vllm.entrypoints.openai.api_server \
  --model Qwen/Qwen2.5-Omni-7B \
  --max-model-len 5632 \
  --max-num-seqs 5 \
  --limit-mm-per-prompt '{"audio": 1, "video": 1}' \
  --trust-remote-code \
  --port 8000

# Qwen3-Omni
python -m vllm.entrypoints.openai.api_server \
  --model Qwen/Qwen3-Omni-30B-A3B-Instruct \
  --max-model-len 5632 \
  --max-num-seqs 5 \
  --limit-mm-per-prompt '{"audio": 1, "video": 1}' \
  --trust-remote-code \
  --port 8000

# Send video + audio request with use_audio_in_video=True via OpenAI API
# (use mm_processor_kwargs to enable interleaving)

Offline inference with examples/offline_inference/qwen2_5_omni/only_thinker.py also tested.

Test Results

Qwen2.5-Omni

Audio only — correctly transcribes "Mary had a little lamb":

"...Mary had a little lamb its fleece was white as snow and everywhere that Mary went the lamb was sure to go."

Video + Audio (use_audio_in_video=True) — correctly processes both modalities (3228 prompt tokens, confirming interleaving):

"The audio contains a speech in English, saying '...Mary had a little lamb its feet were white as snow and everywhere that Mary went the lamb was sure to go.'"

Video + Audio (use_audio_in_video=False) — also works correctly (2827 prompt tokens, no interleaving):

"The audio contains a speech with the words '...mary had a little lamb its fleece was white as snow and everywhere that mary went the lamb was sure to go.'"

Qwen3-Omni

Before fix (use_audio_in_video=True):

[Music] (garbled embeddings due to deepstack + merge misalignment)

After fix (use_audio_in_video=True) — correctly transcribes:

"the first words I spoke in the original phonograph. A little piece of practical poetry. Mary had a little lamb, its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go."


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

…audio budget

Signed-off-by: linyueqian <linyueqian@outlook.com>
@mergify mergify bot added qwen Related to Qwen models bug Something isn't working labels Feb 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two important bug fixes for Qwen2.5-Omni when using use_audio_in_video=True. The first fix correctly provides an audio budget by overriding get_mm_max_tokens_per_item, resolving a KeyError. The second fix addresses an embedding misalignment for interleaved audio and video tokens by overriding embed_input_ids to scatter embeddings for each modality separately. While the approach is sound, I've identified a critical issue in the implementation of the embedding separation logic that could lead to incorrect model behavior. A detailed comment with a suggested fix is provided below.

Comment on lines +1362 to +1367
if video_remaining > 0 and n <= video_remaining:
video_embeds.append(emb)
video_remaining -= n
elif audio_remaining > 0 and n <= audio_remaining:
audio_embeds.append(emb)
audio_remaining -= n
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for separating multimodal embeddings into video_embeds and audio_embeds appears to be incorrect. It assumes that video embeddings appear before audio embeddings in the multimodal_embeddings list.

However, based on the implementation of embed_multimodal and _parse_and_validate_multimodal_inputs, the order of embeddings is determined by the order of modalities in mm_input_by_modality. This order is derived from the field order in the dictionary returned by create_qwen2_5_omni_thinker_field_factory, where audio-related fields appear before video-related fields. Consequently, audio embeddings will be placed before video embeddings in multimodal_embeddings.

The current greedy matching logic will incorrectly classify audio embeddings as video embeddings, leading to incorrect model behavior.

To fix this, the order of checks should be swapped to match the embedding order (audio then video).

Suggested change
if video_remaining > 0 and n <= video_remaining:
video_embeds.append(emb)
video_remaining -= n
elif audio_remaining > 0 and n <= audio_remaining:
audio_embeds.append(emb)
audio_remaining -= n
if audio_remaining > 0 and n <= audio_remaining:
audio_embeds.append(emb)
audio_remaining -= n
elif video_remaining > 0 and n <= video_remaining:
video_embeds.append(emb)
video_remaining -= n

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to fix it. The embedding order is actually determined by _gather_mm_embeddings in gpu_model_runner.py, which iterates over req_state.mm_features — where video is registered before audio (as confirmed by the scheduler output's mm_features=[MultiModalFeatureSpec(modality='video', ...), MultiModalFeatureSpec(modality='audio', ...)]). The field factory dict order affects HF processor kwargs, not the feature registration order, which is set by the multimodal processor's placeholder creation (video first, then audio derived via _derive_audio_from_video_placeholders).

…d missing audio budget

Signed-off-by: linyueqian <linyueqian@outlook.com>
@linyueqian linyueqian changed the title [Bugfix] Fix Qwen2.5-Omni audio-in-video embedding merge and missing audio budget [Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni Feb 2, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Feb 2, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@hsliuustc0106
Copy link
Copy Markdown
Contributor

@ywang96 @DarkLight1337 @gaohan PTAL

@hsliuustc0106
Copy link
Copy Markdown
Contributor

has vllm-omni fixed this part or it will rely on this fix?

@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented Feb 3, 2026

The first issue should be fixed in a more general way by #33634, can you focus this PR on the second issue?

@linyueqian
Copy link
Copy Markdown
Contributor Author

The first issue should be fixed in a more general way by #33634, can you focus this PR on the second issue?

ok. i will revert the changes once the pr you mention is merged and test again.

@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Feb 4, 2026

Thanks for the contribution - will take a look today

Copy link
Copy Markdown
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at my comment - thanks!

Signed-off-by: Roger Wang <hey@rogerw.io>
@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Feb 4, 2026

I've updated this PR with some clean up - we can use the following example as the source of truth

import os
import torch

from vllm import LLM, SamplingParams
from transformers import Qwen3OmniMoeProcessor
from qwen_omni_utils import process_mm_info

if __name__ == '__main__':
    MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"

    llm = LLM(
            model=MODEL_PATH, trust_remote_code=True, gpu_memory_utilization=0.95,
            tensor_parallel_size=2,
            limit_mm_per_prompt={'image': 3, 'video': 3, 'audio': 3},
            max_num_seqs=8,
            max_model_len=32768,
            seed=1234,
    )

    sampling_params = SamplingParams(
        temperature=0.6,
        top_p=0.95,
        top_k=20,
        max_tokens=16384,
    )

    processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "video", "video": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4"},
                {"type": "text", "text": "What can you see and hear? Answer in one sentence."}
            ], 
        }
    ]

    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    audios, images, videos = process_mm_info(messages, use_audio_in_video=True)

    inputs = {
        'prompt': text,
        'multi_modal_data': {},
        "mm_processor_kwargs": {
            "use_audio_in_video": True,
        },
    }

    if images is not None:
        inputs['multi_modal_data']['image'] = images
    if videos is not None:
        inputs['multi_modal_data']['video'] = videos
    if audios is not None:
        inputs['multi_modal_data']['audio'] = audios

    outputs = llm.generate([inputs], sampling_params=sampling_params)

    print(outputs[0].outputs[0].text)

Main branch:

A person is using a stylus to draw a guitar on a tablet, and a woman says, "Hello."

This PR:

A person is using a stylus on a tablet to draw a guitar, and they say, "Hello, take a look at what I'm drawing."

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 4, 2026
@ywang96 ywang96 enabled auto-merge (squash) February 4, 2026 10:19
@ywang96 ywang96 merged commit f8516a1 into vllm-project:main Feb 4, 2026
50 checks passed
gameofdimension pushed a commit to gameofdimension/vllm that referenced this pull request Feb 5, 2026
…-Omni (vllm-project#33605)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…-Omni (vllm-project#33605)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
linyueqian added a commit to linyueqian/vllm that referenced this pull request Feb 26, 2026
…llm-project#34506)

PR vllm-project#33605 changed the non-interleaved path in
Qwen2_5OmniThinkerForConditionalGeneration.embed_input_ids to call
_merge_multimodal_embeddings directly instead of super().embed_input_ids.
This broke mixed modalities (audio + image + video): the embeddings list
is ordered by modality type (audio, image, video) but masked_scatter_
fills positions sequentially by token order, so audio embeddings were
incorrectly assigned to image/video positions.

Restore super().embed_input_ids() for the non-interleaved path to match
pre-vllm-project#33605 behaviour. The interleaved use_audio_in_video path is
unchanged and still uses merge_interleaved_embeddings.

Adds unit tests for the regression and related helpers.

Fixes: vllm-project#34506

Signed-off-by: linyueqian <linyueqian@outlook.com>
linyueqian added a commit to linyueqian/vllm that referenced this pull request Feb 26, 2026
…llm-project#34506)

PR vllm-project#33605 changed the non-interleaved path in
Qwen2_5OmniThinkerForConditionalGeneration.embed_input_ids to call
_merge_multimodal_embeddings directly instead of super().embed_input_ids.
This broke mixed modalities (audio + image + video): the embeddings list
is ordered by modality type (audio, image, video) but masked_scatter_
fills positions sequentially by token order, so audio embeddings were
incorrectly assigned to image/video positions.

Restore super().embed_input_ids() for the non-interleaved path to match
pre-vllm-project#33605 behaviour. The interleaved use_audio_in_video path is
unchanged and still uses merge_interleaved_embeddings.

Adds unit tests for the regression and related helpers.

Fixes: vllm-project#34506

Signed-off-by: linyueqian <linyueqian@outlook.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…-Omni (vllm-project#33605)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
@ywang96 ywang96 mentioned this pull request Mar 8, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants