diff --git a/docs/design/cuda_graphs_multimodal.md b/docs/design/cuda_graphs_multimodal.md index f44ef359df38..f1ece3e0ffe5 100644 --- a/docs/design/cuda_graphs_multimodal.md +++ b/docs/design/cuda_graphs_multimodal.md @@ -87,6 +87,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra | ------------ | ------ | ------------ | ------------ | | `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ | +| `Glm4vForConditionalGeneration` | `GLM-4.1V, GLM-4.6V-Flash` | ✅︎ | ✅︎ | !!! note Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs. diff --git a/examples/generate/multimodal/vision_language_offline.py b/examples/generate/multimodal/vision_language_offline.py index 794f20dd0a52..41a21d462666 100644 --- a/examples/generate/multimodal/vision_language_offline.py +++ b/examples/generate/multimodal/vision_language_offline.py @@ -588,7 +588,6 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData: "fps": 1, }, limit_mm_per_prompt=mm_limit, - enforce_eager=True, ) image_placeholder = "<|begin_of_image|><|image|><|end_of_image|>" @@ -2467,6 +2466,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "qwen3_vl", "qwen3_vl_moe", "qwen2_5_vl", + "glm4_1v", ] diff --git a/tests/models/multimodal/generation/test_vit_cudagraph.py b/tests/models/multimodal/generation/test_vit_cudagraph.py index fb7bdfc8625d..7afcd438142d 100644 --- a/tests/models/multimodal/generation/test_vit_cudagraph.py +++ b/tests/models/multimodal/generation/test_vit_cudagraph.py @@ -66,6 +66,21 @@ def qwen_vl_chat_template(content: str) -> str: needs_video_metadata=False, marks=[pytest.mark.core_model], ), + "glm4_1v": VitCudagraphTestConfig( + model="zai-org/GLM-4.1V-9B-Thinking", + image_prompt=( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + "<|begin_of_image|><|image|><|end_of_image|>" + "What is in this image?<|assistant|>assistant\n" + ), + video_prompt=( + "[gMASK]<|system|>\nYou are a helpful assistant.<|user|>\n" + "<|begin_of_video|><|video|><|end_of_video|>" + "Describe this video in one sentence<|assistant|>assistant\n" + ), + needs_video_metadata=True, + marks=[pytest.mark.core_model], + ), } diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 85f422342a95..d341f6873644 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -29,7 +29,7 @@ import math from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence -from functools import partial +from functools import lru_cache, partial from typing import Annotated, Any, Literal, TypeAlias import numpy as np @@ -95,10 +95,12 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.worker.encoder_cudagraph_defs import EncoderCudaGraphReplayBuffers from ..layers.activation import SiluAndMul from .interfaces import ( MultiModalEmbeddings, + SupportsEncoderCudaGraph, SupportsLoRA, SupportsMRoPE, SupportsMultiModal, @@ -615,6 +617,11 @@ def __init__( ) -> None: super().__init__() + use_data_parallel = is_vit_use_data_parallel() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + patch_size = vision_config.patch_size temporal_patch_size = vision_config.temporal_patch_size in_channels = vision_config.in_channels @@ -662,6 +669,8 @@ def __init__( prefix=f"{prefix}.merger", ) self.embeddings = Glm4vVisionEmbeddings(vision_config) + self.num_position_embeddings = self.embeddings.num_positions + self.num_grid_per_side = int(self.num_position_embeddings**0.5) self.post_conv_layernorm = RMSNorm( vision_config.hidden_size, eps=vision_config.rms_norm_eps @@ -689,43 +698,50 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.proj.weight.device - def rot_pos_emb( - self, grid_thw: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = ( - hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - .permute(0, 2, 1, 3) - .flatten() - ) - wpos_ids = ( - wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - .permute(0, 2, 1, 3) - .flatten() - ) - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + + def rot_pos_emb(self, grid_thw: list[list[int]]): + max_grid_size = max(max(h, w) for _, h, w in grid_thw) + pos_ids = [ + self.rot_pos_ids(h, w, self.spatial_merge_size) + if t == 1 + else self.rot_pos_ids(h, w, self.spatial_merge_size).repeat(t, 1) + for t, h, w in grid_thw + ] + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) cos_combined = cos[pos_ids].flatten(1) sin_combined = sin[pos_ids].flatten(1) - return cos_combined, sin_combined, pos_ids + + return cos_combined, sin_combined def compute_attn_mask_seqlen( self, @@ -740,45 +756,162 @@ def compute_attn_mask_seqlen( max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen + def pos_embeds_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: + device = self.embeddings.position_embedding.weight.device + dtype = self.dtype + all_embeds = [] + + for t, h, w in grid_thw: + h_coords = ( + torch.arange(h, device=device).unsqueeze(1).expand(h, w).reshape(-1) + ) + w_coords = ( + torch.arange(w, device=device).unsqueeze(0).expand(h, w).reshape(-1) + ) + + lengths = [h * w] + image_shapes = torch.tensor([[t, h, w]], device=device) + + embeds = self.embeddings( + embeddings=torch.zeros( + h * w, self.hidden_size, device=device, dtype=dtype + ), + lengths=lengths, + image_shapes=image_shapes, + h_coords=h_coords, + w_coords=w_coords, + ) + embeds = embeds.repeat(t, 1) + all_embeds.append(embeds) + + return torch.cat(all_embeds, dim=0).to(dtype) + + def prepare_encoder_metadata( + self, + grid_thw_list: list[list[int]], + *, + max_batch_size: int | None = None, + max_frames_per_batch: int | None = None, + max_seqlen_override: int | None = None, + device: torch.device | None = None, + ) -> dict[str, torch.Tensor | None]: + """Compute encoder metadata from grid_thw_list. + + Shared by the eager forward path, CUDA graph capture, and + CUDA graph replay to avoid duplicated implementation. + + Args: + grid_thw_list: Grid configurations as list of [t, h, w]. + max_batch_size: If set, pad cu_seqlens to this size + (needed for CUDA graph capture/replay). + max_frames_per_batch: If set, overrides max_batch_size for + cu_seqlens padding. For video inputs each item contributes + T attention sequences (frames); this sizes the buffer to + the total frame budget so video replays never overflow. + max_seqlen_override: If set, use this value for max_seqlen + instead of computing from cu_seqlens (needed for CUDA + graph capture to cover worst-case replay scenarios). + device: Device to place tensors on. Defaults to self.device. + """ + if device is None: + device = self.device + + metadata: dict[str, torch.Tensor | None] = {} + + # Positional embeddings + metadata["pos_embeds"] = self.pos_embeds_interpolate(grid_thw_list) + rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw_list) + metadata["rotary_pos_emb_cos"] = rotary_cos + metadata["rotary_pos_emb_sin"] = rotary_sin + + # cu_seqlens from grid_thw + grid_thw_np = np.array(grid_thw_list, dtype=np.int32) + patches_per_frame = grid_thw_np[:, 1] * grid_thw_np[:, 2] + cu_seqlens = np.repeat(patches_per_frame, grid_thw_np[:, 0]).cumsum( + dtype=np.int32 + ) + cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens]) + + # Pad cu_seqlens to the required number of sequences. + # For videos each item contributes T frames = T attention sequences, + # so the total can exceed max_batch_size. max_frames_per_batch + # overrides the pad target when set. + pad_to = ( + max_frames_per_batch if max_frames_per_batch is not None else max_batch_size + ) + if pad_to is not None: + num_seqs = len(cu_seqlens) - 1 + if num_seqs < pad_to: + cu_seqlens = np.concatenate( + [ + cu_seqlens, + np.full( + pad_to - num_seqs, + cu_seqlens[-1], + dtype=np.int32, + ), + ] + ) + + # sequence_lengths (backend-specific) + metadata["sequence_lengths"] = MMEncoderAttention.maybe_compute_seq_lens( + self.attn_backend, cu_seqlens, device + ) + + # max_seqlen + if max_seqlen_override is not None: + max_seqlen_val = max_seqlen_override + else: + max_seqlen_val = MMEncoderAttention.compute_max_seqlen( + self.attn_backend, cu_seqlens + ) + # Keep max_seqlen on CPU: attention wrappers call .item() on it, + # and having it on GPU would capture a wasteful D2H copy in CUDA + # graphs without changing behavior (the scalar is baked at capture). + metadata["max_seqlen"] = torch.tensor(max_seqlen_val, dtype=torch.int32) + + # Recompute cu_seqlens (backend-specific transformation) + metadata["cu_seqlens"] = MMEncoderAttention.maybe_recompute_cu_seqlens( + self.attn_backend, + cu_seqlens, + self.hidden_size, + self.tp_size, + device, + ) + + return metadata + def forward( self, x: torch.Tensor, grid_thw: torch.Tensor | list[list[int]], + *, + encoder_metadata: dict[str, torch.Tensor] | None = None, ) -> torch.Tensor: - if isinstance(grid_thw, list): - grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + if encoder_metadata is None: + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw = grid_thw.tolist() + encoder_metadata = self.prepare_encoder_metadata(grid_thw) # patchify x = x.to(device=self.device, dtype=self.dtype) x = self.patch_embed(x) x = self.post_conv_layernorm(x) - # compute position embedding - rotary_pos_emb_cos, rotary_pos_emb_sin, image_type_ids = self.rot_pos_emb( - grid_thw - ) - # compute cu_seqlens - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0, dtype=torch.int32) - cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) - # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations - max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) - x = self.embeddings( - x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1] - ) + pos_embeds = encoder_metadata["pos_embeds"] + x = x + pos_embeds # transformers x = x.unsqueeze(1) for blk in self.blocks: x = blk( x, - cu_seqlens=cu_seqlens, - rotary_pos_emb_cos=rotary_pos_emb_cos, - rotary_pos_emb_sin=rotary_pos_emb_sin, - max_seqlen=max_seqlen, + cu_seqlens=encoder_metadata["cu_seqlens"], + rotary_pos_emb_cos=encoder_metadata["rotary_pos_emb_cos"], + rotary_pos_emb_sin=encoder_metadata["rotary_pos_emb_sin"], + max_seqlen=encoder_metadata["max_seqlen"], ) # adapter @@ -1385,7 +1518,12 @@ def get_video_replacement_glm4v(item_idx: int): dummy_inputs=Glm4vDummyInputsBuilder, ) class Glm4vForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE + nn.Module, + SupportsMultiModal, + SupportsEncoderCudaGraph, + SupportsLoRA, + SupportsPP, + SupportsMRoPE, ): packed_modules_mapping = { "qkv_proj": [ @@ -1423,8 +1561,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): multimodal_config = vllm_config.model_config.multimodal_config self.config = config + self.model_config = vllm_config.model_config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) with self._mark_tower_model(vllm_config, {"image", "video"}): self.visual = Glm4vVisionTransformer( @@ -1550,6 +1692,280 @@ def _process_video_input( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) + # -- SupportsEncoderCudaGraph protocol methods -- + + def get_encoder_cudagraph_config(self): + from vllm.v1.worker.encoder_cudagraph_defs import ( + EncoderCudaGraphConfig, + ) + + modalities = ["image"] + # NOTE: When EVS (Efficient Video Sampling) pruning is enabled, the number + # of tokens becomes data-dependent (i.e., the retained tokens are + # dynamically selected based on inter-frame differences) and therefore + # cannot be captured by CUDA Graphs. As a result, video CUDA Graphs are + # only enabled when EVS is disabled. + if not self.is_multimodal_pruning_enabled: + modalities.append("video") + + return EncoderCudaGraphConfig( + modalities=modalities, + input_key_by_modality={ + "image": "pixel_values", + "video": "pixel_values_videos", + }, + buffer_keys=[ + "pos_embeds", + "rotary_pos_emb_cos", + "rotary_pos_emb_sin", + "cu_seqlens", + "max_seqlen", + "sequence_lengths", + ], + out_hidden_size=self.visual.out_hidden_size, + ) + + def get_input_modality( + self, + mm_kwargs: dict[str, Any], + ) -> str: + if "image_grid_thw" in mm_kwargs: + return "image" + return "video" + + def get_max_frames_per_video(self) -> int: + mm_registry = MULTIMODAL_REGISTRY + info = mm_registry.get_processing_info(self.model_config) + max_frames_per_video = info.get_num_frames_with_most_features( + seq_len=self.model_config.max_model_len, + mm_counts={"video": self.multimodal_config.get_limit_per_prompt("video")}, + ) + + image_longest = info.get_image_processor().size["longest_edge"] + video_longest = info.get_video_processor().size["longest_edge"] + max_frames_from_info = video_longest // image_longest + + max_frames_per_video = max(max_frames_per_video, max_frames_from_info, 16) + return max_frames_per_video + + def get_encoder_cudagraph_budget_range( + self, + vllm_config, + ) -> tuple[int, int]: + # Min: estimated smallest possible encoder input. + # 224x224 image → 16x16 patches (patch_size=14) + # spatial_merge_size=2 → 8x8 = 64 tokens + min_budget = 64 + # Max: capped by max_num_batched_tokens + max_budget = min( + vllm_config.scheduler_config.max_num_batched_tokens, + vllm_config.model_config.max_model_len, + ) + return (min_budget, max_budget) + + def _get_pixel_values_by_modality( + self, + mm_kwargs: dict[str, Any], + ) -> torch.Tensor: + if self.get_input_modality(mm_kwargs) == "image": + pixel_values = mm_kwargs["pixel_values"] + else: + pixel_values = mm_kwargs["pixel_values_videos"] + return pixel_values + + def _get_grid_thw_by_modality( + self, + mm_kwargs: dict[str, Any], + ) -> list[tuple[int, int, int]]: + grid_thw_key = f"{self.get_input_modality(mm_kwargs)}_grid_thw" + grid_thw = mm_kwargs[grid_thw_key] + if not isinstance(grid_thw, list): + grid_thw = grid_thw.tolist() + return grid_thw + + def get_encoder_cudagraph_num_items( + self, + mm_kwargs: dict[str, Any], + ) -> int: + return len(self._get_grid_thw_by_modality(mm_kwargs)) + + def get_encoder_cudagraph_per_item_output_tokens( + self, + mm_kwargs: dict[str, Any], + ) -> list[int]: + m = self.visual.spatial_merge_size + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return [t * (h // m) * (w // m) for t, h, w in grid_thw] + + def get_encoder_cudagraph_per_item_input_sizes( + self, + mm_kwargs: dict[str, Any], + ) -> list[int]: + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return [t * h * w for t, h, w in grid_thw] + + def select_encoder_cudagraph_items( + self, + mm_kwargs: dict[str, Any], + indices: list[int], + ) -> dict[str, Any]: + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + + if len(indices) == 0: + if self.get_input_modality(mm_kwargs) == "image": + return { + "pixel_values": pixel_values[:0], + "image_grid_thw": [], + } + else: + return { + "pixel_values_videos": pixel_values[:0], + "video_grid_thw": [], + } + + # Compute cumulative patch offsets for slicing pixel_values + patches_per_item = [t * h * w for t, h, w in grid_thw] + cum_patches = [0] + for p in patches_per_item: + cum_patches.append(cum_patches[-1] + p) + + selected_pv = torch.cat( + [pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices] + ) + selected_grid = [grid_thw[i] for i in indices] + + if self.get_input_modality(mm_kwargs) == "image": + return { + "pixel_values": selected_pv, + "image_grid_thw": selected_grid, + } + else: + return { + "pixel_values_videos": selected_pv, + "video_grid_thw": selected_grid, + } + + def prepare_encoder_cudagraph_capture_inputs( + self, + token_budget: int, + max_batch_size: int, + max_frames_per_batch: int, + device: torch.device, + dtype: torch.dtype, + ): + from vllm.v1.worker.encoder_cudagraph_defs import ( + EncoderCudaGraphCaptureInputs, + ) + + spatial_merge_size = self.visual.spatial_merge_size + per_mm_item_output = token_budget // max_batch_size + + frames_per_item = max_frames_per_batch // max_batch_size + if frames_per_item > 1: + # Build the capture grid using a video-format layout so that + # cu_seqlens is sized for video replays from the start. + # cu_seqlens has one entry per attention sequence (one per frame), + # so using T > 1 per item makes the buffer large enough without + # relying solely on padding. + # Ceiling ensures frames_per_item * tokens_per_frame >= per_mm_item_output + # so the pixel_values buffer covers any valid single-item replay. + tokens_per_frame = ( + per_mm_item_output + frames_per_item - 1 + ) // frames_per_item + # Video-format grid_config (T=frames_per_item). + grid_config = [ + [ + frames_per_item, + spatial_merge_size, + tokens_per_frame * spatial_merge_size, + ] + for _ in range(max_batch_size) + ] + else: + # Image-format grid_config (T=1). + grid_config = [ + [1, spatial_merge_size, per_mm_item_output * spatial_merge_size] + for _ in range(max_batch_size) + ] + + # Create dummy pixel_values + patch_embed = self.visual.patch_embed + in_channels = patch_embed.proj.in_channels + patch_size = patch_embed.patch_size + temporal_patch_size = patch_embed.temporal_patch_size + total_patches = sum(t * h * w for t, h, w in grid_config) + flattened_patch_size = ( + in_channels * temporal_patch_size * patch_size * patch_size + ) + dummy_pixel_values = torch.randn( + total_patches, flattened_patch_size, device=device, dtype=dtype + ) + + # Override max_seqlen with a safe upper bound for capture. + # max_seqlen.item() gets baked into the CUDA graph (not replayed), + # so the capture value must cover any replay scenario. + # Worst case: 1 item consuming the full budget -> + # seq_len = token_budget * spatial_merge_size^2. + buffers = self.visual.prepare_encoder_metadata( + grid_config, + max_batch_size=max_batch_size, + max_frames_per_batch=max_frames_per_batch, + max_seqlen_override=token_budget * (spatial_merge_size**2), + device=device, + ) + + # Just use image-modality dummy input_buffer for capturing, since it's also + # compatible for video inputs (has the same shape: [num_patches, C*T*P*P]). + mm_kwargs = { + "pixel_values": dummy_pixel_values, + "image_grid_thw": grid_config, + } + + return EncoderCudaGraphCaptureInputs( + mm_kwargs=mm_kwargs, + buffers=buffers, + ) + + def prepare_encoder_cudagraph_replay_buffers( + self, + mm_kwargs: dict[str, Any], + max_batch_size: int, + max_frames_per_batch: int, + ): + modality = self.get_input_modality(mm_kwargs) + grid_thw_list = self._get_grid_thw_by_modality(mm_kwargs) + + if modality == "image": + buffers = self.visual.prepare_encoder_metadata( + grid_thw_list, + max_batch_size=max_batch_size, + ) + else: + buffers = self.visual.prepare_encoder_metadata( + grid_thw_list, + max_frames_per_batch=max_frames_per_batch, + ) + + return EncoderCudaGraphReplayBuffers(buffers=buffers) + + def encoder_cudagraph_forward( + self, + mm_kwargs: dict[str, Any], + buffers: dict[str, torch.Tensor], + ) -> torch.Tensor: + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return self.visual(pixel_values, grid_thw, encoder_metadata=buffers) + + def encoder_eager_forward( + self, + mm_kwargs: dict[str, Any], + ) -> torch.Tensor: + pixel_values = self._get_pixel_values_by_modality(mm_kwargs) + grid_thw = self._get_grid_thw_by_modality(mm_kwargs) + return self.visual(pixel_values, grid_thw) + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {}