Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions examples/offline_inference/qwen2_5_omni/only_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from typing import NamedTuple

import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
Expand Down Expand Up @@ -72,11 +71,7 @@ def get_use_audio_in_video_query() -> QueryResult:
)
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
assert not envs.VLLM_USE_V1, (
"V1 does not support use_audio_in_video. "
"Please launch this example with "
"`VLLM_USE_V1=0`."
)

return QueryResult(
inputs={
"prompt": prompt,
Expand Down Expand Up @@ -125,7 +120,7 @@ def get_multi_audios_query() -> QueryResult:


def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B"
model_name = "Qwen/Qwen2.5-Omni-3B"
query_result = query_map[args.query_type]()

llm = LLM(
Expand All @@ -138,7 +133,7 @@ def main(args):

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
sampling_params = SamplingParams(temperature=0.2, max_tokens=128)

outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)

Expand Down
236 changes: 218 additions & 18 deletions vllm/model_executor/models/qwen2_5_omni_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems, MultiModalUUIDDict,
NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
Expand Down Expand Up @@ -121,11 +122,24 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str,

num_videos = len(video_grid_sizes)

# Check if audio is embedded in video
use_audio_in_video_tensor = hf_inputs.get("use_audio_in_video")
use_audio_in_video = False
if use_audio_in_video_tensor is not None:
use_audio_in_video = bool(use_audio_in_video_tensor.item())

# When use_audio_in_video=True, audio fields should be grouped
# with video modality
# This way audio and video data are in the same mm_kwargs item
audio_modality = "video" if use_audio_in_video else "audio"

return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_feature_lengths, dim=1),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
audio_modality, audio_feature_lengths, dim=1),
feature_attention_mask=MultiModalFieldConfig.batched(
audio_modality),
audio_feature_lengths=MultiModalFieldConfig.batched(
audio_modality),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_pixel_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
Expand Down Expand Up @@ -309,6 +323,44 @@ def _get_mm_fields_config(
self.info.get_hf_config().vision_config.spatial_merge_size)(
hf_inputs)

def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
):
"""
Override to bypass caching when use_audio_in_video=True.

When use_audio_in_video=True, audio and video are tightly coupled and
the caching logic doesn't handle this properly (hash/data mismatch).
For now, bypass caching in this case.
"""
use_audio_in_video = hf_processor_mm_kwargs.get(
"use_audio_in_video", False)

if use_audio_in_video:
# Bypass caching for use_audio_in_video case
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)

# Normal caching path for other cases
return super()._cached_apply_hf_processor(
prompt,
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)

