diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 85267ccda8a9..dcae9ccdebbe 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -325,7 +325,7 @@ def __init__( self.hidden_size = config.text_config.hidden_size self.vision_tower = MoonVitPretrainedModel( config.vision_config, - self.use_data_parallel, + multimodal_config=model_config.multimodal_config, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 63ea6b259a71..99200068c066 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -51,118 +51,20 @@ import torch.nn.functional as F from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel -from transformers.utils import is_flash_attn_2_available +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.conv import Conv2dLayer -from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform from vllm.transformers_utils.configs.moonvit import MoonViTConfig -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func -elif current_platform.is_xpu(): - from vllm.attention.utils.fa_utils import flash_attn_varlen_func -else: - flash_attn_varlen_func = None - - -def multihead_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - q_cu_seqlens: torch.Tensor | None = None, - k_cu_seqlens: torch.Tensor | None = None, -) -> torch.Tensor: - """Multi-head attention using flash attention 2. - - Args: - q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. - The first element should be 0 and the last element should be q.shape[0]. - k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. - The first element should be 0 and the last element should be k.shape[0]. - - Returns: - output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, - where dim = num_heads * head_dim - """ - # Unified format legal check - assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" - assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]" - assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], ( - "k_cu_seqlens must sum to k.shape[0]" - ) - assert q.dtype in [ - torch.bfloat16, - torch.float16, - ], f"unsupported dtype {q.dtype} for multihead attn" - - max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item() - max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item() - attn_out = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q=q_cu_seqlens, - cu_seqlens_k=k_cu_seqlens, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - causal=False, - ) - attn_out = attn_out.flatten(start_dim=-2) - - return attn_out - - -def sdpa_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - q_cu_seqlens: torch.Tensor | None = None, - k_cu_seqlens: torch.Tensor | None = None, -) -> torch.Tensor: - """SDPA attention. - - Args: - q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim), - or (tot_seqlens, num_heads, head_dim) if packing. - q_cu_seqlens: Optional cumulative sequence lengths of q. - k_cu_seqlens: Optional cumulative sequence lengths of k. - """ - seq_length = q.shape[0] - attention_mask = torch.zeros( - [1, seq_length, seq_length], device=q.device, dtype=torch.bool - ) - for i in range(1, len(q_cu_seqlens)): - attention_mask[ - ..., - q_cu_seqlens[i - 1] : q_cu_seqlens[i], - q_cu_seqlens[i - 1] : q_cu_seqlens[i], - ] = True - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(seq_length, -1) - return attn_output - - -VL_VISION_ATTENTION_FUNCTIONS = { - "flash_attention_2": multihead_attention, - "sdpa": sdpa_attention, -} - def _apply_rope_input_validation(x, freqs_cis): assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) @@ -411,11 +313,19 @@ def __init__( super().__init__() assert len(dims) == 3 self.use_data_parallel = use_data_parallel - self.fc0 = ReplicatedLinear( - dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0") + self.fc0 = ColumnParallelLinear( + dims[0], + dims[1], + bias=bias, + prefix=maybe_prefix(prefix, "fc0"), + disable_tp=self.use_data_parallel, ) - self.fc1 = ReplicatedLinear( - dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1") + self.fc1 = RowParallelLinear( + dims[1], + dims[2], + bias=bias, + prefix=maybe_prefix(prefix, "fc1"), + disable_tp=self.use_data_parallel, ) self.activation = activation @@ -433,35 +343,55 @@ def __init__( hidden_dim: int, mlp_dim: int, prefix: str = "", - use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, *, - attn_implementation: str = "sdpa", activation=F.gelu, attn_bias: bool = False, ): super().__init__() + self.use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) + self.num_heads = num_heads self.hidden_dim = hidden_dim self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads - self.attn_implementation = attn_implementation - # use fa2 in vllm by default - if is_flash_attn_2_available() or current_platform.is_xpu(): - self.attn_implementation = "flash_attention_2" + self.tp_size = ( + 1 if self.use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.num_attention_heads_per_partition = divide(num_heads, self.tp_size) self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) - self.use_data_parallel = use_data_parallel self.mlp = MLP2( [hidden_dim, mlp_dim, hidden_dim], activation, prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel, + use_data_parallel=self.use_data_parallel, ) - self.wqkv = ReplicatedLinear( - hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv" + self.wqkv = QKVParallelLinear( + hidden_size=hidden_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=attn_bias, + prefix=f"{prefix}.wqkv", + disable_tp=self.use_data_parallel, ) - self.wo = ReplicatedLinear( - hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo" + self.wo = RowParallelLinear( + hidden_dim, + hidden_dim, + bias=attn_bias, + prefix=f"{prefix}.wo", + disable_tp=self.use_data_parallel, + ) + self.attn = MMEncoderAttention( + num_heads=self.num_attention_heads_per_partition, + head_size=self.hidden_size_per_attention_head, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attn", ) def attention_qkvpacked( @@ -472,14 +402,15 @@ def attention_qkvpacked( ): """ Args: - x (torch.Tensor): (batch_size, seqlen, hidden_dim) + x (torch.Tensor): (seqlen, hidden_dim) cu_seqlens (torch.Tensor): """ + seq_length = x.size(0) xqkv, _ = self.wqkv(x) qkv_shape = xqkv.size()[:-1] + ( 3, - self.num_heads, + self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) # xqkv: (batch_size, seqlen, 3, nheads, headdim) @@ -488,9 +419,18 @@ def attention_qkvpacked( xq, xk = apply_rope(xq, xk, rope_freqs_cis) - attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] - attn_out = attn_func( - xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_out = self.attn( + xq.unsqueeze(0), + xk.unsqueeze(0), + xv.unsqueeze(0), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + attn_out = attn_out.reshape( + seq_length, + self.num_attention_heads_per_partition + * self.hidden_size_per_attention_head, ) attn_out, _ = self.wo(attn_out) return attn_out @@ -528,7 +468,7 @@ def __init__( num_layers: int, block_cfg: dict, prefix: str = "", - use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() @@ -538,7 +478,7 @@ def __init__( self.blocks = nn.ModuleList( [ MoonVitEncoderLayer( - use_data_parallel=use_data_parallel, + multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", **block_cfg, ) @@ -599,31 +539,6 @@ def patch_merger( return outputs -class MoonVitVLProjector(nn.Module): - def __init__( - self, - in_channels: int, - merge_kernel_size: list[int, int], - hidden_act: str = "gelu", - ln_eps: float = 1e-5, - out_dim: int = 4096, - ): - super().__init__() - self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1] - - self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps) - self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.act = ACT2FN[hidden_act] - self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size) - hidden_states = self.linear_1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - class MoonVitPretrainedModel(PreTrainedModel): config_class = MoonViTConfig model_type = "moonvit" @@ -634,14 +549,13 @@ class MoonVitPretrainedModel(PreTrainedModel): def __init__( self, config: MoonViTConfig, - use_data_parallel: bool = False, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", *inputs, **kwargs, ): super().__init__(config, *inputs, **kwargs) config = deepcopy(config) - self.use_data_parallel = use_data_parallel self.merge_kernel_size = config.merge_kernel_size self.hidden_size = config.hidden_size self.patch_size = config.patch_size @@ -662,9 +576,9 @@ def __init__( "mlp_dim": config.intermediate_size, "activation": ACT2FN["gelu_pytorch_tanh"], "attn_bias": True, - "attn_implementation": config._attn_implementation, }, prefix=f"{prefix}.encoder", + multimodal_config=multimodal_config, ) def forward(