Skip to content
35 changes: 13 additions & 22 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (
Expand All @@ -58,6 +58,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
Expand Down Expand Up @@ -1432,26 +1433,24 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for Ernie VL."""
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])

hf_config = self.config
image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = []

if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()

if len(image_grid_thw) or len(video_grid_thw):
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
Expand Down Expand Up @@ -1483,11 +1482,7 @@ def get_mrope_input_positions(
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
t, h, w = image_grid_thw[mm_data_idx].tolist()
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_conv_size,
Expand Down Expand Up @@ -1518,11 +1513,7 @@ def get_mrope_input_positions(
mm_data_idx += 1

elif modality_type == "video":
t, h, w = (
video_grid_thw[mm_data_idx][0],
video_grid_thw[mm_data_idx][1],
video_grid_thw[mm_data_idx][2],
)
t, h, w = video_grid_thw[mm_data_idx].tolist()
llm_grid_t, llm_grid_h, llm_grid_w = (
t // temporal_conv_size,
h // spatial_conv_size,
Expand Down
32 changes: 13 additions & 19 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig
from transformers import BatchFeature
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor,
Expand Down Expand Up @@ -70,6 +70,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
Expand Down Expand Up @@ -1618,25 +1619,23 @@ def get_multimodal_embeddings(
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])

hf_config = self.config
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []

if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()

if len(image_grid_thw) or len(video_grid_thw):
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
Expand Down Expand Up @@ -1668,11 +1667,7 @@ def get_mrope_input_positions(
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
t, h, w = image_grid_thw[mm_data_idx].tolist()
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
Expand Down Expand Up @@ -1705,8 +1700,7 @@ def get_mrope_input_positions(
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
*image_grid_thw[mm_data_idx][1:].tolist(),
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
Expand Down
32 changes: 13 additions & 19 deletions vllm/model_executor/models/glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput

Expand All @@ -36,6 +36,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
Expand Down Expand Up @@ -622,25 +623,23 @@ def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tenso
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])

hf_config = self.config
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []

if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()

if len(image_grid_thw) or len(video_grid_thw):
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
Expand Down Expand Up @@ -672,11 +671,7 @@ def get_mrope_input_positions(
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
t, h, w = image_grid_thw[mm_data_idx].tolist()
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
Expand Down Expand Up @@ -709,8 +704,7 @@ def get_mrope_input_positions(
elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
*image_grid_thw[mm_data_idx][1:].tolist(),
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
Expand Down
22 changes: 6 additions & 16 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torch
import torch.nn as nn
from torch import Tensor
from transformers import PretrainedConfig
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs

Expand All @@ -32,10 +31,12 @@
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.sequence import IntermediateTensors
else:
VllmConfig = object
WeightsMapper = object
MultiModalFeatureSpec = object
IntermediateTensors = object

logger = init_logger(__name__)
Expand Down Expand Up @@ -991,12 +992,7 @@ class SupportsMRoPE(Protocol):
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list["MultiModalFeatureSpec"],
) -> tuple[torch.Tensor, int]:
"""
Get M-RoPE input positions and delta value for this specific model.
Expand All @@ -1006,17 +1002,11 @@ def get_mrope_input_positions(

Args:
input_tokens: List of input token IDs
hf_config: HuggingFace model configuration
image_grid_thw: Image grid dimensions (t, h, w)
video_grid_thw: Video grid dimensions (t, h, w)
second_per_grid_ts: Seconds per grid timestep for videos
audio_feature_lengths: Audio feature lengths for multimodal models
use_audio_in_video: Whether to use audio in video for interleaving
mm_features: Information about each multi-modal data item

Returns:
Tuple of (llm_positions, mrope_position_delta)
- llm_positions: Tensor of shape [3, num_tokens]
with T/H/W positions
Tuple of `(llm_positions, mrope_position_delta)`
- llm_positions: Tensor of shape `[3, num_tokens]` with T/H/W positions
- mrope_position_delta: Delta for position calculations
"""
...
Expand Down
29 changes: 12 additions & 17 deletions vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ImageItem,
ModalityData,
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
Expand Down Expand Up @@ -1627,16 +1628,17 @@ def _process_video_input(
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])

if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""

def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Expand All @@ -1662,6 +1664,7 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:

video_grid_thw = split_thw(video_grid_thw)

hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
Expand Down Expand Up @@ -1691,20 +1694,12 @@ def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
ed_video = len(input_tokens) + 1

if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
t, h, w = image_grid_thw[image_index].tolist()
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
t, h, w = video_grid_thw[video_index].tolist()
video_index += 1
remain_frames -= 1
ed = ed_video
Expand Down
Loading
Loading