def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
Expand All @@ -321,8 +373,8 @@ def _maybe_apply_prompt_updates(
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

# Detect use_audio_in_video first
use_audio_in_video = False
if "video" in mm_kwargs:
video_items = [
Expand All @@ -333,6 +385,18 @@ def _maybe_apply_prompt_updates(
use_audio_in_video = all(item["use_audio_in_video"].data
for item in video_items)

# When use_audio_in_video=True, audio items are merged into video
# mm_kwargs
# So we need to adjust mm_item_counts to reflect this
adjusted_item_counts = mm_item_counts.copy()
if use_audio_in_video and "audio" in adjusted_item_counts \
and "video" in adjusted_item_counts:
# Audio items are now part of video mm_kwargs, so remove them
# from the count
del adjusted_item_counts["audio"]

self._validate_mm_kwargs(mm_kwargs, adjusted_item_counts)

if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
Expand Down Expand Up @@ -840,6 +904,111 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
"audio"] = self._parse_and_validate_audio_input(**kwargs)
return mm_input_by_modality

def _interleave_audio_video_embeddings(
self,
audio_embeddings: tuple[torch.Tensor, ...],
video_embeddings: tuple[torch.Tensor, ...],
mm_input_by_modality: dict,
) -> torch.Tensor:
"""
Interleave audio and video embeddings based on temporal chunks.

When use_audio_in_video=True, audio and video are split into temporal
chunks and interleaved: [video_chunk1, audio_chunk1, video_chunk2,
audio_chunk2, ...].
"""
# Extract single tensors (we expect one video and one audio)
audio_emb = audio_embeddings[0] # [audio_len, hidden_size]
video_emb = video_embeddings[0] # [video_len, hidden_size]

# Get chunking parameters from config
thinker_config = self.config
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)

# Get video grid dimensions from mm_input
video_input = mm_input_by_modality["video"]
video_grid_thw = video_input["video_grid_thw"]
video_second_per_grid_t = video_input.get("video_second_per_grid_t",
1.0)

grid_thw_list = video_grid_thw.tolist()
if isinstance(grid_thw_list, list) and len(grid_thw_list) == 1:
grid_thw_list = grid_thw_list[0]
Comment on lines +937 to +939
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
grid_thw_list = video_grid_thw.tolist()
if isinstance(grid_thw_list, list) and len(grid_thw_list) == 1:
grid_thw_list = grid_thw_list[0]
grid_thw_list = video_grid_thw.tolist()
if len(grid_thw_list) == 1:
grid_thw_list = grid_thw_list[0]

The isinstance check is redundant

grid_t, grid_h, grid_w = grid_thw_list
grid_t, grid_h, grid_w = int(grid_t), int(grid_h), int(grid_w)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can perform the int conversion and removal of batch dimension before converting to list


# Calculate temporal chunks
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (torch.arange(grid_t, device=video_emb.device) *
video_second_per_grid_t * tokens_per_second).long()

# Split t_index into chunks
t_index_split_chunk = MRotaryEmbedding._split_list_into_ranges(
t_index, t_ntoken_per_chunk)

# Interleave embeddings chunk by chunk
interleaved_chunks = []
audio_start_idx = 0
video_start_idx = 0

for chunk_idx, t_chunk in enumerate(t_index_split_chunk):
# Calculate video tokens for this chunk
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
spatial_merge_size**2)

# Get video chunk
video_chunk = video_emb[video_start_idx:video_start_idx +
vision_ntoken_per_chunk]
interleaved_chunks.append(video_chunk)
video_start_idx += vision_ntoken_per_chunk

# Get audio chunk
audio_chunk_size = min(t_ntoken_per_chunk,
audio_emb.shape[0] - audio_start_idx)
if audio_chunk_size > 0:
audio_chunk = audio_emb[audio_start_idx:audio_start_idx +
audio_chunk_size]
interleaved_chunks.append(audio_chunk)
audio_start_idx += audio_chunk_size

# Add any remaining audio
if audio_start_idx < audio_emb.shape[0]:
remaining_audio = audio_emb[audio_start_idx:]
interleaved_chunks.append(remaining_audio)

# Concatenate all chunks
merged_embedding = torch.cat(interleaved_chunks, dim=0)

# TODO: this should be moved to the placeholder mechanism
# Get the text embeddings for audio_bos and audio_eos tokens
thinker_config = self.config
# <|audio_bos|>
audio_start_token_id = thinker_config.audio_start_token_id
# <|audio_eos|>
audio_end_token_id = thinker_config.audio_end_token_id

embed_layer = self.language_model.model.embed_tokens
device = merged_embedding.device
dtype = merged_embedding.dtype

# Get inferred embeddings for special tokens
audio_bos_ids = torch.tensor([audio_start_token_id], device=device)
audio_eos_ids = torch.tensor([audio_end_token_id], device=device)
audio_bos_emb = embed_layer(audio_bos_ids).to(
dtype) # [1, hidden_size]
audio_eos_emb = embed_layer(audio_eos_ids).to(
dtype) # [1, hidden_size]

# Add special token embeddings at beginning and end
# Structure: [audio_bos] + [interleaved] + [audio_eos]
merged_embedding = torch.cat(
[audio_bos_emb, merged_embedding, audio_eos_emb], dim=0)

Comment on lines +985 to +1009
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The manual handling of audio_bos and audio_eos token embeddings is a maintainability concern. This logic is inconsistent with how special tokens are typically handled through the placeholder mechanism in vLLM, which makes the code more fragile to future changes in tokenization or embedding logic. As noted in the TODO comment, this should be refactored to use the placeholder mechanism, for example, by modifying omni_get_updates_use_audio_in_video to include the special tokens and letting the standard machinery handle the embedding.

