Skip to content
126 changes: 123 additions & 3 deletions vllm/model_executor/models/qwen2_5_omni_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,95 @@
logger = init_logger(__name__)


def check_interleaved_audio_video(
is_video: torch.Tensor,
is_audio: torch.Tensor,
num_video: int,
num_audio: int,
) -> bool:
"""
Check if video and audio positions are interleaved in the multimodal region.

Returns:
True if video and audio tokens are interleaved, False otherwise.
"""
if num_video == 0 or num_audio == 0:
return False

video_pos = is_video.nonzero(as_tuple=True)[0]
audio_pos = is_audio.nonzero(as_tuple=True)[0]

return (
video_pos[0].item() < audio_pos[-1].item()
and audio_pos[0].item() < video_pos[-1].item()
)


def merge_interleaved_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: "MultiModalEmbeddings",
is_video: torch.Tensor,
is_audio: torch.Tensor,
is_multimodal: torch.Tensor,
num_video: int,
num_audio: int,
) -> torch.Tensor:
"""
Merge embeddings for interleaved audio-in-video sequences.

When use_audio_in_video=True, video and audio tokens are interleaved in
the token sequence, but embeddings are provided as separate contiguous
tensors (video first, then audio). This function reorders video and audio
embeddings to match sequence position order and scatters them efficiently.

Args:
inputs_embeds: The input embeddings tensor to merge into.
multimodal_embeddings: List of embedding tensors (video, audio, other).
is_video: Boolean mask for video token positions.
is_audio: Boolean mask for audio token positions.
is_multimodal: Boolean mask for all multimodal token positions.
num_video: Total count of video tokens.
num_audio: Total count of audio tokens.

Returns:
The merged inputs_embeds tensor with multimodal embeddings scattered
to their correct positions.
"""
# Categorize embeddings by modality based on token counts.
# Embeddings come grouped by modality but order varies (e.g., image, video, audio
# or video, audio depending on input kwargs order).
video_embeds: list[torch.Tensor] = []
audio_embeds: list[torch.Tensor] = []
other_embeds: list[torch.Tensor] = []
video_remaining = num_video
audio_remaining = num_audio

for emb in multimodal_embeddings:
n = emb.shape[0]
if video_remaining > 0 and n <= video_remaining:
video_embeds.append(emb)
video_remaining -= n
elif audio_remaining > 0 and n <= audio_remaining:
audio_embeds.append(emb)
audio_remaining -= n
else:
other_embeds.append(emb)

# Scatter each modality to its positions
if video_embeds:
video_positions = is_video.nonzero(as_tuple=True)[0]
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
if audio_embeds:
audio_positions = is_audio.nonzero(as_tuple=True)[0]
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
if other_embeds:
other_mask = is_multimodal & ~is_video & ~is_audio
other_positions = other_mask.nonzero(as_tuple=True)[0]
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)

return inputs_embeds


class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
"""
Dimensions:
Expand Down Expand Up @@ -1286,17 +1375,48 @@ def embed_input_ids(
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
from .utils import _merge_multimodal_embeddings

if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)

return super().embed_input_ids(
inputs_embeds = self._embed_text_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)

if len(multimodal_embeddings) == 0:
return inputs_embeds

# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region.
video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index

is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)

num_video = is_video.sum().item()
num_audio = is_audio.sum().item()

if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
is_video,
is_audio,
is_multimodal,
num_video,
num_audio,
)

# Default: standard merge (no interleaving)
return _merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
)

def forward(
self,
input_ids: torch.Tensor | None,
Expand Down
58 changes: 49 additions & 9 deletions vllm/model_executor/models/qwen3_omni_moe_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
check_interleaved_audio_video,
merge_interleaved_embeddings,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
Expand Down Expand Up @@ -1780,6 +1782,19 @@ def embed_input_ids(
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds

# Detect interleaved audio-in-video early, since it affects
# both the deepstack path and the final embedding merge.
video_token_id = self.config.video_token_id
audio_token_id = self.config.audio_token_id
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()

is_interleaved = check_interleaved_audio_video(
is_video, is_audio, num_video, num_audio
)

deepstack_input_embeds = None
# split the feat dim to obtain multi-scale visual feature
has_vision_embeddings = [
Expand All @@ -1791,14 +1806,18 @@ def embed_input_ids(
):
multiscale_len = len(self.visual.deepstack_visual_indexes)
multimodal_embeddings_multiscale = []
is_vision = torch.zeros_like(is_multimodal)
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
mm_position_idx = 0

if is_interleaved:
# Use input_ids-based mask for correct vision positions
# when audio and video tokens are interleaved.
is_vision = is_video.clone()
else:
is_vision = torch.zeros_like(is_multimodal)
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
mm_position_idx = 0

for index, embeddings in enumerate(multimodal_embeddings):
num_tokens = embeddings.shape[0]
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]

# Vision embeddings
if embeddings.shape[-1] != self.config.text_config.hidden_size:
Expand All @@ -1809,13 +1828,22 @@ def embed_input_ids(
)
multimodal_embeddings[index] = embeddings_main
multimodal_embeddings_multiscale.append(embeddings_multiscale)
is_vision[current_positions] = True
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = True

# Audio embeddings
else:
is_vision[current_positions] = False
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = False

mm_position_idx += num_tokens
if not is_interleaved:
mm_position_idx += num_tokens

deepstack_input_embeds = inputs_embeds.new_zeros(
inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
Expand All @@ -1834,6 +1862,18 @@ def embed_input_ids(
)
self._set_deepstack_input_embeds(deepstack_input_embeds)

if is_interleaved:
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
is_video,
is_audio,
is_multimodal,
num_video,
num_audio,
)

# Default: standard merge (no interleaving)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
Expand Down