Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -740,23 +740,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.

!!! warning
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:

V0 correctly implements the model's attention pattern:
- Uses bidirectional attention between the image tokens corresponding to the same image
- Uses causal attention for other tokens
- Implemented via (naive) PyTorch SDPA with masking tensors
- Note: May use significant memory for long prompts with image

V1 currently uses a simplified attention pattern:
- Uses causal attention for all tokens, including image tokens
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}`
- Will be updated in the future to support the correct behavior

This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.

!!! note
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
MobileNet-v5 vision backbone.
Expand All @@ -776,9 +759,6 @@ Some models are supported only via the [Transformers modeling backend](#transfor
The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: <https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630>

!!! warning
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.

!!! note
For Qwen2.5-Omni and Qwen3-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`) is currently work in progress and not yet supported.

Expand Down
1 change: 0 additions & 1 deletion tests/models/multimodal/generation/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,6 @@
auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
num_logprobs=10,
),
"glm4v": VLMTestInfo(
models=["zai-org/glm-4v-9b"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ def get_attn_backend_cls(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
9 changes: 9 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def is_mla(cls) -> bool:
def supports_sink(cls) -> bool:
return False

@classmethod
def supports_mm_prefix(cls) -> bool:
return False

@classmethod
def is_sparse(cls) -> bool:
return False
Expand Down Expand Up @@ -207,6 +211,7 @@ def validate_configuration(
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
device_capability: "DeviceCapability",
attn_type: str,
) -> list[str]:
Expand All @@ -219,6 +224,10 @@ def validate_configuration(
invalid_reasons.append("kv_cache_dtype not supported")
if not cls.supports_block_size(block_size):
invalid_reasons.append("block_size not supported")
if use_mm_prefix and not cls.supports_mm_prefix():
invalid_reasons.append(
"partial multimodal token full attention not supported"
)
if use_mla != cls.is_mla():
if use_mla:
invalid_reasons.append("MLA not supported")
Expand Down
5 changes: 5 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def __init__(
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None

# NOTE: model_config may be None during certain tests
model_config = vllm_config.model_config
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
Expand All @@ -241,6 +245,7 @@ def __init__(
block_size,
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
attn_type=attn_type,
)
else:
Expand Down
5 changes: 5 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
Expand All @@ -52,6 +53,7 @@ def get_attn_backend(
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
)

Expand All @@ -66,6 +68,7 @@ def _cached_get_attn_backend(
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
Expand All @@ -87,6 +90,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
else:
Expand All @@ -99,6 +103,7 @@ def _cached_get_attn_backend(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
)
if not attention_cls:
Expand Down
14 changes: 14 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from collections.abc import Callable
from dataclasses import InitVar, field
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args

import torch
Expand Down Expand Up @@ -1217,6 +1218,19 @@ def is_deepseek_mla(self) -> bool:
)
return False

@cached_property
def is_mm_prefix_lm(self) -> bool:
"""Whether to use bidirectional attention for mm positions."""
MM_PREFIX_LM_MODELS = (
"gemma3",
# TODO(Isotr0py): Disable paligemma for now before
# we supports soft cap attention for FlexAttention
# "paligemma",
)
if not hasattr(self.hf_config, "model_type"):
return False
return self.hf_config.model_type in MM_PREFIX_LM_MODELS

def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
Expand Down
25 changes: 25 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,31 @@ def get_num_embeds(self) -> int:

return int(self.is_embed.sum().item())

def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt.

For example, given `PlaceholderRange(offset=2, length=5)` and
`is_embed = [False, True, False, True, True]`, the output is
`[(1 + offset, 1 + offset), (3 + offset, 4 + offset)]`.

Returns:
A tuple `(start, end)` representing the start and end
indices (inclusive) of the embedded region.
Returns full placeholder range if `is_embed` is `None`.
"""
if self.is_embed is None:
return [(self.offset, self.offset + self.length)]

mask_i = self.is_embed.int()
starts = torch.nonzero(
torch.diff(mask_i, prepend=mask_i.new_zeros(1)) == 1
).flatten()
ends = torch.nonzero(
torch.diff(mask_i, append=mask_i.new_zeros(1)) == -1
).flatten()
ranges = torch.stack((starts, ends), dim=1) + self.offset
return [tuple(x) for x in ranges.tolist()]

def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
Expand Down
19 changes: 19 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,20 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)

scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
model_config is not None
and model_config.is_mm_prefix_lm
and scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
):
logger.warning(
"Forcing --disable_chunked_mm_input for models "
"with multimodal-bidirectional attention."
)
scheduler_config.disable_chunked_mm_input = True

@classmethod
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
Expand Down Expand Up @@ -268,6 +282,7 @@ def get_valid_backends(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
) -> tuple[
Expand All @@ -289,6 +304,7 @@ def get_valid_backends(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)
Expand All @@ -312,6 +328,7 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if attn_type is None:
Expand All @@ -332,6 +349,7 @@ def get_attn_backend_cls(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)
Expand All @@ -356,6 +374,7 @@ def get_attn_backend_cls(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
)
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def get_attn_backend_cls(
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
"""Get the attention backend class of a device."""
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def get_attn_backend_cls(
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
Expand Down
5 changes: 3 additions & 2 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def get_attn_backend_cls(
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink,
use_sparse,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
if use_sparse:
Expand Down
3 changes: 2 additions & 1 deletion vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def get_attn_backend_cls(
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
Expand Down
Loading