diff --git a/docs/design/cuda_graphs_multimodal.md b/docs/design/cuda_graphs_multimodal.md
index f44ef359df38..3920ff34556f 100644
--- a/docs/design/cuda_graphs_multimodal.md
+++ b/docs/design/cuda_graphs_multimodal.md
@@ -87,6 +87,7 @@ Models opt-in to encoder CUDA Graphs by implementing the [SupportsEncoderCudaGra
| ------------ | ------ | ------------ | ------------ |
| `Qwen3VLForConditionalGeneration` | `Qwen3-VL` | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | `Qwen2.5-VL` | ✅︎ | ✅︎ |
+| `MiniCPMV` | `MiniCPMV2.5`,`MiniCPMV2.6`,`MiniCPMV4.0`,`MiniCPMV4.5` | ✅︎ | ✅︎ |
!!! note
Encoder CUDA Graphs have currently been tested with `--mm-encoder-attn-backend=FLASH_ATTN` and `--mm-encoder-attn-backend=FLASHINFER` on Blackwell GPUs.
diff --git a/examples/generate/multimodal/vision_language_offline.py b/examples/generate/multimodal/vision_language_offline.py
index 794f20dd0a52..c691da7ac64e 100644
--- a/examples/generate/multimodal/vision_language_offline.py
+++ b/examples/generate/multimodal/vision_language_offline.py
@@ -2467,6 +2467,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"qwen3_vl",
"qwen3_vl_moe",
"qwen2_5_vl",
+ "minicpmv",
]
diff --git a/tests/models/multimodal/generation/test_vit_cudagraph.py b/tests/models/multimodal/generation/test_vit_cudagraph.py
index fb7bdfc8625d..f9b8fbccbd94 100644
--- a/tests/models/multimodal/generation/test_vit_cudagraph.py
+++ b/tests/models/multimodal/generation/test_vit_cudagraph.py
@@ -41,6 +41,20 @@ def qwen_vl_chat_template(content: str) -> str:
return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
+def minicpmv_25_chat_template(content: str) -> str:
+ """Llama3-style chat template used by MiniCPM-V 2.5."""
+ return (
+ f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
+ f"{content}"
+ f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
+ )
+
+
+def minicpmv_chat_template(content: str) -> str:
+ """ChatML template used by MiniCPM-V 2.6 / 4.0 / 4.5."""
+ return f"<|im_start|>user\n{content}<|im_end|>\n<|im_start|>assistant\n"
+
+
MODEL_CONFIGS: dict[str, VitCudagraphTestConfig] = {
"qwen3_vl": VitCudagraphTestConfig(
model="Qwen/Qwen3-VL-2B-Instruct",
@@ -66,6 +80,48 @@ def qwen_vl_chat_template(content: str) -> str:
needs_video_metadata=False,
marks=[pytest.mark.core_model],
),
+ "minicpmv_25": VitCudagraphTestConfig(
+ model="openbmb/MiniCPM-Llama3-V-2_5",
+ modalities=["image"],
+ image_prompt=minicpmv_25_chat_template(
+ "(./)\nWhat is in this image?"
+ ),
+ vllm_runner_kwargs={"trust_remote_code": True},
+ marks=[pytest.mark.core_model],
+ ),
+ "minicpmv_26": VitCudagraphTestConfig(
+ model="openbmb/MiniCPM-V-2_6",
+ image_prompt=minicpmv_chat_template(
+ "(./)\nWhat is in this image?"
+ ),
+ video_prompt=minicpmv_chat_template(
+ "()\nDescribe this video in one sentence."
+ ),
+ vllm_runner_kwargs={"trust_remote_code": True},
+ marks=[pytest.mark.core_model],
+ ),
+ "minicpmv_40": VitCudagraphTestConfig(
+ model="openbmb/MiniCPM-V-4",
+ image_prompt=minicpmv_chat_template(
+ "(./)\nWhat is in this image?"
+ ),
+ video_prompt=minicpmv_chat_template(
+ "()\nDescribe this video in one sentence."
+ ),
+ vllm_runner_kwargs={"trust_remote_code": True},
+ marks=[pytest.mark.core_model],
+ ),
+ "minicpmv_45": VitCudagraphTestConfig(
+ model="openbmb/MiniCPM-V-4_5",
+ image_prompt=minicpmv_chat_template(
+ "(./)\nWhat is in this image?"
+ ),
+ video_prompt=minicpmv_chat_template(
+ "()\nDescribe this video in one sentence."
+ ),
+ vllm_runner_kwargs={"trust_remote_code": True},
+ marks=[pytest.mark.core_model],
+ ),
}
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index cda07ea291ea..a1db369bbe1b 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -29,7 +29,7 @@
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from itertools import chain
-from typing import Annotated, Any, Literal, TypeAlias
+from typing import Annotated, Any, ClassVar, Literal, TypeAlias
import numpy as np
import torch
@@ -85,10 +85,16 @@
from vllm.utils.collection_utils import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
+from vllm.v1.worker.encoder_cudagraph_defs import (
+ EncoderCudaGraphCaptureInputs,
+ EncoderCudaGraphConfig,
+ EncoderCudaGraphReplayBuffers,
+)
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (
MultiModalEmbeddings,
+ SupportsEncoderCudaGraph,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
@@ -126,6 +132,11 @@ class MiniCPMVImagePixelInputs(TensorSchema):
TensorShape("bn"),
]
+ # Handled as batched input but shape check via TensorShape
+ # isn't strictly necessary since it defaults to None
+ # and has a non-tensor type.
+ temporal_ids: list[list[int]] | None = None
+
class MiniCPMVImageEmbeddingInputs(TensorSchema):
"""
@@ -204,25 +215,34 @@ def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor:
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
- self._adjust_pos_cache(tgt_sizes, device=device)
+ # In eager mode or during capture, adjust the cache size if necessary.
+ # We can safely use .max().item() here because during capture tgt_sizes
+ # contains the max possible sizes, so it will grow the cache sufficiently.
+ max_h = tgt_sizes[:, 0].max().item()
+ max_w = tgt_sizes[:, 1].max().item()
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
+ self._adjust_pos_cache(tgt_sizes, device=device)
+
+ max_patch_len = x.shape[1]
- max_patch_len = patch_len.max().item()
- assert isinstance(max_patch_len, int)
+ seq_idx = torch.arange(max_patch_len, device=device).unsqueeze(0)
- key_padding_mask = torch.zeros(
- (bs, max_patch_len), dtype=torch.bool, device=device
- )
+ tgt_w = tgt_sizes[:, 1].unsqueeze(1)
+ tgt_w = torch.clamp(tgt_w, min=1)
+
+ h_idx = seq_idx // tgt_w
+ w_idx = seq_idx % tgt_w
- pos_embed = []
- for i in range(bs):
- tgt_h, tgt_w = tgt_sizes[i].tolist()
- pos_embed.append(
- self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)
- ) # patches * D
- key_padding_mask[i, patch_len[i] :] = True
- pos_embed = torch.nn.utils.rnn.pad_sequence(
- pos_embed, batch_first=True, padding_value=0.0
+ h_idx = torch.clamp(h_idx, max=self.pos_embed.shape[0] - 1)
+ w_idx = torch.clamp(w_idx, max=self.pos_embed.shape[1] - 1)
+
+ key_padding_mask = seq_idx >= patch_len.unsqueeze(1)
+
+ pos_embed = self.pos_embed[h_idx, w_idx].to(dtype)
+ pos_embed = torch.where(
+ key_padding_mask.unsqueeze(-1), torch.zeros_like(pos_embed), pos_embed
).permute(1, 0, 2) # BLD => L * B * D
+
x, _ = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
@@ -339,88 +359,108 @@ def forward(
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
- self._adjust_pos_cache(tgt_sizes, device=device)
+ # In eager mode or during capture, adjust the cache size if necessary.
+ # We safely use .max().item() here as it only runs once per graph capture
+ # and sees the maximum possible sizes.
+ max_h = tgt_sizes[:, 0].max().item()
+ max_w = tgt_sizes[:, 1].max().item()
+ if max_h > self.max_size[0] or max_w > self.max_size[1]:
+ self._adjust_pos_cache(tgt_sizes, device=device)
+
+ max_patch_len = x.shape[1]
temporal_pos_emb = False
temporal_ids_flatten = None
if temporal_ids is not None:
- # example: [[-1], [-1], [2, 6, 9]]
- temporal_ids_flatten = list(chain.from_iterable(temporal_ids))
- max_temporal_size = max(temporal_ids_flatten, default=0)
+ if isinstance(temporal_ids, torch.Tensor):
+ temporal_ids_flatten = temporal_ids
+ max_temporal_size = temporal_ids_flatten.max().item()
+ else:
+ # example: [[-1], [-1], [2, 6, 9]]
+ temporal_ids_flatten = list(chain.from_iterable(temporal_ids))
+ max_temporal_size = max(temporal_ids_flatten, default=0)
+ temporal_ids_flatten = torch.tensor(
+ temporal_ids_flatten, dtype=torch.long, device=device
+ )
+
if max_temporal_size > -1:
temporal_pos_emb = True
if max_temporal_size > self.max_temporal_size:
self._adjust_temporal_pos_cache(max_temporal_size, device)
- max_patch_len = patch_len.max().item()
- assert isinstance(max_patch_len, int)
+ seq_idx = torch.arange(max_patch_len, device=device).unsqueeze(0)
- key_padding_mask = torch.zeros(
- (bs, max_patch_len), dtype=torch.bool, device=device
- )
+ tgt_w = tgt_sizes[:, 1].unsqueeze(1)
+ tgt_w = torch.clamp(tgt_w, min=1)
+
+ h_idx = seq_idx // tgt_w
+ w_idx = seq_idx % tgt_w
+
+ h_idx = torch.clamp(h_idx, max=self.pos_embed.shape[0] - 1)
+ w_idx = torch.clamp(w_idx, max=self.pos_embed.shape[1] - 1)
+
+ pos_embed_2d = self.pos_embed[h_idx, w_idx].to(dtype)
+
+ key_padding_mask = seq_idx >= patch_len.unsqueeze(1)
+
+ pos_embed_2d = torch.where(
+ key_padding_mask.unsqueeze(-1), torch.zeros_like(pos_embed_2d), pos_embed_2d
+ ).permute(1, 0, 2) # (L, bs, D)
x, _ = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
- pos_embed_2d = []
- pos_embed_temporal = []
- for i in range(bs):
- tgt_h, tgt_w = tgt_sizes[i]
- if temporal_pos_emb:
- if temporal_ids_flatten[i] == -1:
- pos_embed_temporal.append(
- torch.zeros(self.embed_dim, dtype=dtype, device=device)
- )
- else:
- pos_embed_temporal.append(
- self.temporal_pos_embed[temporal_ids_flatten[i]].to(dtype)
- ) # D
-
- pos_embed_2d.append(
- self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)
- ) # patches * D
- key_padding_mask[i, patch_len[i] :] = True
-
- pos_embed_2d = torch.nn.utils.rnn.pad_sequence(
- pos_embed_2d, batch_first=True, padding_value=0.0
- ).permute(1, 0, 2) # BLD => L * B * D
-
k = x + pos_embed_2d
v = x
- if pos_embed_temporal:
- k += torch.stack(pos_embed_temporal, dim=0)
- bs = len(temporal_ids)
- merge_k = []
- merge_v = []
- merge_key_padding_mask = []
-
- start = 0
- for tp in temporal_ids:
- end = start + len(tp)
- # L * (end-start) * D -> (end-start) * L * D
- # -> 1 * L*(end-start) * D
- merge_k.append(
- k[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim)
- )
- merge_v.append(
- v[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim)
- )
- merge_key_padding_mask.append(
- key_padding_mask[start:end, :].reshape(-1, 1)
- )
- start = end
+ if temporal_pos_emb:
+ # temporal_ids_flatten is 1D tensor of shape (bs,)
+ pos_embed_temporal = torch.where(
+ (temporal_ids_flatten == -1).unsqueeze(-1),
+ torch.zeros(self.embed_dim, dtype=dtype, device=device),
+ self.temporal_pos_embed[torch.clamp(temporal_ids_flatten, min=0)].to(
+ dtype
+ ),
+ ) # (bs, D)
+
+ k += pos_embed_temporal.unsqueeze(0) # (L, bs, D) + (1, bs, D)
+
+ # skip the cross-frame merge loop when compiling into a CUDA graph
+ # (which occurs when temporal_ids is passed as a flat Tensor) because
+ # dynamic sequence lengths and batch sizes cannot be captured.
+ if not isinstance(temporal_ids, torch.Tensor):
+ bs = len(temporal_ids)
+ merge_k = []
+ merge_v = []
+ merge_key_padding_mask = []
+
+ start = 0
+ for tp in temporal_ids:
+ end = start + len(tp)
+ # L * (end-start) * D -> (end-start) * L * D
+ # -> 1 * L*(end-start) * D
+ merge_k.append(
+ k[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim)
+ )
+ merge_v.append(
+ v[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim)
+ )
+ merge_key_padding_mask.append(
+ key_padding_mask[start:end, :].reshape(-1, 1)
+ )
+
+ start = end
- k = torch.nn.utils.rnn.pad_sequence(
- merge_k, batch_first=True, padding_value=0.0
- ).permute(1, 0, 2) # L*(end-start)
- v = torch.nn.utils.rnn.pad_sequence(
- merge_v, batch_first=True, padding_value=0.0
- ).permute(1, 0, 2) # L*(end-start)
- key_padding_mask = torch.nn.utils.rnn.pad_sequence(
- merge_key_padding_mask, batch_first=True, padding_value=True
- ).squeeze(-1)
+ k = torch.nn.utils.rnn.pad_sequence(
+ merge_k, batch_first=True, padding_value=0.0
+ ).permute(1, 0, 2) # L*(end-start)
+ v = torch.nn.utils.rnn.pad_sequence(
+ merge_v, batch_first=True, padding_value=0.0
+ ).permute(1, 0, 2) # L*(end-start)
+ key_padding_mask = torch.nn.utils.rnn.pad_sequence(
+ merge_key_padding_mask, batch_first=True, padding_value=True
+ ).squeeze(-1)
out = self.attn(
self._repeat(q, bs), # Q * B * D
@@ -459,6 +499,7 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"),
+ temporal_ids=MultiModalFieldConfig.batched("video"),
)
@@ -1068,6 +1109,7 @@ def _parse_and_validate_vision_input(
) -> MiniCPMVImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
+ temporal_ids = kwargs.pop("temporal_ids", None)
if pixel_values is None and image_embeds is None:
return None
@@ -1089,6 +1131,7 @@ def _parse_and_validate_vision_input(
pixel_values=pixel_values_flat,
tgt_sizes=tgt_sizes_flat,
num_slices=num_slices_flat,
+ temporal_ids=temporal_ids,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@@ -1219,6 +1262,392 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens
raise NotImplementedError
+# --- Encoder CUDA graph (MiniCPM-V 2.5 / 2.6 / 4.0 / 4.5; mixed into subclasses) ---
+# Buffer keys
+_MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES = "minicpmv_tgt_sizes"
+_MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK = "minicpmv_patch_attn_mask"
+# v4.5 only; see TODO in encoder_cudagraph_forward.
+_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS = "minicpmv_temporal_ids"
+
+# mm_kwargs keys for the flat pixel-value tensor
+_MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE = "minicpmv_encoder_input_flat"
+_MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO = "minicpmv_video_encoder_input_flat"
+
+
+def _mcpmv_tgt_sizes_tensor(mm_kwargs: dict[str, Any], *, video: bool) -> torch.Tensor:
+ key = "video_tgt_sizes" if video else "tgt_sizes"
+ return mm_kwargs[key]
+
+
+def _mcpmv_pack_flat_pixels(
+ slices: list[torch.Tensor],
+ *,
+ pixel_height: int,
+ pixel_width: int,
+ max_num_slices: int,
+ device: torch.device,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ """Pack slice tensors into a fixed ``(max_num_slices, 3*H*W)`` buffer.
+
+ Every slice is ``image_size × image_size`` (see ``_mcpmv_slice_pixel_size``),
+ so all rows in the buffer are fully occupied.
+ """
+ flat_dim = 3 * pixel_height * pixel_width
+ packed = torch.zeros((max_num_slices, flat_dim), device=device, dtype=dtype)
+ n = min(len(slices), max_num_slices)
+ if n > 0:
+ packed[:n] = torch.stack(slices[:n]).reshape(n, -1).to(dtype=dtype)
+ return packed
+
+
+class _MiniCPMVEncoderCudaGraphMixin(SupportsEncoderCudaGraph):
+ """SupportsEncoderCudaGraph for MiniCPM-V Idefics2 + resampler (not 2.0)."""
+
+ supports_encoder_cudagraph: ClassVar[Literal[True]] = True
+
+ def _mcpmv_slice_pixel_size(self) -> tuple[int, int]:
+ """Return (pixel_height, pixel_width) for each slice fed into vpm.
+
+ Every slice is resized to image_size x image_size pixels before being
+ passed to the vision encoder, so this is always (image_size, image_size)
+ regardless of max_slice_num.
+ """
+ image_size = int(self.vpm.embeddings.image_size)
+ return image_size, image_size
+
+ def _mcpmv_max_patches_per_slice(self) -> int:
+ """Return max patch count per slice for patch_attention_mask sizing.
+
+ The vision encoder divides each (image_size x image_size) slice into
+ (image_size // patch_size)^2 patches, so this equals that value.
+ """
+ image_size = int(self.vpm.embeddings.image_size)
+ patch_size = int(self.vpm.embeddings.patch_size)
+ return (image_size // patch_size) ** 2
+
+ def _mcpmv_max_slices_cap(
+ self,
+ token_budget: int,
+ max_batch_size: int,
+ max_frames_per_batch: int,
+ ) -> int:
+ max_slice_num = int(getattr(self.config, "max_slice_num", 9))
+ query_num = max(1, int(self.config.query_num))
+ # Each slice produces query_num output tokens, so token_budget caps slices.
+ max_slices_by_token_budget = max(1, token_budget // query_num)
+ # Buffer must fit the largest possible input from either modality.
+ # Image batch: max_batch_size images × (max_slice_num + 1) slices each.
+ # Video batch: max_frames_per_batch frames × (max_slice_num + 1) slices each.
+ # Both modalities share the same captured graph, so take the larger of the two.
+ max_slices_by_content = max_batch_size * (max_slice_num + 1)
+ if self.version in {(2, 6), (4, 0), (4, 5)} and max_frames_per_batch > 0:
+ max_slices_by_content = max(
+ max_slices_by_content,
+ max_frames_per_batch * (max_slice_num + 1),
+ )
+ return max(1, min(max_slices_by_token_budget, max_slices_by_content))
+
+ def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig:
+ buffer_keys = [
+ _MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES,
+ _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK,
+ ]
+ if self.version == (4, 5):
+ buffer_keys.append(_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS)
+ # Video is only supported from 2.6 onward.
+ modalities = ["image"]
+ if self.version in {(2, 6), (4, 0), (4, 5)}:
+ modalities.append("video")
+ return EncoderCudaGraphConfig(
+ modalities=modalities,
+ input_key_by_modality={
+ "image": _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE,
+ "video": _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO,
+ },
+ buffer_keys=buffer_keys,
+ out_hidden_size=int(self.embed_dim),
+ )
+
+ def get_input_modality(self, mm_kwargs: dict[str, Any]) -> str:
+ if "video_pixel_values" in mm_kwargs:
+ return "video"
+ return "image"
+
+ def get_max_frames_per_video(self) -> int:
+ info = MULTIMODAL_REGISTRY.get_processing_info(self.vllm_config.model_config)
+ return int(
+ info.get_num_frames_with_most_features(
+ seq_len=self.vllm_config.model_config.max_model_len,
+ mm_counts={
+ "video": self.multimodal_config.get_limit_per_prompt("video")
+ },
+ )
+ )
+
+ def get_encoder_cudagraph_budget_range(
+ self, vllm_config: VllmConfig
+ ) -> tuple[int, int]:
+ # Each slice produces exactly query_num resampler output tokens.
+ # A thumbnail-only image has 1 slice, so query_num is the smallest
+ # possible encoder output and the natural minimum budget.
+ min_budget = int(self.config.query_num)
+ max_budget = min(
+ vllm_config.scheduler_config.max_num_batched_tokens,
+ vllm_config.model_config.max_model_len,
+ )
+ return (min_budget, max_budget)
+
+ def get_encoder_cudagraph_num_items(self, mm_kwargs: dict[str, Any]) -> int:
+ video = self.get_input_modality(mm_kwargs) == "video"
+ pixel_values_key = "video_pixel_values" if video else "pixel_values"
+ return len(mm_kwargs[pixel_values_key])
+
+ def get_encoder_cudagraph_per_item_output_tokens(
+ self, mm_kwargs: dict[str, Any]
+ ) -> list[int]:
+ query_num = int(self.config.query_num)
+ video = self.get_input_modality(mm_kwargs) == "video"
+ pixel_values_key = "video_pixel_values" if video else "pixel_values"
+ pixel_values: list[list[torch.Tensor]] = mm_kwargs[pixel_values_key]
+ return [len(img) * query_num for img in pixel_values]
+
+ def get_encoder_cudagraph_per_item_input_sizes(
+ self, mm_kwargs: dict[str, Any]
+ ) -> list[int]:
+ video = self.get_input_modality(mm_kwargs) == "video"
+ tgt_sizes = _mcpmv_tgt_sizes_tensor(mm_kwargs, video=video)
+ pixel_values: list[list[torch.Tensor]] = mm_kwargs[
+ "video_pixel_values" if video else "pixel_values"
+ ]
+ slice_counts = [len(img) for img in pixel_values]
+ patch_sums = tgt_sizes.prod(-1)
+ return [
+ int(group.sum().item()) for group in torch.split(patch_sums, slice_counts)
+ ]
+
+ def select_encoder_cudagraph_items(
+ self,
+ mm_kwargs: dict[str, Any],
+ indices: list[int],
+ ) -> dict[str, Any]:
+ video = self.get_input_modality(mm_kwargs) == "video"
+ pixel_values_key = "video_pixel_values" if video else "pixel_values"
+ tgt_key = "video_tgt_sizes" if video else "tgt_sizes"
+ flat_key = (
+ _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO
+ if video
+ else _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE
+ )
+ device = next(self.vpm.parameters()).device
+ pixel_h, pixel_w = self._mcpmv_slice_pixel_size()
+
+ pixel_values: list[list[torch.Tensor]] = mm_kwargs[pixel_values_key]
+ tgt_sizes = _mcpmv_tgt_sizes_tensor(mm_kwargs, video=video)
+
+ # Base dict without the stale flat buffer (recomputed at the end).
+ subset = {k: v for k, v in mm_kwargs.items() if k != flat_key}
+
+ if not indices:
+ subset.update(
+ {
+ pixel_values_key: [],
+ tgt_key: torch.zeros((0, 2), dtype=torch.long, device=device),
+ flat_key: torch.zeros(
+ (0, 3 * pixel_h * pixel_w), device=device, dtype=torch.float32
+ ),
+ }
+ )
+ if self.version == (4, 5):
+ subset["temporal_ids"] = None
+ return subset
+
+ # Select per-item nested slices and matching tgt_sizes rows.
+ slice_counts = [len(item_slices) for item_slices in pixel_values]
+ tgt_groups = torch.split(tgt_sizes, slice_counts)
+ selected_pixel_values = [pixel_values[i] for i in indices]
+ selected_tgt_sizes = torch.cat([tgt_groups[i] for i in indices], dim=0)
+
+ # Pack ragged [3, H, W_i] slices into fixed [num_slices, 3*H*W] buffer.
+ selected_slices = flatten_2d_lists(selected_pixel_values)
+ packed_flat_pixels = _mcpmv_pack_flat_pixels(
+ selected_slices,
+ pixel_height=pixel_h,
+ pixel_width=pixel_w,
+ max_num_slices=len(selected_slices),
+ device=selected_slices[0].device,
+ dtype=selected_slices[0].dtype,
+ )
+
+ subset.update(
+ {
+ pixel_values_key: selected_pixel_values,
+ tgt_key: selected_tgt_sizes,
+ flat_key: packed_flat_pixels,
+ }
+ )
+ if self.version == (4, 5):
+ temporal_ids = mm_kwargs.get("temporal_ids")
+ if temporal_ids is not None:
+ subset["temporal_ids"] = [temporal_ids[i] for i in indices]
+ return subset
+
+ def prepare_encoder_cudagraph_capture_inputs(
+ self,
+ token_budget: int,
+ max_batch_size: int,
+ max_frames_per_batch: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> EncoderCudaGraphCaptureInputs:
+ pixel_h, pixel_w = self._mcpmv_slice_pixel_size()
+ max_patches = self._mcpmv_max_patches_per_slice()
+ max_num_slices = self._mcpmv_max_slices_cap(
+ token_budget,
+ max_batch_size,
+ max_frames_per_batch,
+ )
+ flat_dim = 3 * pixel_h * pixel_w
+ flat_pixel_buffer = torch.zeros(
+ (max_num_slices, flat_dim), device=device, dtype=dtype
+ )
+ patch_hw = pixel_h // int(self.vpm.embeddings.patch_size)
+ dummy_tgt_sizes = torch.full(
+ (max_num_slices, 2), patch_hw, dtype=torch.long, device=device
+ )
+ dummy_patch_mask = torch.ones(
+ (max_num_slices, max_patches), dtype=torch.bool, device=device
+ )
+ buffers: dict[str, torch.Tensor] = {
+ _MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES: dummy_tgt_sizes,
+ _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK: dummy_patch_mask,
+ }
+ if self.version == (4, 5):
+ buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS] = torch.full(
+ (max_num_slices,), -1, dtype=torch.long, device=device
+ )
+ mm_kwargs: dict[str, Any] = {
+ _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE: flat_pixel_buffer,
+ }
+ return EncoderCudaGraphCaptureInputs(mm_kwargs=mm_kwargs, buffers=buffers)
+
+ def prepare_encoder_cudagraph_replay_buffers(
+ self,
+ mm_kwargs: dict[str, Any],
+ max_batch_size: int,
+ max_frames_per_batch: int,
+ ) -> EncoderCudaGraphReplayBuffers:
+ _ = max_frames_per_batch
+ _ = max_batch_size
+ video = self.get_input_modality(mm_kwargs) == "video"
+ max_patches = self._mcpmv_max_patches_per_slice()
+ device = next(self.vpm.parameters()).device
+ # After select_encoder_cudagraph_items, tgt_sizes contains exactly one
+ # row per selected slice, so its length equals the total slice count.
+ tgt_sizes = _mcpmv_tgt_sizes_tensor(mm_kwargs, video=video).to(
+ device=device, dtype=torch.long
+ )
+
+ patches_per_slice = tgt_sizes.prod(-1).clamp(max=max_patches)
+ col_idx = torch.arange(max_patches, device=device)
+ patch_attention_mask = col_idx.unsqueeze(0) < patches_per_slice.unsqueeze(1)
+
+ buffers: dict[str, torch.Tensor] = {
+ _MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES: tgt_sizes.clone(),
+ _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK: patch_attention_mask,
+ }
+ if self.version == (4, 5):
+ temporal_ids = mm_kwargs.get("temporal_ids")
+ if temporal_ids is not None:
+ # temporal_ids is list[list[int]] (per-image, per-slice).
+ flat_ids = torch.tensor(
+ flatten_2d_lists(temporal_ids), dtype=torch.long, device=device
+ )
+ else:
+ flat_ids = torch.full(
+ (len(tgt_sizes),), -1, dtype=torch.long, device=device
+ )
+ buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS] = flat_ids
+ return EncoderCudaGraphReplayBuffers(buffers=buffers)
+
+ def encoder_cudagraph_forward(
+ self,
+ mm_kwargs: dict[str, Any],
+ buffers: dict[str, torch.Tensor],
+ ) -> torch.Tensor:
+ modality = self.get_input_modality(mm_kwargs)
+ flat_key = (
+ _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO
+ if modality == "video"
+ else _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE
+ )
+ flat_pixel_buffer = mm_kwargs[flat_key]
+ pixel_h, pixel_w = self._mcpmv_slice_pixel_size()
+ max_num_slices, flat_dim = flat_pixel_buffer.shape
+ assert flat_dim == 3 * pixel_h * pixel_w
+ all_pixel_values = flat_pixel_buffer.view(max_num_slices, 3, pixel_h, pixel_w)
+
+ tgt_sizes = buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TGT_SIZES]
+ patch_attention_mask = buffers[
+ _MINICPMV_CUDAGRAPH_BUF_KEY_PATCH_MASK
+ ].unsqueeze(1)
+
+ # v2.5 vpm does not accept tgt_sizes.
+ vpm_tgt_sizes = None if self.version == (2, 5) else tgt_sizes
+ vision_embedding = self.vpm(
+ all_pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ tgt_sizes=vpm_tgt_sizes,
+ )
+
+ if self.version == (4, 5):
+ temporal_ids = buffers[_MINICPMV_CUDAGRAPH_BUF_KEY_TEMPORAL_IDS]
+ resampler_out = self.resampler(vision_embedding, tgt_sizes, temporal_ids)
+ else:
+ resampler_out = self.resampler(vision_embedding, tgt_sizes)
+
+ query_num = int(self.config.query_num)
+ return resampler_out.reshape(max_num_slices * query_num, int(self.embed_dim))
+
+ def encoder_eager_forward(self, mm_kwargs: dict[str, Any]) -> torch.Tensor:
+ """Eager encoder path; returns ``(total_tokens, embed_dim)`` like
+ ``encoder_cudagraph_forward``.
+
+ Called by the manager only for images/videos that exceed all token
+ budgets (single-item batches), so ``segments`` always has exactly one
+ element in practice. Version-specific logic (e.g. temporal embeddings
+ for v4.5) is handled transparently by the polymorphic dispatch inside
+ ``get_vision_hidden_states``.
+ """
+ mm_kwargs_no_flat = {
+ k: v
+ for k, v in mm_kwargs.items()
+ if k
+ not in (
+ _MINICPMV_CUDAGRAPH_FLAT_KEY_IMAGE,
+ _MINICPMV_CUDAGRAPH_FLAT_KEY_VIDEO,
+ )
+ }
+ modalities = self._parse_and_validate_multimodal_inputs(**mm_kwargs_no_flat)
+ segments: list[torch.Tensor] = []
+ embed_dim = self.embed_dim
+ for modality in modalities:
+ if modality == "images":
+ image_input = modalities["images"]
+ image_embeddings = self.get_vision_hidden_states(image_input)
+ segments.append(image_embeddings.reshape(-1, embed_dim))
+ elif modality == "videos":
+ video_input = modalities["videos"]
+ video_embeddings = self.get_vision_hidden_states(video_input)
+ segments.append(video_embeddings.reshape(-1, embed_dim))
+ if not segments:
+ raise RuntimeError(
+ "MiniCPM-V encoder cudagraph eager path expects pixel_values "
+ "or video_pixel_values"
+ )
+ return torch.cat(segments, dim=0)
+
+
class MiniCPMV2_0(MiniCPMVBaseModel):
supports_encoder_tp_data = False
@@ -1310,7 +1739,7 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens
return torch.vstack(res)
-class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
+class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1401,7 +1830,7 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens
return self.resampler(vision_embedding, tgt_sizes)
-class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
+class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1499,7 +1928,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loaded
-class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
+class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1596,7 +2025,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loaded
-class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
+class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA, _MiniCPMVEncoderCudaGraphMixin):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1668,9 +2097,6 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens
dtype = pixel_values[0].dtype
all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device)
- all_temporal_ids = (
- None if temporal_ids is None else flatten_2d_lists(temporal_ids)
- )
for i, pixel_values_item in enumerate(pixel_values):
L_item = pixel_values_item.shape[-1]
all_pixel_values[i, ..., :L_item] = pixel_values_item
@@ -1689,7 +2115,7 @@ def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tens
tgt_sizes=tgt_sizes,
)
- return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids)
+ return self.resampler(vision_embedding, tgt_sizes, temporal_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"])