return merged_embedding

def get_language_model(self) -> torch.nn.Module:
return self.language_model

Expand All @@ -855,19 +1024,50 @@ def get_multimodal_embeddings(self,
# tensor corresponding to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()

# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += vision_embeddings
if modality == "video":
video_embeddings = self._process_video_input(multimodal_input)
multimodal_embeddings += video_embeddings
if modality == "audio":
audio_embeddings = self._process_audio_input(multimodal_input)
multimodal_embeddings += audio_embeddings
# Check if use_audio_in_video is enabled - if so, we need to
# interleave audio and video embeddings into a single tensor
use_audio_in_video = False
if "video" in mm_input_by_modality and "audio" in mm_input_by_modality:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I think this will never be hit in V1 engine because we only call get_multimodal_embeddings on one modality at a time

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is a bit confusing as in this design we treat audio as "residing" in video, so it would be seen as video but has input_audio_features.

# use_audio_in_video comes from kwargs, not from the video input
# itself
use_audio_in_video = kwargs.get("use_audio_in_video", False)
use_audio_in_video = bool(use_audio_in_video.item())
Comment on lines +1033 to +1034
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for determining use_audio_in_video is incorrect and will fail for batched requests.

  1. kwargs.get("use_audio_in_video", False) is problematic. If the key is not present, it returns False, and False.item() will raise an AttributeError.
  2. If multiple requests are batched, kwargs.get("use_audio_in_video") will be a tensor with multiple elements. Calling .item() on it will raise a ValueError.

The logic should correctly handle the case where the key is missing and where the tensor is batched. Since the current architecture does not support mixed batches of use_audio_in_video, we should enforce that all requests in a batch have the same setting. A more robust implementation would be to check if all values in the tensor are the same before proceeding.

Suggested change
use_audio_in_video = kwargs.get("use_audio_in_video", False)
use_audio_in_video = bool(use_audio_in_video.item())
use_audio_in_video_tensor = kwargs.get("use_audio_in_video")
use_audio_in_video = False
if use_audio_in_video_tensor is not None:
# This assumes all requests in a batch have the same setting.
# We take the value from the first item.
val = use_audio_in_video_tensor[0] if use_audio_in_video_tensor.ndim > 0 else use_audio_in_video_tensor
use_audio_in_video = bool(val.item())


if use_audio_in_video:
# Process audio and video separately
audio_embeddings = self._process_audio_input(
mm_input_by_modality["audio"])
video_embeddings = self._process_video_input(
mm_input_by_modality["video"])

# Interleave audio and video embeddings
merged_embedding = self._interleave_audio_video_embeddings(
audio_embeddings, video_embeddings, mm_input_by_modality)
multimodal_embeddings += (merged_embedding, )

# Process images if present
if "image" in mm_input_by_modality:
image_embeddings = self._process_image_input(
mm_input_by_modality["image"])
multimodal_embeddings += image_embeddings
else:
# Normal processing without interleaving
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality]
if modality == "image":
vision_embeddings = self._process_image_input(
multimodal_input)
multimodal_embeddings += vision_embeddings
if modality == "video":
video_embeddings = self._process_video_input(
multimodal_input)
multimodal_embeddings += video_embeddings
if modality == "audio":
audio_embeddings = self._process_audio_input(
multimodal_input)
multimodal_embeddings += audio_embeddings
return multimodal_embeddings

# TODO (ywang96): support overlapping modality embeddings so that
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,10 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
second_per_grid_ts.append(t)
if (t := mm_input.get("audio_feature_lengths")) is not None:
audio_feature_lengths.append(t)
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
# Check for use_audio_in_video
use_audio_in_video_value = mm_input.get("use_audio_in_video")
if use_audio_in_video_value is not None:
use_audio_in_video = bool(use_audio_in_video_value.item())

if supports_mrope(self.model):
req_state.mrope_positions, req_state.mrope_position_delta = \
Expand Down
Loading