From 45530e680760f90bfe7a670edf450f546cb90c98 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 5 Jan 2026 23:39:21 +0800 Subject: [PATCH 1/7] update mha interface Signed-off-by: Isotr0py --- vllm/model_executor/models/moonvit.py | 125 +++----------------------- 1 file changed, 11 insertions(+), 114 deletions(-) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 63ea6b259a71..2b01358d31d4 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -51,118 +51,14 @@ import torch.nn.functional as F from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel -from transformers.utils import is_flash_attn_2_available +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform from vllm.transformers_utils.configs.moonvit import MoonViTConfig -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func -elif current_platform.is_xpu(): - from vllm.attention.utils.fa_utils import flash_attn_varlen_func -else: - flash_attn_varlen_func = None - - -def multihead_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - q_cu_seqlens: torch.Tensor | None = None, - k_cu_seqlens: torch.Tensor | None = None, -) -> torch.Tensor: - """Multi-head attention using flash attention 2. - - Args: - q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. - The first element should be 0 and the last element should be q.shape[0]. - k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. - The first element should be 0 and the last element should be k.shape[0]. - - Returns: - output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, - where dim = num_heads * head_dim - """ - # Unified format legal check - assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" - assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]" - assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], ( - "k_cu_seqlens must sum to k.shape[0]" - ) - assert q.dtype in [ - torch.bfloat16, - torch.float16, - ], f"unsupported dtype {q.dtype} for multihead attn" - - max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item() - max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item() - attn_out = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=q_cu_seqlens, - cu_seqlens_k=k_cu_seqlens, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - causal=False, - ) - attn_out = attn_out.flatten(start_dim=-2) - - return attn_out - - -def sdpa_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - q_cu_seqlens: torch.Tensor | None = None, - k_cu_seqlens: torch.Tensor | None = None, -) -> torch.Tensor: - """SDPA attention. - - Args: - q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - q_cu_seqlens: Optional cumulative sequence lengths of q. - k_cu_seqlens: Optional cumulative sequence lengths of k. - """ - seq_length = q.shape[0] - attention_mask = torch.zeros( - [1, seq_length, seq_length], device=q.device, dtype=torch.bool - ) - for i in range(1, len(q_cu_seqlens)): - attention_mask[ - ..., - q_cu_seqlens[i - 1] : q_cu_seqlens[i], - q_cu_seqlens[i - 1] : q_cu_seqlens[i], - ] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - return attn_output - - -VL_VISION_ATTENTION_FUNCTIONS = { - "flash_attention_2": multihead_attention, - "sdpa": sdpa_attention, -} - def _apply_rope_input_validation(x, freqs_cis): assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) @@ -435,7 +331,6 @@ def __init__( prefix: str = "", use_data_parallel: bool = False, *, - attn_implementation: str = "sdpa", activation=F.gelu, attn_bias: bool = False, ): @@ -443,10 +338,6 @@ def __init__( self.num_heads = num_heads self.hidden_dim = hidden_dim self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads - self.attn_implementation = attn_implementation - # use fa2 in vllm by default - if is_flash_attn_2_available() or current_platform.is_xpu(): - self.attn_implementation = "flash_attention_2" self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) @@ -463,6 +354,11 @@ def __init__( self.wo = ReplicatedLinear( hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo" ) + self.attn = MMEncoderAttention( + num_heads=self.num_heads, + head_size=self.hidden_size_per_attention_head, + prefix=f"{prefix}.attn", + ) def attention_qkvpacked( self, @@ -488,9 +384,11 @@ def attention_qkvpacked( xq, xk = apply_rope(xq, xk, rope_freqs_cis) - attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] - attn_out = attn_func( - xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens + attn_out = self.attn( + xq, + xk, + xv, + cu_seqlens=cu_seqlens, ) attn_out, _ = self.wo(attn_out) return attn_out @@ -662,7 +560,6 @@ def __init__( "mlp_dim": config.intermediate_size, "activation": ACT2FN["gelu_pytorch_tanh"], "attn_bias": True, - "attn_implementation": config._attn_implementation, }, prefix=f"{prefix}.encoder", ) From 919e0d459931331e3a5ae069b2633811af6d7626 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 6 Jan 2026 00:14:44 +0800 Subject: [PATCH 2/7] fix Signed-off-by: Isotr0py --- vllm/model_executor/models/moonvit.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 2b01358d31d4..647ede94be70 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -368,9 +368,10 @@ def attention_qkvpacked( ): """ Args: - x (torch.Tensor): (batch_size, seqlen, hidden_dim) + x (torch.Tensor): (seqlen, hidden_dim) cu_seqlens (torch.Tensor): """ + seq_length = x.size(0) xqkv, _ = self.wqkv(x) qkv_shape = xqkv.size()[:-1] + ( @@ -384,11 +385,16 @@ def attention_qkvpacked( xq, xk = apply_rope(xq, xk, rope_freqs_cis) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attn_out = self.attn( - xq, - xk, - xv, + xq.unsqueeze(0), + xk.unsqueeze(0), + xv.unsqueeze(0), cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + attn_out = attn_out.reshape( + seq_length, self.num_heads * self.hidden_size_per_attention_head ) attn_out, _ = self.wo(attn_out) return attn_out From e5c667dc0ca1603f0ba213edcfa5154223100683 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 6 Jan 2026 00:16:15 +0800 Subject: [PATCH 3/7] clean Signed-off-by: Isotr0py --- vllm/model_executor/models/moonvit.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 647ede94be70..2dafa3410e50 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -503,31 +503,6 @@ def patch_merger( return outputs -class MoonVitVLProjector(nn.Module): - def __init__( - self, - in_channels: int, - merge_kernel_size: list[int, int], - hidden_act: str = "gelu", - ln_eps: float = 1e-5, - out_dim: int = 4096, - ): - super().__init__() - self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1] - - self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps) - self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.act = ACT2FN[hidden_act] - self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size) - hidden_states = self.linear_1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - class MoonVitPretrainedModel(PreTrainedModel): config_class = MoonViTConfig model_type = "moonvit" From 8d7e248289dbe570f679fd795ebdcc6d4ab85cc3 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 6 Jan 2026 00:19:41 +0800 Subject: [PATCH 4/7] update dp Signed-off-by: Isotr0py --- vllm/model_executor/models/moonvit.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 2dafa3410e50..6b79677402ea 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -53,6 +53,7 @@ from transformers.modeling_utils import PreTrainedModel from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.models.utils import maybe_prefix @@ -329,7 +330,7 @@ def __init__( hidden_dim: int, mlp_dim: int, prefix: str = "", - use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, *, activation=F.gelu, attn_bias: bool = False, @@ -341,12 +342,16 @@ def __init__( self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) - self.use_data_parallel = use_data_parallel + self.use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.mlp = MLP2( [hidden_dim, mlp_dim, hidden_dim], activation, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, + use_data_parallel=self.use_data_parallel, ) self.wqkv = ReplicatedLinear( hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv" @@ -432,7 +437,7 @@ def __init__( num_layers: int, block_cfg: dict, prefix: str = "", - use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() @@ -442,7 +447,7 @@ def __init__( self.blocks = nn.ModuleList( [ MoonVitEncoderLayer( - use_data_parallel=use_data_parallel, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", **block_cfg, ) @@ -513,14 +518,13 @@ class MoonVitPretrainedModel(PreTrainedModel): def __init__( self, config: MoonViTConfig, - use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", *inputs, **kwargs, ): super().__init__(config, *inputs, **kwargs) config = deepcopy(config) - self.use_data_parallel = use_data_parallel self.merge_kernel_size = config.merge_kernel_size self.hidden_size = config.hidden_size self.patch_size = config.patch_size @@ -543,6 +547,7 @@ def __init__( "attn_bias": True, }, prefix=f"{prefix}.encoder", + multimodal_config=multimodal_config, ) def forward( From 2c1c1e4728f1a84f9f2c53a6e9c89686027ab7f4 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 6 Jan 2026 00:42:24 +0800 Subject: [PATCH 5/7] update Signed-off-by: Isotr0py --- vllm/model_executor/models/kimi_vl.py | 2 +- vllm/model_executor/models/moonvit.py | 65 ++++++++++++++++++++------- 2 files changed, 49 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 85267ccda8a9..dcae9ccdebbe 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -325,7 +325,7 @@ def __init__( self.hidden_size = config.text_config.hidden_size self.vision_tower = MoonVitPretrainedModel( config.vision_config, - self.use_data_parallel, + multimodal_config=model_config.multimodal_config, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 6b79677402ea..99200068c066 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -54,8 +54,13 @@ from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import MultiModalConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.conv import Conv2dLayer -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform from vllm.transformers_utils.configs.moonvit import MoonViTConfig @@ -308,11 +313,19 @@ def __init__( super().__init__() assert len(dims) == 3 self.use_data_parallel = use_data_parallel - self.fc0 = ReplicatedLinear( - dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0") + self.fc0 = ColumnParallelLinear( + dims[0], + dims[1], + bias=bias, + prefix=maybe_prefix(prefix, "fc0"), + disable_tp=self.use_data_parallel, ) - self.fc1 = ReplicatedLinear( - dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1") + self.fc1 = RowParallelLinear( + dims[1], + dims[2], + bias=bias, + prefix=maybe_prefix(prefix, "fc1"), + disable_tp=self.use_data_parallel, ) self.activation = activation @@ -336,32 +349,48 @@ def __init__( attn_bias: bool = False, ): super().__init__() + self.use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) + self.num_heads = num_heads self.hidden_dim = hidden_dim self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads + self.tp_size = ( + 1 if self.use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.num_attention_heads_per_partition = divide(num_heads, self.tp_size) self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) - self.use_data_parallel = ( - multimodal_config.mm_encoder_tp_mode == "data" - if multimodal_config - else False - ) self.mlp = MLP2( [hidden_dim, mlp_dim, hidden_dim], activation, prefix=f"{prefix}.mlp", use_data_parallel=self.use_data_parallel, ) - self.wqkv = ReplicatedLinear( - hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv" + self.wqkv = QKVParallelLinear( + hidden_size=hidden_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=attn_bias, + prefix=f"{prefix}.wqkv", + disable_tp=self.use_data_parallel, ) - self.wo = ReplicatedLinear( - hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo" + self.wo = RowParallelLinear( + hidden_dim, + hidden_dim, + bias=attn_bias, + prefix=f"{prefix}.wo", + disable_tp=self.use_data_parallel, ) self.attn = MMEncoderAttention( - num_heads=self.num_heads, + num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) @@ -381,7 +410,7 @@ def attention_qkvpacked( qkv_shape = xqkv.size()[:-1] + ( 3, - self.num_heads, + self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) # xqkv: (batch_size, seqlen, 3, nheads, headdim) @@ -399,7 +428,9 @@ def attention_qkvpacked( max_seqlen=max_seqlen, ) attn_out = attn_out.reshape( - seq_length, self.num_heads * self.hidden_size_per_attention_head + seq_length, + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, ) attn_out, _ = self.wo(attn_out) return attn_out From c1bfea4e164fb2243a84cfdbfe4fec3dcd0dcc27 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Mon, 5 Jan 2026 10:42:05 -0800 Subject: [PATCH 6/7] Update vllm/model_executor/models/moonvit.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Roger Wang --- vllm/model_executor/models/moonvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 99200068c066..b40178c4565d 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -419,7 +419,7 @@ def attention_qkvpacked( xq, xk = apply_rope(xq, xk, rope_freqs_cis) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_out = self.attn( xq.unsqueeze(0), xk.unsqueeze(0), From 72dd4d1d1deff15ebd831072b4ad86c1a4e49dbb Mon Sep 17 00:00:00 2001 From: h100 Date: Mon, 5 Jan 2026 19:06:06 +0000 Subject: [PATCH 7/7] revert Signed-off-by: h100 --- vllm/model_executor/models/moonvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index b40178c4565d..99200068c066 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -419,7 +419,7 @@ def attention_qkvpacked( xq, xk = apply_rope(xq, xk, rope_freqs_cis) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attn_out = self.attn( xq.unsqueeze(0), xk.unsqueeze(0),