diff --git a/docs/design/cuda_graphs_multimodal.md b/docs/design/cuda_graphs_multimodal.md index f44ef359df38..3920ff34556f 100644 --- a/docs/design/cuda_graphs_multimodal.md +++ b/docs/design/cuda_graphs_multimodal.md @@ -87,6 +87,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra | ------------ | ------ | ------------ | ------------ | | `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ | +| `MiniCPMV` | `MiniCPMV2.5`,`MiniCPMV2.6`,`MiniCPMV4.0`,`MiniCPMV4.5` | ✅︎ | ✅︎ | !!! note Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs. diff --git a/examples/generate/multimodal/vision_language_offline.py b/examples/generate/multimodal/vision_language_offline.py index 794f20dd0a52..c691da7ac64e 100644 --- a/examples/generate/multimodal/vision_language_offline.py +++ b/examples/generate/multimodal/vision_language_offline.py @@ -2467,6 +2467,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "qwen3_vl", "qwen3_vl_moe", "qwen2_5_vl", + "minicpmv", ] diff --git a/tests/models/multimodal/generation/test_vit_cudagraph.py b/tests/models/multimodal/generation/test_vit_cudagraph.py index fb7bdfc8625d..f9b8fbccbd94 100644 --- a/tests/models/multimodal/generation/test_vit_cudagraph.py +++ b/tests/models/multimodal/generation/test_vit_cudagraph.py @@ -41,6 +41,20 @@ def qwen_vl_chat_template(content: str) -> str: return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n" +def minicpmv_25_chat_template(content: str) -> str: + """Llama3-style chat template used by MiniCPM-V 2.5.""" + return ( + f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + f"{content}" + f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + + +def minicpmv_chat_template(content: str) -> str: + """ChatML template used by MiniCPM-V 2.6 / 4.0 / 4.5.""" + return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n" + + MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = { "qwen3_vl": VitCudagraphTestConfig( model="Qwen/Qwen3-VL-2B-Instruct", @@ -66,6 +80,48 @@ def qwen_vl_chat_template(content: str) -> str: needs_video_metadata=False, marks=[pytest.mark.core_model], ), + "minicpmv_25": VitCudagraphTestConfig( + model="openbmb/MiniCPM-Llama3-V-2_5", + modalities=["image"], + image_prompt=minicpmv_25_chat_template( + "(./)\nWhat is in this image?" + ), + vllm_runner_kwargs={"trust_remote_code": True}, + marks=[pytest.mark.core_model], + ), + "minicpmv_26": VitCudagraphTestConfig( + model="openbmb/MiniCPM-V-2_6", + image_prompt=minicpmv_chat_template( + "(./)\nWhat is in this image?" + ), + video_prompt=minicpmv_chat_template( + "()\nDescribe this video in one sentence." + ), + vllm_runner_kwargs={"trust_remote_code": True}, + marks=[pytest.mark.core_model], + ), + "minicpmv_40": VitCudagraphTestConfig( + model="openbmb/MiniCPM-V-4", + image_prompt=minicpmv_chat_template( + "(./)\nWhat is in this image?" + ), + video_prompt=minicpmv_chat_template( + "()\nDescribe this video in one sentence." + ), + vllm_runner_kwargs={"trust_remote_code": True}, + marks=[pytest.mark.core_model], + ), + "minicpmv_45": VitCudagraphTestConfig( + model="openbmb/MiniCPM-V-4_5", + image_prompt=minicpmv_chat_template( + "(./)\nWhat is in this image?" + ), + video_prompt=minicpmv_chat_template( + "()\nDescribe this video in one sentence." + ), + vllm_runner_kwargs={"trust_remote_code": True}, + marks=[pytest.mark.core_model], + ), } diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index cda07ea291ea..a1db369bbe1b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -29,7 +29,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence from functools import partial from itertools import chain -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Any, ClassVar, Literal, TypeAlias import numpy as np import torch @@ -85,10 +85,16 @@ from vllm.utils.collection_utils import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.torch_utils import set_default_torch_dtype +from vllm.v1.worker.encoder_cudagraph_defs import ( + EncoderCudaGraphCaptureInputs, + EncoderCudaGraphConfig, + EncoderCudaGraphReplayBuffers, +) from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import ( MultiModalEmbeddings, + SupportsEncoderCudaGraph, SupportsLoRA, SupportsMultiModal, SupportsPP, @@ -126,6 +132,11 @@ class MiniCPMVImagePixelInputs(TensorSchema): TensorShape("bn"), ] + # Handled as batched input but shape check via TensorShape + # isn't strictly necessary since it defaults to None + # and has a non-tensor type. + temporal_ids: list[list[int]] | None = None + class MiniCPMVImageEmbeddingInputs(TensorSchema): """ @@ -204,25 +215,34 @@ def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor: patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] - self._adjust_pos_cache(tgt_sizes, device=device) + # In eager mode or during capture, adjust the cache size if necessary. + # We can safely use .max().item() here because during capture tgt_sizes + # contains the max possible sizes, so it will grow the cache sufficiently. + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + if max_h > self.max_size[0] or max_w > self.max_size[1]: + self._adjust_pos_cache(tgt_sizes, device=device) + + max_patch_len = x.shape[1] - max_patch_len = patch_len.max().item() - assert isinstance(max_patch_len, int) + seq_idx = torch.arange(max_patch_len, device=device).unsqueeze(0) - key_padding_mask = torch.zeros( - (bs, max_patch_len), dtype=torch.bool, device=device - ) + tgt_w = tgt_sizes[:, 1].unsqueeze(1) + tgt_w = torch.clamp(tgt_w, min=1) + + h_idx = seq_idx // tgt_w + w_idx = seq_idx % tgt_w - pos_embed = [] - for i in range(bs): - tgt_h, tgt_w = tgt_sizes[i].tolist() - pos_embed.append( - self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) - ) # patches * D - key_padding_mask[i, patch_len[i] :] = True - pos_embed = torch.nn.utils.rnn.pad_sequence( - pos_embed, batch_first=True, padding_value=0.0 + h_idx = torch.clamp(h_idx, max=self.pos_embed.shape[0] - 1) + w_idx = torch.clamp(w_idx, max=self.pos_embed.shape[1] - 1) + + key_padding_mask = seq_idx >= patch_len.unsqueeze(1) + + pos_embed = self.pos_embed[h_idx, w_idx].to(dtype) + pos_embed = torch.where( + key_padding_mask.unsqueeze(-1), torch.zeros_like(pos_embed), pos_embed ).permute(1, 0, 2) # BLD => L * B * D + x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D @@ -339,88 +359,108 @@ def forward( patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] - self._adjust_pos_cache(tgt_sizes, device=device) + # In eager mode or during capture, adjust the cache size if necessary. + # We safely use .max().item() here as it only runs once per graph capture + # and sees the maximum possible sizes. + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + if max_h > self.max_size[0] or max_w > self.max_size[1]: + self._adjust_pos_cache(tgt_sizes, device=device) + + max_patch_len = x.shape[1] temporal_pos_emb = False temporal_ids_flatten = None if temporal_ids is not None: - # example: [[-1], [-1], [2, 6, 9]] - temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) - max_temporal_size = max(temporal_ids_flatten, default=0) + if isinstance(temporal_ids, torch.Tensor): + temporal_ids_flatten = temporal_ids + max_temporal_size = temporal_ids_flatten.max().item() + else: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten, default=0) + temporal_ids_flatten = torch.tensor( + temporal_ids_flatten, dtype=torch.long, device=device + ) + if max_temporal_size > -1: temporal_pos_emb = True if max_temporal_size > self.max_temporal_size: self._adjust_temporal_pos_cache(max_temporal_size, device) - max_patch_len = patch_len.max().item() - assert isinstance(max_patch_len, int) + seq_idx = torch.arange(max_patch_len, device=device).unsqueeze(0) - key_padding_mask = torch.zeros( - (bs, max_patch_len), dtype=torch.bool, device=device - ) + tgt_w = tgt_sizes[:, 1].unsqueeze(1) + tgt_w = torch.clamp(tgt_w, min=1) + + h_idx = seq_idx // tgt_w + w_idx = seq_idx % tgt_w + + h_idx = torch.clamp(h_idx, max=self.pos_embed.shape[0] - 1) + w_idx = torch.clamp(w_idx, max=self.pos_embed.shape[1] - 1) + + pos_embed_2d = self.pos_embed[h_idx, w_idx].to(dtype) + + key_padding_mask = seq_idx >= patch_len.unsqueeze(1) + + pos_embed_2d = torch.where( + key_padding_mask.unsqueeze(-1), torch.zeros_like(pos_embed_2d), pos_embed_2d + ).permute(1, 0, 2) # (L, bs, D) x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D - pos_embed_2d = [] - pos_embed_temporal = [] - for i in range(bs): - tgt_h, tgt_w = tgt_sizes[i] - if temporal_pos_emb: - if temporal_ids_flatten[i] == -1: - pos_embed_temporal.append( - torch.zeros(self.embed_dim, dtype=dtype, device=device) - ) - else: - pos_embed_temporal.append( - self.temporal_pos_embed[temporal_ids_flatten[i]].to(dtype) - ) # D - - pos_embed_2d.append( - self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) - ) # patches * D - key_padding_mask[i, patch_len[i] :] = True - - pos_embed_2d = torch.nn.utils.rnn.pad_sequence( - pos_embed_2d, batch_first=True, padding_value=0.0 - ).permute(1, 0, 2) # BLD => L * B * D - k = x + pos_embed_2d v = x - if pos_embed_temporal: - k += torch.stack(pos_embed_temporal, dim=0) - bs = len(temporal_ids) - merge_k = [] - merge_v = [] - merge_key_padding_mask = [] - - start = 0 - for tp in temporal_ids: - end = start + len(tp) - # L * (end-start) * D -> (end-start) * L * D - # -> 1 * L*(end-start) * D - merge_k.append( - k[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) - ) - merge_v.append( - v[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) - ) - merge_key_padding_mask.append( - key_padding_mask[start:end, :].reshape(-1, 1) - ) - start = end + if temporal_pos_emb: + # temporal_ids_flatten is 1D tensor of shape (bs,) + pos_embed_temporal = torch.where( + (temporal_ids_flatten == -1).unsqueeze(-1), + torch.zeros(self.embed_dim, dtype=dtype, device=device), + self.temporal_pos_embed[torch.clamp(temporal_ids_flatten, min=0)].to( + dtype + ), + ) # (bs, D) + + k += pos_embed_temporal.unsqueeze(0) # (L, bs, D) + (1, bs, D) + + # skip the cross-frame merge loop when compiling into a CUDA graph + # (which occurs when temporal_ids is passed as a flat Tensor) because + # dynamic sequence lengths and batch sizes cannot be captured. + if not isinstance(temporal_ids, torch.Tensor): + bs = len(temporal_ids) + merge_k = [] + merge_v = [] + merge_key_padding_mask = [] + + start = 0 + for tp in temporal_ids: + end = start + len(tp) + # L * (end-start) * D -> (end-start) * L * D + # -> 1 * L*(end-start) * D + merge_k.append( + k[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) + ) + merge_v.append( + v[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) + ) + merge_key_padding_mask.append( + key_padding_mask[start:end, :].reshape(-1, 1) + ) + + start = end - k = torch.nn.utils.rnn.pad_sequence( - merge_k, batch_first=True, padding_value=0.0 - ).permute(1, 0, 2) # L*(end-start) - v = torch.nn.utils.rnn.pad_sequence( - merge_v, batch_first=True, padding_value=0.0 - ).permute(1, 0, 2) # L*(end-start) - key_padding_mask = torch.nn.utils.rnn.pad_sequence( - merge_key_padding_mask, batch_first=True, padding_value=True - ).squeeze(-1) + k = torch.nn.utils.rnn.pad_sequence( + merge_k, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence( + merge_v, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # L*(end-start) + key_padding_mask = torch.nn.utils.rnn.pad_sequence( + merge_key_padding_mask, batch_first=True, padding_value=True + ).squeeze(-1) out = self.attn( self._repeat(q, bs), # Q * B * D @@ -459,6 +499,7 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): video_image_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"), + temporal_ids=MultiModalFieldConfig.batched("video"), ) @@ -1068,6 +1109,7 @@ def _parse_and_validate_vision_input( ) -> MiniCPMVImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) + temporal_ids = kwargs.pop("temporal_ids", None) if pixel_values is None and image_embeds is None: return None @@ -1089,6 +1131,7 @@ def _parse_and_validate_vision_input( pixel_values=pixel_values_flat, tgt_sizes=tgt_sizes_flat, num_slices=num_slices_flat, + temporal_ids=temporal_ids, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -1219,6 +1262,392 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens raise NotImplementedError +# --- Encoder CUDA graph (MiniCPM-V 2.5 / 2.6 / 4.0 / 4.5; mixed into subclasses) --- +# Buffer keys +_MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES = "minicpmv_tgt_sizes" +_MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK = "minicpmv_patch_attn_mask" +# v4.5 only; see TODO in encoder_cudagraph_forward. +_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS = "minicpmv_temporal_ids" + +# mm_kwargs keys for the flat pixel-value tensor +_MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE = "minicpmv_encoder_input_flat" +_MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO = "minicpmv_video_encoder_input_flat" + + +def _mcpmv_tgt_sizes_tensor(mm_kwargs: dict[str, Any], *, video: bool) -> torch.Tensor: + key = "video_tgt_sizes" if video else "tgt_sizes" + return mm_kwargs[key] + + +def _mcpmv_pack_flat_pixels( + slices: list[torch.Tensor], + *, + pixel_height: int, + pixel_width: int, + max_num_slices: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Pack slice tensors into a fixed ``(max_num_slices, 3*H*W)`` buffer. + + Every slice is ``image_size × image_size`` (see ``_mcpmv_slice_pixel_size``), + so all rows in the buffer are fully occupied. + """ + flat_dim = 3 * pixel_height * pixel_width + packed = torch.zeros((max_num_slices, flat_dim), device=device, dtype=dtype) + n = min(len(slices), max_num_slices) + if n > 0: + packed[:n] = torch.stack(slices[:n]).reshape(n, -1).to(dtype=dtype) + return packed + + +class _MiniCPMVEncoderCudaGraphMixin(SupportsEncoderCudaGraph): + """SupportsEncoderCudaGraph for MiniCPM-V Idefics2 + resampler (not 2.0).""" + + supports_encoder_cudagraph: ClassVar[Literal[True]] = True + + def _mcpmv_slice_pixel_size(self) -> tuple[int, int]: + """Return (pixel_height, pixel_width) for each slice fed into vpm. + + Every slice is resized to image_size x image_size pixels before being + passed to the vision encoder, so this is always (image_size, image_size) + regardless of max_slice_num. + """ + image_size = int(self.vpm.embeddings.image_size) + return image_size, image_size + + def _mcpmv_max_patches_per_slice(self) -> int: + """Return max patch count per slice for patch_attention_mask sizing. + + The vision encoder divides each (image_size x image_size) slice into + (image_size // patch_size)^2 patches, so this equals that value. + """ + image_size = int(self.vpm.embeddings.image_size) + patch_size = int(self.vpm.embeddings.patch_size) + return (image_size // patch_size) ** 2 + + def _mcpmv_max_slices_cap( + self, + token_budget: int, + max_batch_size: int, + max_frames_per_batch: int, + ) -> int: + max_slice_num = int(getattr(self.config, "max_slice_num", 9)) + query_num = max(1, int(self.config.query_num)) + # Each slice produces query_num output tokens, so token_budget caps slices. + max_slices_by_token_budget = max(1, token_budget // query_num) + # Buffer must fit the largest possible input from either modality. + # Image batch: max_batch_size images × (max_slice_num + 1) slices each. + # Video batch: max_frames_per_batch frames × (max_slice_num + 1) slices each. + # Both modalities share the same captured graph, so take the larger of the two. + max_slices_by_content = max_batch_size * (max_slice_num + 1) + if self.version in {(2, 6), (4, 0), (4, 5)} and max_frames_per_batch > 0: + max_slices_by_content = max( + max_slices_by_content, + max_frames_per_batch * (max_slice_num + 1), + ) + return max(1, min(max_slices_by_token_budget, max_slices_by_content)) + + def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig: + buffer_keys = [ + _MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES, + _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK, + ] + if self.version == (4, 5): + buffer_keys.append(_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS) + # Video is only supported from 2.6 onward. + modalities = ["image"] + if self.version in {(2, 6), (4, 0), (4, 5)}: + modalities.append("video") + return EncoderCudaGraphConfig( + modalities=modalities, + input_key_by_modality={ + "image": _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE, + "video": _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO, + }, + buffer_keys=buffer_keys, + out_hidden_size=int(self.embed_dim), + ) + + def get_input_modality(self, mm_kwargs: dict[str, Any]) -> str: + if "video_pixel_values" in mm_kwargs: + return "video" + return "image" + + def get_max_frames_per_video(self) -> int: + info = MULTIMODAL_REGISTRY.get_processing_info(self.vllm_config.model_config) + return int( + info.get_num_frames_with_most_features( + seq_len=self.vllm_config.model_config.max_model_len, + mm_counts={ + "video": self.multimodal_config.get_limit_per_prompt("video") + }, + ) + ) + + def get_encoder_cudagraph_budget_range( + self, vllm_config: VllmConfig + ) -> tuple[int, int]: + # Each slice produces exactly query_num resampler output tokens. + # A thumbnail-only image has 1 slice, so query_num is the smallest + # possible encoder output and the natural minimum budget. + min_budget = int(self.config.query_num) + max_budget = min( + vllm_config.scheduler_config.max_num_batched_tokens, + vllm_config.model_config.max_model_len, + ) + return (min_budget, max_budget) + + def get_encoder_cudagraph_num_items(self, mm_kwargs: dict[str, Any]) -> int: + video = self.get_input_modality(mm_kwargs) == "video" + pixel_values_key = "video_pixel_values" if video else "pixel_values" + return len(mm_kwargs[pixel_values_key]) + + def get_encoder_cudagraph_per_item_output_tokens( + self, mm_kwargs: dict[str, Any] + ) -> list[int]: + query_num = int(self.config.query_num) + video = self.get_input_modality(mm_kwargs) == "video" + pixel_values_key = "video_pixel_values" if video else "pixel_values" + pixel_values: list[list[torch.Tensor]] = mm_kwargs[pixel_values_key] + return [len(img) * query_num for img in pixel_values] + + def get_encoder_cudagraph_per_item_input_sizes( + self, mm_kwargs: dict[str, Any] + ) -> list[int]: + video = self.get_input_modality(mm_kwargs) == "video" + tgt_sizes = _mcpmv_tgt_sizes_tensor(mm_kwargs, video=video) + pixel_values: list[list[torch.Tensor]] = mm_kwargs[ + "video_pixel_values" if video else "pixel_values" + ] + slice_counts = [len(img) for img in pixel_values] + patch_sums = tgt_sizes.prod(-1) + return [ + int(group.sum().item()) for group in torch.split(patch_sums, slice_counts) + ] + + def select_encoder_cudagraph_items( + self, + mm_kwargs: dict[str, Any], + indices: list[int], + ) -> dict[str, Any]: + video = self.get_input_modality(mm_kwargs) == "video" + pixel_values_key = "video_pixel_values" if video else "pixel_values" + tgt_key = "video_tgt_sizes" if video else "tgt_sizes" + flat_key = ( + _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO + if video + else _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE + ) + device = next(self.vpm.parameters()).device + pixel_h, pixel_w = self._mcpmv_slice_pixel_size() + + pixel_values: list[list[torch.Tensor]] = mm_kwargs[pixel_values_key] + tgt_sizes = _mcpmv_tgt_sizes_tensor(mm_kwargs, video=video) + + # Base dict without the stale flat buffer (recomputed at the end). + subset = {k: v for k, v in mm_kwargs.items() if k != flat_key} + + if not indices: + subset.update( + { + pixel_values_key: [], + tgt_key: torch.zeros((0, 2), dtype=torch.long, device=device), + flat_key: torch.zeros( + (0, 3 * pixel_h * pixel_w), device=device, dtype=torch.float32 + ), + } + ) + if self.version == (4, 5): + subset["temporal_ids"] = None + return subset + + # Select per-item nested slices and matching tgt_sizes rows. + slice_counts = [len(item_slices) for item_slices in pixel_values] + tgt_groups = torch.split(tgt_sizes, slice_counts) + selected_pixel_values = [pixel_values[i] for i in indices] + selected_tgt_sizes = torch.cat([tgt_groups[i] for i in indices], dim=0) + + # Pack ragged [3, H, W_i] slices into fixed [num_slices, 3*H*W] buffer. + selected_slices = flatten_2d_lists(selected_pixel_values) + packed_flat_pixels = _mcpmv_pack_flat_pixels( + selected_slices, + pixel_height=pixel_h, + pixel_width=pixel_w, + max_num_slices=len(selected_slices), + device=selected_slices[0].device, + dtype=selected_slices[0].dtype, + ) + + subset.update( + { + pixel_values_key: selected_pixel_values, + tgt_key: selected_tgt_sizes, + flat_key: packed_flat_pixels, + } + ) + if self.version == (4, 5): + temporal_ids = mm_kwargs.get("temporal_ids") + if temporal_ids is not None: + subset["temporal_ids"] = [temporal_ids[i] for i in indices] + return subset + + def prepare_encoder_cudagraph_capture_inputs( + self, + token_budget: int, + max_batch_size: int, + max_frames_per_batch: int, + device: torch.device, + dtype: torch.dtype, + ) -> EncoderCudaGraphCaptureInputs: + pixel_h, pixel_w = self._mcpmv_slice_pixel_size() + max_patches = self._mcpmv_max_patches_per_slice() + max_num_slices = self._mcpmv_max_slices_cap( + token_budget, + max_batch_size, + max_frames_per_batch, + ) + flat_dim = 3 * pixel_h * pixel_w + flat_pixel_buffer = torch.zeros( + (max_num_slices, flat_dim), device=device, dtype=dtype + ) + patch_hw = pixel_h // int(self.vpm.embeddings.patch_size) + dummy_tgt_sizes = torch.full( + (max_num_slices, 2), patch_hw, dtype=torch.long, device=device + ) + dummy_patch_mask = torch.ones( + (max_num_slices, max_patches), dtype=torch.bool, device=device + ) + buffers: dict[str, torch.Tensor] = { + _MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES: dummy_tgt_sizes, + _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK: dummy_patch_mask, + } + if self.version == (4, 5): + buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS] = torch.full( + (max_num_slices,), -1, dtype=torch.long, device=device + ) + mm_kwargs: dict[str, Any] = { + _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE: flat_pixel_buffer, + } + return EncoderCudaGraphCaptureInputs(mm_kwargs=mm_kwargs, buffers=buffers) + + def prepare_encoder_cudagraph_replay_buffers( + self, + mm_kwargs: dict[str, Any], + max_batch_size: int, + max_frames_per_batch: int, + ) -> EncoderCudaGraphReplayBuffers: + _ = max_frames_per_batch + _ = max_batch_size + video = self.get_input_modality(mm_kwargs) == "video" + max_patches = self._mcpmv_max_patches_per_slice() + device = next(self.vpm.parameters()).device + # After select_encoder_cudagraph_items, tgt_sizes contains exactly one + # row per selected slice, so its length equals the total slice count. + tgt_sizes = _mcpmv_tgt_sizes_tensor(mm_kwargs, video=video).to( + device=device, dtype=torch.long + ) + + patches_per_slice = tgt_sizes.prod(-1).clamp(max=max_patches) + col_idx = torch.arange(max_patches, device=device) + patch_attention_mask = col_idx.unsqueeze(0) < patches_per_slice.unsqueeze(1) + + buffers: dict[str, torch.Tensor] = { + _MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES: tgt_sizes.clone(), + _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK: patch_attention_mask, + } + if self.version == (4, 5): + temporal_ids = mm_kwargs.get("temporal_ids") + if temporal_ids is not None: + # temporal_ids is list[list[int]] (per-image, per-slice). + flat_ids = torch.tensor( + flatten_2d_lists(temporal_ids), dtype=torch.long, device=device + ) + else: + flat_ids = torch.full( + (len(tgt_sizes),), -1, dtype=torch.long, device=device + ) + buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS] = flat_ids + return EncoderCudaGraphReplayBuffers(buffers=buffers) + + def encoder_cudagraph_forward( + self, + mm_kwargs: dict[str, Any], + buffers: dict[str, torch.Tensor], + ) -> torch.Tensor: + modality = self.get_input_modality(mm_kwargs) + flat_key = ( + _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO + if modality == "video" + else _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE + ) + flat_pixel_buffer = mm_kwargs[flat_key] + pixel_h, pixel_w = self._mcpmv_slice_pixel_size() + max_num_slices, flat_dim = flat_pixel_buffer.shape + assert flat_dim == 3 * pixel_h * pixel_w + all_pixel_values = flat_pixel_buffer.view(max_num_slices, 3, pixel_h, pixel_w) + + tgt_sizes = buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES] + patch_attention_mask = buffers[ + _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK + ].unsqueeze(1) + + # v2.5 vpm does not accept tgt_sizes. + vpm_tgt_sizes = None if self.version == (2, 5) else tgt_sizes + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=vpm_tgt_sizes, + ) + + if self.version == (4, 5): + temporal_ids = buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS] + resampler_out = self.resampler(vision_embedding, tgt_sizes, temporal_ids) + else: + resampler_out = self.resampler(vision_embedding, tgt_sizes) + + query_num = int(self.config.query_num) + return resampler_out.reshape(max_num_slices * query_num, int(self.embed_dim)) + + def encoder_eager_forward(self, mm_kwargs: dict[str, Any]) -> torch.Tensor: + """Eager encoder path; returns ``(total_tokens, embed_dim)`` like + ``encoder_cudagraph_forward``. + + Called by the manager only for images/videos that exceed all token + budgets (single-item batches), so ``segments`` always has exactly one + element in practice. Version-specific logic (e.g. temporal embeddings + for v4.5) is handled transparently by the polymorphic dispatch inside + ``get_vision_hidden_states``. + """ + mm_kwargs_no_flat = { + k: v + for k, v in mm_kwargs.items() + if k + not in ( + _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE, + _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO, + ) + } + modalities = self._parse_and_validate_multimodal_inputs(**mm_kwargs_no_flat) + segments: list[torch.Tensor] = [] + embed_dim = self.embed_dim + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + image_embeddings = self.get_vision_hidden_states(image_input) + segments.append(image_embeddings.reshape(-1, embed_dim)) + elif modality == "videos": + video_input = modalities["videos"] + video_embeddings = self.get_vision_hidden_states(video_input) + segments.append(video_embeddings.reshape(-1, embed_dim)) + if not segments: + raise RuntimeError( + "MiniCPM-V encoder cudagraph eager path expects pixel_values " + "or video_pixel_values" + ) + return torch.cat(segments, dim=0) + + class MiniCPMV2_0(MiniCPMVBaseModel): supports_encoder_tp_data = False @@ -1310,7 +1739,7 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens return torch.vstack(res) -class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): +class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1401,7 +1830,7 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens return self.resampler(vision_embedding, tgt_sizes) -class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): +class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1499,7 +1928,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded -class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): +class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1596,7 +2025,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded -class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): +class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1668,9 +2097,6 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens dtype = pixel_values[0].dtype all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) - all_temporal_ids = ( - None if temporal_ids is None else flatten_2d_lists(temporal_ids) - ) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1689,7 +2115,7 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens tgt_sizes=tgt_sizes, ) - return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) + return self.resampler(vision_embedding, tgt_sizes, temporal_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"])