Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/models/kimi_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(
self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel(
config.vision_config,
self.use_data_parallel,
multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)

Expand Down
228 changes: 71 additions & 157 deletions vllm/model_executor/models/moonvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,118 +51,20 @@
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available

from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.moonvit import MoonViTConfig

if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
elif current_platform.is_xpu():
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
else:
flash_attn_varlen_func = None


def multihead_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""Multi-head attention using flash attention 2.

Args:
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
The first element should be 0 and the last element should be q.shape[0].
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
The first element should be 0 and the last element should be k.shape[0].

Returns:
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
where dim = num_heads * head_dim
"""
# Unified format legal check
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], (
"k_cu_seqlens must sum to k.shape[0]"
)
assert q.dtype in [
torch.bfloat16,
torch.float16,
], f"unsupported dtype {q.dtype} for multihead attn"

max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
attn_out = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=q_cu_seqlens,
cu_seqlens_k=k_cu_seqlens,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=False,
)
attn_out = attn_out.flatten(start_dim=-2)

return attn_out


def sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""SDPA attention.

Args:
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
or (tot_seqlens, num_heads, head_dim) if packing.
q_cu_seqlens: Optional cumulative sequence lengths of q.
k_cu_seqlens: Optional cumulative sequence lengths of k.
"""
seq_length = q.shape[0]
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(q_cu_seqlens)):
attention_mask[
...,
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
return attn_output


VL_VISION_ATTENTION_FUNCTIONS = {
"flash_attention_2": multihead_attention,
"sdpa": sdpa_attention,
}


def _apply_rope_input_validation(x, freqs_cis):
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
Expand Down Expand Up @@ -411,11 +313,19 @@ def __init__(
super().__init__()
assert len(dims) == 3
self.use_data_parallel = use_data_parallel
self.fc0 = ReplicatedLinear(
dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0")
self.fc0 = ColumnParallelLinear(
dims[0],
dims[1],
bias=bias,
prefix=maybe_prefix(prefix, "fc0"),
disable_tp=self.use_data_parallel,
)
self.fc1 = ReplicatedLinear(
dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1")
self.fc1 = RowParallelLinear(
dims[1],
dims[2],
bias=bias,
prefix=maybe_prefix(prefix, "fc1"),
disable_tp=self.use_data_parallel,
)
self.activation = activation

Expand All @@ -433,35 +343,55 @@ def __init__(
hidden_dim: int,
mlp_dim: int,
prefix: str = "",
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
*,
attn_implementation: str = "sdpa",
activation=F.gelu,
attn_bias: bool = False,
):
super().__init__()
self.use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)

self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.attn_implementation = attn_implementation
# use fa2 in vllm by default
if is_flash_attn_2_available() or current_platform.is_xpu():
self.attn_implementation = "flash_attention_2"
self.tp_size = (
1 if self.use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_attention_heads_per_partition = divide(num_heads, self.tp_size)

self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.use_data_parallel = use_data_parallel
self.mlp = MLP2(
[hidden_dim, mlp_dim, hidden_dim],
activation,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
use_data_parallel=self.use_data_parallel,
)
self.wqkv = ReplicatedLinear(
hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv"
self.wqkv = QKVParallelLinear(
hidden_size=hidden_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=attn_bias,
prefix=f"{prefix}.wqkv",
disable_tp=self.use_data_parallel,
)
self.wo = ReplicatedLinear(
hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo"
self.wo = RowParallelLinear(
hidden_dim,
hidden_dim,
bias=attn_bias,
prefix=f"{prefix}.wo",
disable_tp=self.use_data_parallel,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)

def attention_qkvpacked(
Expand All @@ -472,14 +402,15 @@ def attention_qkvpacked(
):
"""
Args:
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
x (torch.Tensor): (seqlen, hidden_dim)
cu_seqlens (torch.Tensor):
"""
seq_length = x.size(0)
xqkv, _ = self.wqkv(x)

qkv_shape = xqkv.size()[:-1] + (
3,
self.num_heads,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
Expand All @@ -488,9 +419,18 @@ def attention_qkvpacked(

xq, xk = apply_rope(xq, xk, rope_freqs_cis)

attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
attn_out = attn_func(
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_out = self.attn(
xq.unsqueeze(0),
xk.unsqueeze(0),
xv.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_out = attn_out.reshape(
seq_length,
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
)
attn_out, _ = self.wo(attn_out)
return attn_out
Expand Down Expand Up @@ -528,7 +468,7 @@ def __init__(
num_layers: int,
block_cfg: dict,
prefix: str = "",
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
) -> None:
super().__init__()

Expand All @@ -538,7 +478,7 @@ def __init__(
self.blocks = nn.ModuleList(
[
MoonVitEncoderLayer(
use_data_parallel=use_data_parallel,
multimodal_config=multimodal_config,
prefix=f"{prefix}.blocks.{layer_idx}",
**block_cfg,
)
Expand Down Expand Up @@ -599,31 +539,6 @@ def patch_merger(
return outputs


class MoonVitVLProjector(nn.Module):
def __init__(
self,
in_channels: int,
merge_kernel_size: list[int, int],
hidden_act: str = "gelu",
ln_eps: float = 1e-5,
out_dim: int = 4096,
):
super().__init__()
self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]

self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.act = ACT2FN[hidden_act]
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states


class MoonVitPretrainedModel(PreTrainedModel):
config_class = MoonViTConfig
model_type = "moonvit"
Expand All @@ -634,14 +549,13 @@ class MoonVitPretrainedModel(PreTrainedModel):
def __init__(
self,
config: MoonViTConfig,
use_data_parallel: bool = False,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
*inputs,
**kwargs,
):
super().__init__(config, *inputs, **kwargs)
config = deepcopy(config)
self.use_data_parallel = use_data_parallel
self.merge_kernel_size = config.merge_kernel_size
self.hidden_size = config.hidden_size
self.patch_size = config.patch_size
Expand All @@ -662,9 +576,9 @@ def __init__(
"mlp_dim": config.intermediate_size,
"activation": ACT2FN["gelu_pytorch_tanh"],
"attn_bias": True,
"attn_implementation": config._attn_implementation,
},
prefix=f"{prefix}.encoder",
multimodal_config=multimodal_config,
)

def forward(
Expand Down