diff --git a/python/sglang/benchmark/datasets/image.py b/python/sglang/benchmark/datasets/image.py
index e84c6a622a9c..5efeb98b7a54 100644
--- a/python/sglang/benchmark/datasets/image.py
+++ b/python/sglang/benchmark/datasets/image.py
@@ -260,7 +260,7 @@ def _gen_random_image_data_uri(
# Generate text prompt
text_prompt = gen_mm_prompt(
- processor.tokenizer,
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor,
processor.image_token_id if hasattr(processor, "image_token_id") else None,
int(input_lens[i]),
)
diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py
index 35e3193ebfac..0c1fc170d18b 100644
--- a/python/sglang/srt/configs/__init__.py
+++ b/python/sglang/srt/configs/__init__.py
@@ -20,6 +20,7 @@
from sglang.srt.configs.lfm2_moe import Lfm2MoeConfig
from sglang.srt.configs.lfm2_vl import Lfm2VlConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
+from sglang.srt.configs.minicpmv4_6 import MiniCPMV4_6Config, MiniCPMV4_6VisionConfig
from sglang.srt.configs.nano_nemotron_vl import (
NemotronH_Nano_Omni_Reasoning_V3_Config,
NemotronH_Nano_VL_V2_Config,
@@ -64,6 +65,8 @@
"Lfm2Config",
"Lfm2MoeConfig",
"Lfm2VlConfig",
+ "MiniCPMV4_6Config",
+ "MiniCPMV4_6VisionConfig",
"NemotronHConfig",
"NemotronH_Nano_VL_V2_Config",
"NemotronH_Nano_Omni_Reasoning_V3_Config",
diff --git a/python/sglang/srt/configs/minicpmv4_6.py b/python/sglang/srt/configs/minicpmv4_6.py
new file mode 100644
index 000000000000..472224a892ba
--- /dev/null
+++ b/python/sglang/srt/configs/minicpmv4_6.py
@@ -0,0 +1,159 @@
+# Copyright 2026 The SGLang team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+"""Sglang-side ``PretrainedConfig`` classes for MiniCPM-V 4.6.
+
+Mirrors HF ref ``transformers/models/minicpmv4_6/configuration_minicpmv4_6.py``
+so we can register the configs ourselves while transformers main has not
+yet shipped native ``MiniCPMV4_6Config`` (lands 5.7+).
+"""
+
+from typing import Any, Dict, Optional, Union
+
+from transformers import AutoConfig, PretrainedConfig
+from transformers.models.auto import CONFIG_MAPPING
+
+from sglang.srt.configs.qwen3_5 import Qwen3_5TextConfig
+
+
+class MiniCPMV4_6VisionConfig(PretrainedConfig):
+ model_type = "minicpmv4_6_vision"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size: int = 1152,
+ intermediate_size: int = 4304,
+ num_hidden_layers: int = 27,
+ num_attention_heads: int = 16,
+ num_channels: int = 3,
+ image_size: int = 980,
+ patch_size: int = 14,
+ hidden_act: str = "gelu_pytorch_tanh",
+ layer_norm_eps: float = 1e-6,
+ attention_dropout: float = 0.0,
+ insert_layer_id: int = 6,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.insert_layer_id = insert_layer_id
+
+
+def _resolve_text_config_class(model_type: Optional[str]) -> type:
+ """``model_type`` -> registered config class. sglang's ``Qwen3_5TextConfig``
+ wins over the stock entry when both exist (it carries ``layers_block_type``
+ etc. that the model code reads); ``AutoConfig.register`` doesn't replace
+ existing entries so we have to short-circuit here. Note that
+ ``CONFIG_MAPPING.get`` returns ``None`` even on hit — go through
+ ``__getitem__`` to trigger the lazy class import.
+ """
+ if model_type == Qwen3_5TextConfig.model_type:
+ return Qwen3_5TextConfig
+ if model_type and model_type in CONFIG_MAPPING:
+ return CONFIG_MAPPING[model_type]
+ raise KeyError(f"Unknown text_config model_type: {model_type!r}")
+
+
+def _build_text_config(
+ text_config: Union[None, dict, PretrainedConfig],
+) -> PretrainedConfig:
+ """Coerce ``text_config`` into the right registered backbone class.
+
+ ``AutoConfig.from_pretrained`` resolves the ``"text_config"`` entry of
+ ``sub_configs`` and hands us a pre-built ``PretrainedConfig``; manual
+ construction in tests / examples passes a dict or ``None``.
+ """
+ if text_config is None:
+ return _resolve_text_config_class(Qwen3_5TextConfig.model_type)()
+ if isinstance(text_config, PretrainedConfig):
+ cls = _resolve_text_config_class(getattr(text_config, "model_type", None))
+ if isinstance(text_config, cls):
+ return text_config
+ return cls(**text_config.to_dict())
+ if isinstance(text_config, dict):
+ cfg = dict(text_config)
+ cls = _resolve_text_config_class(cfg.pop("model_type", None))
+ return cls(**cfg)
+ raise TypeError(f"Unsupported text_config type: {type(text_config)}")
+
+
+class MiniCPMV4_6Config(PretrainedConfig):
+ model_type = "minicpmv4_6"
+ # No type annotation: transformers 5+ wraps PretrainedConfig subclasses
+ # with @dataclass(kw_only=True), and an annotated mutable default would be
+ # rejected as a dataclass field. Matches qwen3_5/qwen3_vl/qwen3_omni.
+ sub_configs = {
+ "vision_config": MiniCPMV4_6VisionConfig,
+ "text_config": AutoConfig,
+ }
+
+ def __init__(
+ self,
+ text_config: Optional[Union[Dict[str, Any], PretrainedConfig]] = None,
+ vision_config: Optional[Union[Dict[str, Any], PretrainedConfig]] = None,
+ insert_layer_id: int = 6,
+ image_size: int = 448,
+ drop_vision_last_layer: bool = False,
+ image_token_id: Optional[int] = None,
+ video_token_id: Optional[int] = None,
+ tie_word_embeddings: bool = False,
+ downsample_mode: str = "16x",
+ merge_kernel_size=(2, 2),
+ merger_times: int = 1,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+ if isinstance(vision_config, dict):
+ vc = dict(vision_config)
+ vc.pop("model_type", None)
+ self.vision_config = MiniCPMV4_6VisionConfig(**vc)
+ elif vision_config is None:
+ self.vision_config = MiniCPMV4_6VisionConfig()
+ else:
+ self.vision_config = vision_config
+
+ # Mirror the ref ``__post_init__``: keep ``insert_layer_id`` in sync on
+ # both the top-level and the vision sub-config.
+ self.vision_config.insert_layer_id = insert_layer_id
+ self.patch_size = self.vision_config.patch_size
+
+ self.text_config = _build_text_config(text_config)
+
+ self.insert_layer_id = insert_layer_id
+ self.image_size = image_size
+ self.drop_vision_last_layer = drop_vision_last_layer
+ self.image_token_id = image_token_id
+ self.video_token_id = video_token_id
+ self.downsample_mode = downsample_mode
+ self.merge_kernel_size = tuple(merge_kernel_size)
+ self.merger_times = merger_times
+
+ # ``MiniCPMBaseModel.__init__`` reads ``self.config.hidden_size`` (written
+ # against flat 2.6/4.0/4.5 configs) and ``LogitsProcessor.__init__`` reads
+ # ``config.vocab_size`` — proxy both to ``text_config`` so we don't have to
+ # fork the base class / logits processor.
+ @property
+ def hidden_size(self) -> int:
+ return self.text_config.hidden_size
+
+ @property
+ def vocab_size(self) -> int:
+ return self.text_config.vocab_size
+
+
+__all__ = ["MiniCPMV4_6Config", "MiniCPMV4_6VisionConfig"]
diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py
index 588c356a473c..f1121a91a21f 100644
--- a/python/sglang/srt/models/minicpmv.py
+++ b/python/sglang/srt/models/minicpmv.py
@@ -61,8 +61,13 @@
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
from sglang.srt.models.llama import LlamaConfig, LlamaForCausalLM
+from sglang.srt.models.minicpmv_vit import (
+ MiniCPMV_Merger,
+ MiniCPMV_VisionTransformer,
+)
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from sglang.srt.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
+from sglang.srt.models.qwen3_5 import Qwen3_5ForCausalLM
from sglang.srt.utils import add_prefix, flatten_nested_list
RawImageType = Union[Image.Image, torch.Tensor]
@@ -576,6 +581,10 @@ def forward(
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
+ # 4.6 ships its own ``model_type`` instead of a numeric ``version``.
+ if getattr(config, "model_type", None) == "minicpmv4_6":
+ return 4, 6
+
version_float = getattr(config, "version", None)
# The old configs do not include version number
@@ -1342,7 +1351,277 @@ def eval(self):
return self
-_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6, (4, 0): MiniCPMV4_0, (4, 5): MiniCPMV4_5}
+class MiniCPMV4_6(MiniCPMBaseModel):
+ """MiniCPM-V 4.6.
+
+ Differences vs 4.5:
+ * mid-ViT compression (``MiniCPMV_VisionTransformer`` fires a 2x2 window
+ attention + 2x2 fold at ``config.insert_layer_id``);
+ * post-encoder connector is a pure MLP chain (``MiniCPMV_Merger``),
+ not a Perceiver resampler;
+ * LLM backbone is Qwen3.5;
+ * ``config.downsample_mode`` toggles ``"16x"`` (mid-ViT + post merger)
+ vs ``"4x"`` (skip mid-ViT, keep 4x more visual tokens).
+ """
+
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+ supported_lora_modules = [
+ # vision encoder + mid-ViT merger
+ "fc1",
+ "fc2",
+ "out_proj",
+ "linear_1",
+ "linear_2",
+ # language model
+ "qkv_proj",
+ "o_proj",
+ "gate_up_proj",
+ "down_proj",
+ ]
+
+ bitsandbytes_stacked_params_mapping = {
+ "q_proj": ("qkv_proj", 0),
+ "k_proj": ("qkv_proj", 1),
+ "v_proj": ("qkv_proj", 2),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
+
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
+ assert self.version == (4, 6)
+ # ``Qwen3_5ForCausalLM`` returns plain hidden states (body only, no LM
+ # head, no LogitsProcessor). Add them here so the downstream sampler
+ # sees a ``LogitsProcessorOutput``. With ``tie_word_embeddings=True``
+ # (4.6 default) the head shares weights with the embedding.
+ text_config = config.text_config
+ if getattr(text_config, "tie_word_embeddings", False):
+ self.lm_head = self.llm.embed_tokens
+ else:
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
+
+ self.lm_head = ParallelLMHead(
+ text_config.vocab_size,
+ text_config.hidden_size,
+ quant_config=quant_config,
+ prefix=add_prefix("lm_head", prefix),
+ )
+
+ def init_llm(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> nn.Module:
+ # 4.6 nests the LLM config under ``text_config``.
+ return Qwen3_5ForCausalLM(
+ config=config.text_config, quant_config=quant_config, prefix=prefix
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ forward_batch: ForwardBatch,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ # Apply our lm_head + LogitsProcessor on top of the base routine; the
+ # 4.6 LLM body (``Qwen3_5ForCausalLM``) returns plain hidden states,
+ # unlike the ``Qwen3ForCausalLM`` 4.5 used.
+ hidden_states = super().forward(
+ input_ids=input_ids,
+ positions=positions,
+ forward_batch=forward_batch,
+ **kwargs,
+ )
+ return self.logits_processor(
+ input_ids, hidden_states, self.lm_head, forward_batch
+ )
+
+ def init_vision_module(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
+ ) -> nn.Module:
+ model = MiniCPMV_VisionTransformer(
+ config=config.vision_config, quant_config=quant_config, prefix=prefix
+ )
+ if getattr(self.config, "drop_vision_last_layer", False):
+ # The mid-ViT merger sits on the transformer (not encoder.layers),
+ # so popping the last encoder layer leaves it untouched — same
+ # behaviour as 4.5.
+ model.encoder.layers = model.encoder.layers[:-1]
+
+ setattr(model, "embed_dim", model.embeddings.embed_dim)
+ setattr(model, "patch_size", model.embeddings.patch_size)
+ return model
+
+ def init_resampler(
+ self,
+ embed_dim: int,
+ vision_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> nn.Module:
+ # 4.6 replaces Resampler4_5 with a pure MLP. Method name kept so
+ # ``MiniCPMBaseModel.__init__`` doesn't need to branch.
+ with set_default_torch_dtype(torch.float16):
+ merger = MiniCPMV_Merger(
+ config=self.config,
+ quant_config=quant_config,
+ prefix=prefix,
+ )
+ return merger.to(device="cuda", dtype=torch.get_default_dtype())
+
+ def get_vision_embedding(
+ self,
+ pixel_values: List[torch.Tensor],
+ patch_attn_mask: Optional[torch.Tensor] = None,
+ tgt_sizes: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden, _ = self.vpm(
+ pixel_values,
+ patch_attention_mask=patch_attn_mask,
+ target_sizes=tgt_sizes,
+ )
+ return hidden
+
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
+ if items and items[0].format == MultimodalInputFormat.PRECOMPUTED_EMBEDDING:
+ result = torch.cat([item.feature for item in items])
+ return result.reshape(-1, result.shape[-1])
+
+ pixel_values = flatten_nested_list([item.feature for item in items])
+ tgt_sizes = torch.stack(
+ flatten_nested_list([item.tgt_size for item in items]), dim=0
+ )
+ assert len(pixel_values) == tgt_sizes.shape[0]
+
+ device = self.vpm.embeddings.position_embedding.weight.device
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
+ all_pixel_values_lst = [
+ i.flatten(end_dim=1).permute(1, 0) for i in pixel_values
+ ]
+
+ max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
+ assert isinstance(max_patches, int)
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(
+ all_pixel_values_lst, batch_first=True, padding_value=0.0
+ )
+
+ B, L, _ = all_pixel_values.shape
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
+ patch_attn_mask = torch.zeros(
+ (B, 1, max_patches), dtype=torch.bool, device=device
+ )
+
+ tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
+ mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
+ patch_attn_mask[:, 0, :] = torch.arange(
+ patch_attn_mask.size(2), device=patch_attn_mask.device
+ ).unsqueeze(0) < mask_shapes.unsqueeze(1)
+
+ use_vit_merger = getattr(self.config, "downsample_mode", "16x") != "4x"
+
+ vision_embedding, tgt_sizes_out = self.vpm(
+ all_pixel_values.type(dtype),
+ patch_attention_mask=patch_attn_mask,
+ target_sizes=tgt_sizes,
+ use_vit_merger=use_vit_merger,
+ )
+ return self.resampler(vision_embedding, tgt_sizes_out)
+
+ # Video frames take the same vision path as image patches; the mm
+ # processor emits one ``MultimodalDataItem`` per patch regardless of
+ # source. sglang's dispatcher routes by ``get_{modality}_feature``.
+ get_video_feature = get_image_feature
+
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
+ im_start_id: int = image_inputs.im_start_id
+ im_end_id: int = image_inputs.im_end_id
+ slice_start_id: int = image_inputs.slice_start_id
+ slice_end_id: int = image_inputs.slice_end_id
+
+ media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
+ pattern = MultiModalityDataPaddingPatternTokenPairs(
+ media_token_pairs, data_start_token_ids=[im_start_id]
+ )
+ return pattern.pad_input_tokens(input_ids, image_inputs)
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """Remap 4.6 prefixes (``model.{vision_tower,merger,language_model}``)
+ to sglang's (``vpm`` / ``resampler`` / ``llm``) and delegate the LLM
+ portion to ``Qwen3_5ForCausalLM.load_weights`` — the Qwen3.5 hybrid
+ backbone has its own stacked-param logic (``in_proj_a/b -> in_proj_ba``,
+ ``in_proj_qkv/z -> in_proj_qkvz``) the legacy loader doesn't know.
+ Vision-side still needs QKV stacking + ``out_proj -> proj`` rename.
+ """
+
+ llm_weights: List[Tuple[str, torch.Tensor]] = []
+ vision_weights: List[Tuple[str, torch.Tensor]] = []
+ for name, w in weights:
+ if name.startswith("model.language_model."):
+ llm_weights.append((name[len("model.language_model.") :], w))
+ continue
+ if name.startswith("model.vision_tower."):
+ name = "vpm." + name[len("model.vision_tower.") :]
+ elif name.startswith("model.merger."):
+ name = "resampler." + name[len("model.merger.") :]
+ vision_weights.append((name, w))
+
+ self.llm.load_weights(iter(llm_weights))
+
+ stacked_params_mapping = [
+ ("self_attn.qkv_proj", "self_attn.q_proj", "q"),
+ ("self_attn.qkv_proj", "self_attn.k_proj", "k"),
+ ("self_attn.qkv_proj", "self_attn.v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in vision_weights:
+ name = name.replace("self_attn.out_proj", "self_attn.proj")
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ target = name.replace(weight_name, param_name)
+ if target not in params_dict:
+ continue
+ param = params_dict[target]
+ param.weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+
+
+_SUPPORT_VERSION = {
+ (2, 6): MiniCPMV2_6,
+ (4, 0): MiniCPMV4_0,
+ (4, 5): MiniCPMV4_5,
+ (4, 6): MiniCPMV4_6,
+}
class MiniCPMV:
@@ -1369,7 +1648,12 @@ def __init__(
) -> None:
super().__init__()
- if not hasattr(config, "version"):
+ # 4.6 carries ``model_type == "minicpmv4_6"`` instead of a numeric
+ # ``config.version``; older versionless configs keep the legacy
+ # ``(2, 6)`` default.
+ if getattr(config, "model_type", None) == "minicpmv4_6":
+ version = (4, 6)
+ elif not hasattr(config, "version"):
version = (2, 6)
else:
version = str(config.version).split(".")
@@ -1404,6 +1688,13 @@ def __call__(self, *args, **kwargs):
return self.minicpmv(*args, **kwargs)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ # Defer to the version-specific subclass loader if it overrides the
+ # base (4.6 does — it needs prefix remap + Qwen3.5 LLM delegation).
+ sub_loader = getattr(type(self.minicpmv), "load_weights", None)
+ base_loader = getattr(MiniCPMBaseModel, "load_weights", None)
+ if sub_loader is not None and sub_loader is not base_loader:
+ return self.minicpmv.load_weights(weights)
+
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -1455,4 +1746,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight)
-EntryClass = MiniCPMV
+# Real subclass (not an `=` alias) so the model registry — which keys by
+# ``__name__`` — resolves the canonical 4.6 architecture name through
+# ``MiniCPMV``'s version-dispatch factory.
+class MiniCPMV4_6ForConditionalGeneration(MiniCPMV):
+ pass
+
+
+EntryClass = [MiniCPMV, MiniCPMV4_6ForConditionalGeneration]
diff --git a/python/sglang/srt/models/minicpmv_vit.py b/python/sglang/srt/models/minicpmv_vit.py
new file mode 100644
index 000000000000..915dd434c0e2
--- /dev/null
+++ b/python/sglang/srt/models/minicpmv_vit.py
@@ -0,0 +1,526 @@
+# Copyright 2026 The SGLang team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+"""Vision Transformer for MiniCPM-V 4.6.
+
+Compared to 4.5 (Idefics2VisionTransformer end-to-end + Perceiver-style
+Resampler4_5), 4.6 compresses visual tokens *twice*:
+
+ patchify -> [layer 0 .. insert_layer_id] full-res tokens
+ -> ViTWindowAttentionMerger 2x2 window attn + 2x2 fold
+ -> [layer insert_layer_id+1 .. N-1] compressed tokens
+ -> post_layernorm
+ -> Merger (merger_times x DownsampleMLP, project to LLM dim)
+
+With defaults (insert_layer_id=6, merger_times=1) the combined compression
+is 16x. ``downsample_mode="4x"`` skips the mid-ViT merger.
+
+Class structure mirrors the HF ref one-to-one to make weight loading and
+upstream tracking easy.
+"""
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers import PretrainedConfig
+
+from sglang.srt.layers.activation import get_act_fn
+from sglang.srt.layers.attention.vision import VisionAttention
+from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
+from sglang.srt.layers.quantization.base_config import QuantizationConfig
+from sglang.srt.models.idefics2 import (
+ Idefics2Encoder,
+ Idefics2EncoderLayer,
+ Idefics2VisionEmbeddings,
+)
+from sglang.srt.utils import add_prefix, is_npu
+
+
+class MiniCPMV_ViTWindowAttentionMerger(nn.Module):
+ """Mid-ViT 2x2 window attention + 2x2 fold.
+
+ Stage 1: reorder tokens so each 2x2 spatial window becomes 4 contiguous
+ tokens; run packed self-attention with one window per cu_seqlens segment;
+ un-reorder; add residual. (No length reduction yet.)
+
+ Stage 2: fold each 2x2 window into a single token by concatenating the
+ four hidden vectors along channel; pass through ``hidden*4 ->
+ intermediate*4 -> hidden`` MLP; add the mean of the four window vectors
+ as residual. ``target_sizes`` halves on each axis; ``cu_seqlens`` /
+ ``max_seqlens`` are rebuilt for the compressed grid.
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.window_kernel_size = (2, 2)
+ self.embed_dim = config.hidden_size
+
+ # The "FFN" here is the linear_1/linear_2 pair applied after the 2x2
+ # fold below (it operates on hidden*4 -> intermediate*4 -> hidden).
+ # ``flatten_batch=True``: input is one packed sequence
+ # ``(1, sum_windows * window_area, D)`` with cu_seqlens demarcating
+ # per-window segments. The outer encoder layers use ``False`` because
+ # there each batch row is one image padded to max_patches.
+ self.self_attn = VisionAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ projection_size=config.hidden_size,
+ use_qkv_parallel=True,
+ quant_config=quant_config,
+ dropout=config.attention_dropout,
+ softmax_in_single_precision=True,
+ flatten_batch=True,
+ prefix=add_prefix("self_attn", prefix),
+ )
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ window_area = self.window_kernel_size[0] * self.window_kernel_size[1]
+ hidden_4x = self.embed_dim * window_area
+ inter_4x = config.intermediate_size * window_area
+
+ self.pre_norm = nn.LayerNorm(hidden_4x, eps=config.layer_norm_eps)
+ self.linear_1 = ColumnParallelLinear(
+ hidden_4x,
+ inter_4x,
+ bias=True,
+ quant_config=quant_config,
+ prefix=add_prefix("linear_1", prefix),
+ )
+ self.act = get_act_fn("gelu_pytorch_tanh")
+ self.linear_2 = RowParallelLinear(
+ inter_4x,
+ self.embed_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=add_prefix("linear_2", prefix),
+ )
+
+ def get_window_index(
+ self, target_sizes: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """Return ``(permutation, per-window cu_seqlens, max_seqlens=4)``.
+
+ Kept on CPU because mixing device-bound offsets with CPU arange trips
+ strict dtype checks in PyTorch 2.10+.
+ """
+ window_h, window_w = self.window_kernel_size
+ max_seqlens = window_h * window_w # 4
+
+ window_index_list: List[torch.Tensor] = []
+ cu_seqlens: List[int] = [0]
+ token_offset = 0
+
+ for height, width in target_sizes:
+ height, width = int(height), int(width)
+ if height % window_h != 0 or width % window_w != 0:
+ raise ValueError(
+ f"height={height}, width={width} must be divisible by "
+ f"window size ({window_h}, {window_w})"
+ )
+ index = torch.arange(height * width).reshape(height, width)
+ num_windows_h = height // window_h
+ num_windows_w = width // window_w
+ num_windows = num_windows_h * num_windows_w
+
+ index = index.reshape(num_windows_h, window_h, num_windows_w, window_w)
+ index = index.permute(0, 2, 1, 3).reshape(num_windows, window_h * window_w)
+
+ window_index_list.append(index.reshape(-1) + token_offset)
+
+ cu_this = (
+ torch.arange(1, num_windows + 1) * (window_h * window_w)
+ + cu_seqlens[-1]
+ )
+ cu_seqlens.extend(cu_this.tolist())
+
+ token_offset += height * width
+
+ window_index = torch.cat(window_index_list)
+ cu_seqlens_t = torch.tensor(cu_seqlens, dtype=torch.int32)
+ return window_index, cu_seqlens_t, max_seqlens
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ target_sizes: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlens: int,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
+ device = hidden_states.device
+
+ # Stage 1: 2x2 window self-attention + residual.
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+
+ window_index, window_cu_seqlens, _ = self.get_window_index(target_sizes)
+ window_index = window_index.to(device)
+ window_cu_seqlens = window_cu_seqlens.to(device)
+ if is_npu():
+ window_cu_seqlens = window_cu_seqlens.to("cpu")
+
+ hidden_states = hidden_states[:, window_index, :]
+ hidden_states = self.self_attn(hidden_states, cu_seqlens=window_cu_seqlens)
+ hidden_states = hidden_states[:, torch.argsort(window_index), :]
+ hidden_states = residual + hidden_states
+
+ # Stage 2: 2x2 spatial fold + MLP + mean residual.
+ if (target_sizes % 2 != 0).any():
+ raise ValueError(
+ f"All target_sizes must be divisible by 2, got {target_sizes}"
+ )
+ new_target_sizes = target_sizes // 2
+
+ window_h, window_w = self.window_kernel_size
+ batch_size = target_sizes.shape[0]
+ all_pixel_values = []
+ for batch_idx in range(batch_size):
+ height, width = target_sizes[batch_idx]
+ patch = hidden_states[
+ 0, cu_seqlens[batch_idx] : cu_seqlens[batch_idx + 1], :
+ ].squeeze(0)
+
+ embed_dim = patch.shape[-1]
+ merged_h, merged_w = height // window_h, width // window_w
+ patch_5d = patch.view(
+ merged_h, window_h, merged_w, window_w, embed_dim
+ ).permute(0, 2, 1, 3, 4)
+ hidden_state = patch_5d.reshape(
+ merged_h * merged_w, window_h * window_w * embed_dim
+ )
+ res = patch_5d.reshape(
+ merged_h * merged_w, window_h * window_w, embed_dim
+ ).mean(dim=1)
+
+ hidden_state = self.pre_norm(hidden_state)
+ hidden_state, _ = self.linear_1(hidden_state)
+ hidden_state = self.act(hidden_state)
+ hidden_state, _ = self.linear_2(hidden_state)
+
+ all_pixel_values.append(hidden_state + res)
+
+ new_hidden_states = torch.concat(all_pixel_values, dim=0).unsqueeze(0)
+ new_cu_seqlens = F.pad(
+ torch.cumsum(
+ new_target_sizes[:, 0] * new_target_sizes[:, 1],
+ dim=0,
+ dtype=torch.int32,
+ ).to(device),
+ (1, 0),
+ )
+ if max_seqlens % 4 != 0:
+ raise ValueError(f"max_seqlens ({max_seqlens}) must be divisible by 4")
+ new_max_seqlens = max_seqlens // 4
+
+ return new_hidden_states, new_target_sizes, new_cu_seqlens, new_max_seqlens
+
+
+class MiniCPMV_DownsampleMLP(nn.Module):
+ """One round of 2x2 spatial merge + MLP, used inside ``MiniCPMV_Merger``.
+
+ Input channel dim is ``hidden_size * 4`` (already folded by the caller).
+ Output is ``hidden_size`` for an intermediate round or ``llm_embed_dim``
+ for the final round.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ llm_embed_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ merged_hidden_size = hidden_size * 4
+
+ self.pre_norm = nn.LayerNorm(merged_hidden_size, eps=1e-6)
+ self.linear_1 = ColumnParallelLinear(
+ merged_hidden_size,
+ merged_hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=add_prefix("linear_1", prefix),
+ )
+ self.act = nn.GELU()
+ self.linear_2 = RowParallelLinear(
+ merged_hidden_size,
+ llm_embed_dim,
+ bias=True,
+ quant_config=quant_config,
+ prefix=add_prefix("linear_2", prefix),
+ )
+ self.in_features = merged_hidden_size
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.pre_norm(hidden_states).view(-1, self.in_features)
+ hidden_states, _ = self.linear_1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states, _ = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class MiniCPMV_Merger(nn.Module):
+ """Iterative 2x2 fold + MLP chain between ViT and LLM.
+
+ With ``merger_times == 1`` (the 4.6 release default) it's a single
+ DownsampleMLP projecting straight into ``text_config.hidden_size``. Each
+ additional round halves the grid and keeps the channel width at
+ ``vision_config.hidden_size`` until the last round.
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ self.merge_kernel_size = tuple(config.merge_kernel_size)
+ self.merger_times = config.merger_times
+ hidden_size = config.vision_config.hidden_size
+ llm_embed_dim = config.text_config.hidden_size
+
+ self.mlp = nn.ModuleList(
+ [
+ MiniCPMV_DownsampleMLP(
+ hidden_size,
+ llm_embed_dim if i == self.merger_times - 1 else hidden_size,
+ quant_config=quant_config,
+ prefix=add_prefix(f"mlp.{i}", prefix),
+ )
+ for i in range(self.merger_times)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ target_sizes: torch.Tensor,
+ ) -> torch.Tensor:
+ merge_h, merge_w = self.merge_kernel_size
+
+ start = 0
+ processed = []
+ for batch_idx in range(len(target_sizes)):
+ height, width = target_sizes[batch_idx]
+ num_patches = int(height * width)
+
+ embed_dim = hidden_states.shape[-1]
+ merged_h, merged_w = int(height) // merge_h, int(width) // merge_w
+ hidden_state = (
+ hidden_states[0, start : start + num_patches, :]
+ .view(merged_h, merge_h, merged_w, merge_w, embed_dim)
+ .permute(0, 2, 1, 3, 4)
+ .reshape(merged_h * merged_w, merge_h * merge_w * embed_dim)
+ )
+ hidden_state = self.mlp[0](hidden_state)
+
+ height, width = int(height), int(width)
+ for i in range(1, self.merger_times):
+ if height % merge_h != 0 or width % merge_w != 0:
+ raise ValueError(
+ f"Patch grid ({height}, {width}) must be divisible by "
+ f"merge kernel size {self.merge_kernel_size} at round {i}"
+ )
+ height //= merge_h
+ width //= merge_w
+
+ inner_dim = hidden_state.shape[-1]
+ merged_h, merged_w = height // merge_h, width // merge_w
+ hidden_state = (
+ hidden_state.view(merged_h, merge_h, merged_w, merge_w, inner_dim)
+ .permute(0, 2, 1, 3, 4)
+ .reshape(merged_h * merged_w, merge_h * merge_w * inner_dim)
+ )
+ hidden_state = self.mlp[i](hidden_state)
+
+ start += num_patches
+ processed.append(hidden_state)
+
+ return torch.cat(processed, dim=0)
+
+
+class MiniCPMV_VisionEncoderLayer(Idefics2EncoderLayer):
+ """SigLip-style pre-norm encoder layer for packed NaViT input.
+
+ Inherits Idefics2's forward and submodule layout (so HF weights map
+ verbatim), then rebuilds ``self_attn`` with ``flatten_batch=True`` for
+ per-image block-diagonal attention on a single packed sequence
+ (Idefics2 uses padded ``(B, max_patches, D)``) and the SigLip-correct
+ ``projection_size = hidden_size`` (Idefics2 sets it to ``intermediate_size``).
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__(config, quant_config=quant_config, prefix=prefix)
+ self.self_attn = VisionAttention(
+ embed_dim=config.hidden_size,
+ num_heads=config.num_attention_heads,
+ projection_size=config.hidden_size,
+ use_qkv_parallel=True,
+ quant_config=quant_config,
+ dropout=config.attention_dropout,
+ softmax_in_single_precision=True,
+ flatten_batch=True,
+ prefix=add_prefix("self_attn", prefix),
+ )
+
+
+class MiniCPMV_VisionEncoder(Idefics2Encoder):
+ """Stack of ``MiniCPMV_VisionEncoderLayer``.
+
+ ``vit_merger`` lives one level up on ``MiniCPMV_VisionTransformer`` so the
+ HF checkpoint key ``vision_tower.vit_merger.*`` lands at the matching
+ sglang path.
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__(config, quant_config=quant_config, prefix=prefix)
+ self.layers = nn.ModuleList(
+ [
+ MiniCPMV_VisionEncoderLayer(
+ config,
+ quant_config=quant_config,
+ prefix=add_prefix(f"layers.{i}", prefix),
+ )
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+
+
+class MiniCPMV_VisionTransformer(nn.Module):
+ """Vision Transformer for MiniCPM-V 4.6.
+
+ Reuses sglang's SigLIP-style ``Idefics2VisionEmbeddings`` + encoder layers,
+ inserts ``MiniCPMV_ViTWindowAttentionMerger`` after layer ``insert_layer_id``,
+ and applies post-encoder LayerNorm. ``forward`` returns
+ ``(hidden_states, target_sizes)``; in ``"16x"`` mode ``target_sizes``
+ reflects the post-merger grid, which downstream code must use.
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ require_post_norm: bool = True,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ embed_dim = config.hidden_size
+ self.config = config
+
+ if not hasattr(config, "insert_layer_id"):
+ raise ValueError(
+ "MiniCPMV_VisionTransformer requires `config.insert_layer_id`"
+ )
+
+ self.insert_layer_id = config.insert_layer_id
+ self.embeddings = Idefics2VisionEmbeddings(config)
+ self.encoder = MiniCPMV_VisionEncoder(
+ config=config,
+ quant_config=quant_config,
+ prefix=add_prefix("encoder", prefix),
+ )
+ self.post_layernorm = (
+ nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ if require_post_norm
+ else nn.Identity()
+ )
+ self.vit_merger = MiniCPMV_ViTWindowAttentionMerger(
+ config,
+ quant_config=quant_config,
+ prefix=add_prefix("vit_merger", prefix),
+ )
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.embeddings
+
+ @staticmethod
+ def compute_cu_seqlens(target_sizes: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ seqlen = (target_sizes[:, 0] * target_sizes[:, 1]).to(torch.int32)
+ cu_seqlens = torch.cat(
+ [
+ torch.tensor([0], device=seqlen.device, dtype=torch.int32),
+ torch.cumsum(seqlen, dim=0, dtype=torch.int32),
+ ],
+ dim=0,
+ )
+ max_seqlens = int(seqlen.max().item())
+ return cu_seqlens, max_seqlens
+
+ @staticmethod
+ def _pad_to_pack(padded: torch.Tensor, target_sizes: torch.Tensor) -> torch.Tensor:
+ """``(B, max_patches, D) -> (1, sum_patches, D)``.
+
+ ``Idefics2VisionEmbeddings`` emits padded shape with valid tokens at
+ ``[0, h_b * w_b)`` of each batch row. Strip the padding so the rest
+ of the ViT runs in flat NaViT form.
+ """
+ seqlens = (target_sizes[:, 0] * target_sizes[:, 1]).to(torch.long)
+ if padded.shape[0] == 1:
+ return padded[:, : int(seqlens[0].item()), :]
+ parts = [padded[b, : int(seqlens[b].item()), :] for b in range(padded.shape[0])]
+ return torch.cat(parts, dim=0).unsqueeze(0)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
+ target_sizes: Optional[torch.IntTensor] = None,
+ use_vit_merger: bool = True,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if target_sizes is None:
+ raise ValueError("MiniCPMV_VisionTransformer requires `target_sizes`.")
+
+ hidden_states = self.embeddings(
+ pixel_values=pixel_values,
+ patch_attention_mask=patch_attention_mask,
+ tgt_sizes=target_sizes,
+ )
+ hidden_states = self._pad_to_pack(hidden_states, target_sizes)
+ cu_seqlens, max_seqlens = self.compute_cu_seqlens(target_sizes)
+ if is_npu():
+ cu_seqlens = cu_seqlens.to("cpu")
+
+ if use_vit_merger:
+ # Encoder loop lives here (not inside ``MiniCPMV_VisionEncoder``)
+ # so we can fire ``vit_merger`` after layer ``insert_layer_id``
+ # without coupling the encoder module to it.
+ for layer_index, layer in enumerate(self.encoder.layers):
+ hidden_states = layer(hidden_states, cu_seqlens=cu_seqlens)
+ if layer_index == self.insert_layer_id:
+ (
+ hidden_states,
+ target_sizes,
+ cu_seqlens,
+ max_seqlens,
+ ) = self.vit_merger(
+ hidden_states, target_sizes, cu_seqlens, max_seqlens
+ )
+ if is_npu():
+ cu_seqlens = cu_seqlens.to("cpu")
+ else:
+ hidden_states = self.encoder(hidden_states, cu_seqlens=cu_seqlens)
+
+ hidden_states = self.post_layernorm(hidden_states)
+ return hidden_states, target_sizes
diff --git a/python/sglang/srt/multimodal/processors/minicpmv4_6.py b/python/sglang/srt/multimodal/processors/minicpmv4_6.py
new file mode 100644
index 000000000000..25529b9b86e1
--- /dev/null
+++ b/python/sglang/srt/multimodal/processors/minicpmv4_6.py
@@ -0,0 +1,548 @@
+# Copyright 2026 The SGLang team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+"""sglang multimodal processor for MiniCPM-V 4.6.
+
+Ports per-image preprocessing + chat-template expansion sglang-side because
+no working HF ``MiniCPMV4_6Processor`` is reachable yet: transformers main
+does not ship one until 5.7+, and the released 4.6 checkpoints ship only a
+tokenizer (no remote-code processor), so ``AutoProcessor.from_pretrained``
+falls through to a bare tokenizer. Once a real processor is loadable, this
+module collapses to a thin wrapper that delegates to it.
+"""
+
+from __future__ import annotations
+
+import math
+from itertools import chain
+from typing import Any, List, Optional, Sequence, Tuple, Union
+
+import torch
+import torchvision.transforms.functional as F
+from PIL import Image
+
+from sglang.srt.managers.schedule_batch import (
+ Modality,
+ MultimodalDataItem,
+ MultimodalProcessorOutput,
+)
+from sglang.srt.models.minicpmv import MiniCPMV4_6ForConditionalGeneration
+from sglang.srt.multimodal.processors.base_processor import (
+ BaseMultimodalProcessor,
+ MultimodalSpecialTokens,
+)
+
+IMAGENET_STANDARD_MEAN = (0.5, 0.5, 0.5)
+IMAGENET_STANDARD_STD = (0.5, 0.5, 0.5)
+
+# Inner per-feature pad sentinel: prevents the next per-image
+# ``replace(image_token, ...)`` from clobbering a previous expansion's inner
+# pads. Swapped back to the real pad token once per modality after splicing.
+_PAD_PLACEHOLDER = "<|placeholder|>"
+
+
+def _ensure_divide(length: int, divisor: int) -> int:
+ return max(round(length / divisor) * divisor, divisor)
+
+
+def _to_chw_tensor(image) -> torch.Tensor:
+ """PIL / torch / numpy -> ``(C, H, W)`` float32 in ``[0, 255]``.
+
+ Image inputs from ``load_mm_data`` are PIL; video frames from sglang's
+ video decoder come back as numpy arrays.
+ """
+ if isinstance(image, torch.Tensor):
+ if image.dim() == 4:
+ image = image.squeeze(0)
+ if image.dim() != 3:
+ raise ValueError(f"expected 3-D image tensor, got {image.shape}")
+ if image.shape[0] not in (1, 3, 4):
+ image = image.permute(2, 0, 1).contiguous()
+ if image.shape[0] == 4:
+ image = image[:3]
+ if image.shape[0] == 1:
+ image = image.repeat(3, 1, 1)
+ return image.float()
+
+ if isinstance(image, Image.Image):
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+ return F.pil_to_tensor(image).float()
+
+ import numpy as np
+
+ if isinstance(image, np.ndarray):
+ t = torch.from_numpy(image)
+ if t.dim() == 3 and t.shape[-1] in (1, 3, 4):
+ t = t.permute(2, 0, 1).contiguous()
+ if t.shape[0] == 4:
+ t = t[:3]
+ if t.shape[0] == 1:
+ t = t.repeat(3, 1, 1)
+ return t.float()
+
+ raise TypeError(f"Unsupported image type: {type(image)!r}")
+
+
+def _resize(image: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ return F.resize(
+ image,
+ size=[height, width],
+ interpolation=F.InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+
+
+def _divide_to_patches(
+ image: torch.Tensor, patch_h: int, patch_w: int
+) -> List[torch.Tensor]:
+ _, H, W = image.shape
+ if H % patch_h != 0 or W % patch_w != 0:
+ raise ValueError(f"image ({H}, {W}) not divisible by ({patch_h}, {patch_w})")
+ rows = H // patch_h
+ cols = W // patch_w
+ patches: List[torch.Tensor] = []
+ for r in range(rows):
+ for c in range(cols):
+ patches.append(
+ image[
+ :, r * patch_h : (r + 1) * patch_h, c * patch_w : (c + 1) * patch_w
+ ]
+ )
+ return patches
+
+
+def _reshape_by_patch(image: torch.Tensor, patch_size: int) -> torch.Tensor:
+ """``(C, H, W) -> (C, P, H*W/P)`` NaViT packing."""
+ C = image.shape[0]
+ patches = torch.nn.functional.unfold(
+ image.unsqueeze(0), (patch_size, patch_size), stride=(patch_size, patch_size)
+ )
+ patches = patches.reshape(C, patch_size, patch_size, -1)
+ patches = patches.permute(0, 1, 3, 2).reshape(C, patch_size, -1)
+ return patches
+
+
+def _flatten_patches(
+ per_item_pv: List[List[torch.Tensor]],
+ per_item_ts: List[List[List[int]]],
+) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ """Per-item per-patch -> flat per-patch (source first, slices row-major)."""
+ flat_pv = list(chain.from_iterable(per_item_pv))
+ flat_ts = [
+ torch.tensor(ts, dtype=torch.int32) for ts in chain.from_iterable(per_item_ts)
+ ]
+ return flat_pv, flat_ts
+
+
+class MiniCPMV4_6ImageProcessor:
+ """Per-image preprocessing.
+
+ Pipeline: pick a slice grid (rows x cols, up to ``max_slice_nums``); resize
+ source and (optionally) tiles to multiples of ``patch_size * 4`` (factor 4
+ = the two successive 2x2 spatial merges: mid-ViT merger + DownsampleMLP);
+ rescale, normalize, and NaViT-pack each tile into ``(C, P, H*W/P)``.
+ """
+
+ def __init__(
+ self,
+ max_slice_nums: int = 9,
+ scale_resolution: int = 448,
+ patch_size: int = 14,
+ slice_mode: bool = True,
+ downsample_mode: str = "16x",
+ use_image_id: bool = True,
+ image_mean: Sequence[float] = IMAGENET_STANDARD_MEAN,
+ image_std: Sequence[float] = IMAGENET_STANDARD_STD,
+ rescale_factor: float = 1.0 / 255.0,
+ ) -> None:
+ self.max_slice_nums = max_slice_nums
+ self.scale_resolution = scale_resolution
+ self.patch_size = patch_size
+ self.slice_mode = slice_mode
+ self.downsample_mode = downsample_mode
+ self.use_image_id = use_image_id
+ self.image_mean = torch.tensor(image_mean, dtype=torch.float32).view(3, 1, 1)
+ self.image_std = torch.tensor(image_std, dtype=torch.float32).view(3, 1, 1)
+ self.rescale_factor = rescale_factor
+
+ def _find_best_resize(
+ self,
+ image_size: Tuple[int, int],
+ allow_upscale: bool = False,
+ ) -> Tuple[int, int]:
+ height, width = image_size
+ scale = self.scale_resolution
+ # factor 4 = two successive 2x2 spatial merges (mid-ViT + DownsampleMLP)
+ divisor = self.patch_size * 4
+ if (height * width > scale * scale) or allow_upscale:
+ aspect_ratio = width / height
+ height = int(scale / math.sqrt(aspect_ratio))
+ width = int(height * aspect_ratio)
+ best_w = _ensure_divide(width, divisor)
+ best_h = _ensure_divide(height, divisor)
+ return best_h, best_w
+
+ def _get_refine_size(
+ self,
+ image_size: Tuple[int, int],
+ grid: Tuple[int, int],
+ allow_upscale: bool = False,
+ ) -> Tuple[int, int]:
+ height, width = image_size
+ grid_y, grid_x = grid
+ refine_w = _ensure_divide(width, grid_x)
+ refine_h = _ensure_divide(height, grid_y)
+ bh, bw = self._find_best_resize(
+ (refine_h // grid_y, refine_w // grid_x),
+ allow_upscale=allow_upscale,
+ )
+ return bh * grid_y, bw * grid_x
+
+ def _get_sliced_grid(
+ self, image_size: Tuple[int, int]
+ ) -> Optional[Tuple[int, int]]:
+ original_h, original_w = image_size
+ scale = self.scale_resolution
+ log_ratio = math.log(original_w / original_h)
+ ratio = original_w * original_h / (scale * scale)
+ multiple = min(math.ceil(ratio), self.max_slice_nums)
+ if multiple <= 1:
+ return None
+
+ best_grid = (1, 1)
+ min_error = float("inf")
+ for num_slices in (multiple - 1, multiple, multiple + 1):
+ if num_slices == 1 or num_slices > self.max_slice_nums:
+ continue
+ for num_rows in range(1, num_slices + 1):
+ if num_slices % num_rows != 0:
+ continue
+ num_cols = num_slices // num_rows
+ error = abs(log_ratio - math.log(num_rows / num_cols))
+ if error < min_error:
+ # Ref returns ``[cols, rows]``; preserve the convention so
+ # downstream code matches HF.
+ best_grid = (num_cols, num_rows)
+ min_error = error
+ return best_grid
+
+ def _normalize(self, t: torch.Tensor) -> torch.Tensor:
+ t = t * self.rescale_factor
+ return (t - self.image_mean.to(t.dtype)) / self.image_std.to(t.dtype)
+
+ def __call__(self, images: List) -> dict:
+ return self.preprocess(images)
+
+ def preprocess(self, images: List) -> dict:
+ """Returns ``{pixel_values, tgt_sizes, grids, num_patches_per_image}``.
+
+ Per image, ``pixel_values[i]`` is a list whose first entry is the
+ source patch and remaining entries are slice tiles in row-major grid
+ order. ``grids[i]`` is ``[cols, rows]`` (zeros if no slicing).
+ """
+ per_image_pv: List[List[torch.Tensor]] = []
+ per_image_ts: List[List[List[int]]] = []
+ all_grids: List[List[int]] = []
+ num_patches_per_image: List[int] = []
+
+ for image in images:
+ chw = _to_chw_tensor(image)
+ H0, W0 = chw.shape[-2], chw.shape[-1]
+ best_grid = self._get_sliced_grid((H0, W0)) if self.slice_mode else None
+
+ allow_upscale_src = best_grid is None
+ src_h, src_w = self._find_best_resize(
+ (H0, W0), allow_upscale=allow_upscale_src
+ )
+ source = _resize(chw, src_h, src_w)
+
+ patches: List[torch.Tensor] = [source]
+ patch_h = patch_w = 0
+ if best_grid is not None:
+ refine_h, refine_w = self._get_refine_size(
+ (H0, W0), best_grid, allow_upscale=True
+ )
+ refined = _resize(chw, refine_h, refine_w)
+ grid_y, grid_x = best_grid
+ patch_h = refine_h // grid_y
+ patch_w = refine_w // grid_x
+ patches.extend(_divide_to_patches(refined, patch_h, patch_w))
+
+ patches = [self._normalize(p) for p in patches]
+
+ pv = [_reshape_by_patch(patches[0], self.patch_size)]
+ ts = [[src_h // self.patch_size, src_w // self.patch_size]]
+ for p in patches[1:]:
+ pv.append(_reshape_by_patch(p, self.patch_size))
+ ts.append([patch_h // self.patch_size, patch_w // self.patch_size])
+
+ per_image_pv.append(pv)
+ per_image_ts.append(ts)
+ all_grids.append(list(best_grid) if best_grid is not None else [0, 0])
+ num_patches_per_image.append(len(pv))
+
+ return {
+ "pixel_values": per_image_pv,
+ "tgt_sizes": per_image_ts,
+ "grids": all_grids,
+ "num_patches_per_image": num_patches_per_image,
+ }
+
+
+class MiniCPMV4_6MultimodalProcessor(BaseMultimodalProcessor):
+ """4.6-only mm processor.
+
+ The legacy ``MiniCPMMultimodalProcessor`` stays for 2.6/4.0/4.5 because its
+ ``_processor.tokenizer`` shape and ``(./)`` placeholder
+ format don't fit 4.6.
+ """
+
+ models = [MiniCPMV4_6ForConditionalGeneration]
+ support_dynamic_frame_expansion = False
+ gpu_image_decode = False
+
+ def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
+ super().__init__(hf_config, server_args, _processor, *args, **kwargs)
+
+ # ``_processor`` is either the bare tokenizer (current state — no
+ # ``MiniCPMV4_6Processor`` shipped) or a real processor whose
+ # ``.tokenizer`` exposes the same.
+ self.tokenizer = getattr(_processor, "tokenizer", _processor)
+
+ vision_cfg = getattr(hf_config, "vision_config", None)
+ patch_size = (
+ getattr(vision_cfg, "patch_size", 14) if vision_cfg is not None else 14
+ )
+ downsample_mode = getattr(hf_config, "downsample_mode", "16x")
+ # Per-image preprocessor; reused for video frames (HF ref's
+ # video slicing geometry matches image slicing exactly).
+ self.image_processor = MiniCPMV4_6ImageProcessor(
+ max_slice_nums=9,
+ scale_resolution=448,
+ patch_size=patch_size,
+ slice_mode=True,
+ downsample_mode=downsample_mode,
+ use_image_id=True,
+ )
+
+ self.image_token = "<|image_pad|>"
+ self.video_token = "<|video_pad|>"
+ self.image_token_id = getattr(hf_config, "image_token_id", None)
+ if self.image_token_id is None:
+ self.image_token_id = self._token_id(self.image_token)
+ self.video_token_id = getattr(hf_config, "video_token_id", None)
+ if self.video_token_id is None:
+ self.video_token_id = self._token_id(self.video_token)
+
+ # ````/```` wrap the expanded regions for both images and
+ # video frames; only the inner per-feature pad token differs.
+ self.image_start_token = ""
+ self.image_end_token = ""
+ self.slice_start_token = ""
+ self.slice_end_token = ""
+ self.image_id_start_token = ""
+ self.image_id_end_token = ""
+
+ self.image_start_id = self._token_id(self.image_start_token)
+ self.image_end_id = self._token_id(self.image_end_token)
+ self.slice_start_id = self._token_id(self.slice_start_token)
+ self.slice_end_id = self._token_id(self.slice_end_token)
+
+ self.pad_divisor = 16 if downsample_mode != "4x" else 4
+
+ self.mm_tokens = MultimodalSpecialTokens(
+ image_token=self.image_token,
+ image_token_id=self.image_token_id,
+ video_token=self.video_token,
+ video_token_id=self.video_token_id,
+ ).build(_processor)
+
+ def _token_id(self, token: str):
+ try:
+ ids = self.tokenizer.convert_tokens_to_ids([token])
+ if ids and ids[0] is not None:
+ return int(ids[0])
+ except Exception:
+ pass
+ return None
+
+ def _expand_frame(
+ self,
+ tgt_sizes: List[List[int]],
+ grid: List[int],
+ ) -> str:
+ """``...`` (+ optional ``...`` rows) for
+ one image or video frame; inner pads are ``_PAD_PLACEHOLDER`` (caller
+ swaps back after splicing).
+ """
+ h0, w0 = tgt_sizes[0]
+ n_src = (h0 * w0) // self.pad_divisor
+ out = self.image_start_token + _PAD_PLACEHOLDER * n_src + self.image_end_token
+
+ if len(tgt_sizes) > 1 and grid and grid[0] > 0 and grid[1] > 0:
+ grid_y, grid_x = int(grid[0]), int(grid[1])
+ h_s, w_s = tgt_sizes[1]
+ n_slice = (h_s * w_s) // self.pad_divisor
+ slice_chunk = (
+ self.slice_start_token
+ + _PAD_PLACEHOLDER * n_slice
+ + self.slice_end_token
+ )
+ row_chunks = [slice_chunk * grid_x for _ in range(grid_y)]
+ out += "\n".join(row_chunks)
+ return out
+
+ def _expand_media(
+ self,
+ index: int,
+ frames: Sequence[Tuple[List[List[int]], List[int]]],
+ ) -> str:
+ """One image or one video. Image is a single-frame video."""
+ body = "".join(self._expand_frame(ts, grid) for ts, grid in frames)
+ return f"{self.image_id_start_token}{index}{self.image_id_end_token}" + body
+
+ async def process_mm_data_async(
+ self,
+ image_data: Sequence[Union[str, bytes]],
+ audio_data: Sequence[Union[str, bytes]],
+ input_text,
+ request_obj,
+ **kwargs: Any,
+ ):
+ # ``TokenizerManager`` does not pass ``video_data`` through the
+ # processor signature; read it off the request the way qwen_vl does.
+ video_data = getattr(request_obj, "video_data", None) or kwargs.get(
+ "video_data"
+ )
+ base = self.load_mm_data(
+ prompt=input_text,
+ audio_data=audio_data,
+ image_data=image_data,
+ video_data=video_data,
+ multimodal_tokens=self.mm_tokens,
+ )
+ if base is None:
+ return None
+
+ prompt: str = base.input_text or ""
+ images = base.images or []
+ videos = base.videos or []
+
+ # Image: one "frame" per image. Video: per-frame nesting kept so each
+ # frame becomes its own ``...`` block in the expansion.
+ img_per_pv, img_per_ts, img_grids = self._preprocess_images(images)
+ vid_per_pv, vid_per_ts, vid_grids = self._preprocess_videos(videos)
+
+ prompt = self._splice_expansions(
+ prompt,
+ (
+ self._expand_media(i, [(ts, gd)])
+ for i, (ts, gd) in enumerate(zip(img_per_ts, img_grids))
+ ),
+ (
+ self._expand_media(i, list(zip(fts, fgd)))
+ for i, (fts, fgd) in enumerate(zip(vid_per_ts, vid_grids))
+ ),
+ )
+
+ input_ids: List[int] = self.tokenizer.encode(prompt, add_special_tokens=False)
+ input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
+
+ # Each patch's pad tokens are guaranteed contiguous (the expansion
+ # functions wrap them in ``...`` / ``...``
+ # with nothing else in between), so a per-token-id contiguous-run scan
+ # — base's ``get_mm_items_offset`` — gives one (start, end) per patch.
+ mm_items: List[MultimodalDataItem] = []
+ mm_items.extend(
+ self._build_items(
+ input_ids_tensor,
+ self.image_token_id,
+ _flatten_patches(img_per_pv, img_per_ts),
+ Modality.IMAGE,
+ )
+ )
+ # Video: extra ``per-frame -> per-patch`` nesting; pre-flatten one
+ # level so ``_flatten_patches`` sees the same shape as image.
+ vid_pv_flat = [list(chain.from_iterable(v)) for v in vid_per_pv]
+ vid_ts_flat = [list(chain.from_iterable(v)) for v in vid_per_ts]
+ mm_items.extend(
+ self._build_items(
+ input_ids_tensor,
+ self.video_token_id,
+ _flatten_patches(vid_pv_flat, vid_ts_flat),
+ Modality.VIDEO,
+ )
+ )
+
+ return MultimodalProcessorOutput(
+ mm_items=mm_items,
+ input_ids=input_ids,
+ im_token_id=self.image_token_id,
+ im_start_id=self.image_start_id,
+ im_end_id=self.image_end_id,
+ slice_start_id=self.slice_start_id,
+ slice_end_id=self.slice_end_id,
+ )
+
+ def _preprocess_images(self, images):
+ if not images:
+ return [], [], []
+ out = self.image_processor.preprocess(images)
+ return out["pixel_values"], out["tgt_sizes"], out["grids"]
+
+ def _preprocess_videos(self, videos):
+ per_video_pv: List[List[List[torch.Tensor]]] = []
+ per_video_ts: List[List[List[List[int]]]] = []
+ per_video_grids: List[List[List[int]]] = []
+ for frames in videos:
+ out = self.image_processor.preprocess(list(frames))
+ per_video_pv.append(out["pixel_values"])
+ per_video_ts.append(out["tgt_sizes"])
+ per_video_grids.append(out["grids"])
+ return per_video_pv, per_video_ts, per_video_grids
+
+ def _splice_expansions(self, prompt, image_expansions, video_expansions):
+ # The chat template emits exactly one marker per media item; a
+ # sequential ``replace(..., n=1)`` walk lines them up by left-to-right
+ # order. Expansions carry ``_PAD_PLACEHOLDER`` for inner pads so the
+ # next replace doesn't trip on a previous expansion's pads — we swap
+ # placeholders back to the real pad token in one pass per modality.
+ for token, expansions in (
+ (self.image_token, image_expansions),
+ (self.video_token, video_expansions),
+ ):
+ for expansion in expansions:
+ if token not in prompt:
+ break
+ prompt = prompt.replace(token, expansion, 1)
+ prompt = prompt.replace(_PAD_PLACEHOLDER, token)
+ return prompt
+
+ def _build_items(
+ self,
+ input_ids: torch.Tensor,
+ pad_token_id: int,
+ flat: Tuple[List[torch.Tensor], List[torch.Tensor]],
+ modality: Modality,
+ ) -> List[MultimodalDataItem]:
+ flat_pv, flat_ts = flat
+ runs = self.get_mm_items_offset(input_ids, pad_token_id)
+ if len(runs) != len(flat_pv):
+ raise RuntimeError(
+ f"[minicpmv4_6] {modality} pad run / feature count mismatch: "
+ f"{len(runs)} runs vs {len(flat_pv)} patches"
+ )
+ return [
+ MultimodalDataItem(
+ feature=[pv],
+ offsets=[run],
+ model_specific_data={"tgt_size": [ts]},
+ modality=modality,
+ )
+ for run, pv, ts in zip(runs, flat_pv, flat_ts)
+ ]
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 5da9e433740a..6ba3d4507f88 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -2280,6 +2280,16 @@ def _handle_model_specific_adjustments(self):
sm100_default_attention_backend=sm100_default_attn_backend,
)
+ elif model_arch == "MiniCPMV4_6ForConditionalGeneration":
+ # 4.6 wraps a Qwen3.5 hybrid GDN backbone, so it needs the same
+ # mamba radix cache handling as Qwen3_5ForConditionalGeneration.
+ self._handle_mamba_radix_cache(
+ model_arch=model_arch,
+ support_mamba_cache=True,
+ support_mamba_cache_extra_buffer=True,
+ sm100_default_attention_backend="triton",
+ )
+
elif model_arch in ["Glm4MoeForCausalLM"]:
if is_sm100_supported():
quantization_config = getattr(hf_config, "quantization_config", None)
diff --git a/python/sglang/srt/utils/hf_transformers/common.py b/python/sglang/srt/utils/hf_transformers/common.py
index cd8729798d21..88f77dbcbb04 100644
--- a/python/sglang/srt/utils/hf_transformers/common.py
+++ b/python/sglang/srt/utils/hf_transformers/common.py
@@ -39,6 +39,8 @@
KimiVLConfig,
LagunaConfig,
LongcatFlashConfig,
+ MiniCPMV4_6Config,
+ MiniCPMV4_6VisionConfig,
MultiModalityConfig,
NemotronH_Nano_Omni_Reasoning_V3_Config,
NemotronH_Nano_VL_V2_Config,
@@ -100,6 +102,8 @@
JetVLMConfig,
KimiK25Config,
Step3p5Config,
+ MiniCPMV4_6Config,
+ MiniCPMV4_6VisionConfig,
]
}